Vision Transformers, Is it a worthy successor to CNNs?

Sander Ali Khowaja
6 min readAug 18, 2023

--

AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE is a paper published by Google’s research team that inspired this blog post. Applying the pure Transformer directly to image patches, the paper proposes that it be used. With the Vision Transformer (ViT) being pre-trained on a large amount of data, it outperforms state-of-the-art convolutional networks in multiple benchmarks without requiring a large amount of computational resources.

Transformers are now the go-to model for NLP because they’re super fast and can handle a lot of data. CNNs are still the go-to for computer vision, but a few researchers have tried adding self-attention to them. The authors tried using a standard Transformer to train on a medium-sized dataset and found that it had only a bit of accuracy compared to a ResNet-like architecture. But when they trained on a bigger dataset, the ViT got great results and beat the competition on a bunch of image recognition tests.

Figure 1 (from the original paper) shows a model that takes 2D images and turns them into sequences of 2D patches that are flattened. Then, the patches are mapped to a fixed latent vector size using a teachable linear projection. The sequence of patches is pre-trained with a learnable embedding, and its state is output to the transformer encoder. The image representation then goes through a classification head to either train it or fine-tune it. Position embedding is added to keep track of the position information, and the order of the embedding vectors is input to the transform encoder, which is made up of alternating layers of multidirectional self-attention layers and MLP blocks.

CNNs have been around for a while and are great for image processing tasks. They can take local spatial patterns and use convolutional layers to extract hierarchical features. They can learn from a huge amount of image data and are really good at things like image classifying, detecting objects, and segmenting. But while CNNs have a good history in computer vision and can handle big data, Vision Transformers are better for situations where global dependencies and context are important. Vision Transformers usually need a lot more training data to get the same performance as a CNN. Plus, they’re more efficient because they can be used in real-time and with limited resources.

Vision Transformers on Cat vs Dog Dataset (Demo)

In this section, we will train a vision classifier on the available Kaggle dataset of cats and dogs, using both a CNN and vision transformer approach. To begin, the Kaggle dataset with 25,000 RGB images will be downloaded from Kaggle. If you have not already done so, please refer to the instructions below to learn how to set up a Kaggle API credentials. The Python code below will download the dataset to your current working directory.

from kaggle.api.kaggle_api_extended import KaggleApi

api = KaggleApi()
api.authenticate()

api.dataset_download_files('/kaggle-cat-vs-dog-dataset', path='./')
#'./' is for the current working directory

You can unzip the downloaded files using following commands

!unzip -qq kaggle-cat-vs-dog-dataset.zip
!rm -r kaggle-cat-vs-dog-dataset.zip

Clone the following GitHub repository and access the code within the utils directory.

!git clone https://github.com/sander-ali/ViT_vs_CNN.git
!mv ViT_vs_CNN/utils .

Use the following code to clean the downloaded dataset and prepare it for the subsequent training steps. The code for performing cleaning and loading are provided below, which are in PyTorch’s DataLoader Format.

import torch.nn as nn
import torch
import torch.optim as optim

from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
from sklearn.model_selection import train_test_split

import os


class LoadData:
def __init__(self):
self.cat_path = 'kagglecatsanddogs_3367a/PetImages/Cat'
self.dog_path = 'kagglecatsanddogs_3367a/PetImages/Dog'

def delete_non_jpeg_files(self, directory):
for filename in os.listdir(directory):
if not filename.endswith('.jpg') and not filename.endswith('.jpeg'):
file_path = os.path.join(directory, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
print('deleted', file_path)
except Exception as e:
print('Failed to delete %s. Reason: %s' % (file_path, e))

def data(self):
self.delete_non_jpeg_files(self.dog_path)
self.delete_non_jpeg_files(self.cat_path)

dog_list = os.listdir(self.dog_path)
dog_list = [(os.path.join(self.dog_path, i), 1) for i in dog_list]

cat_list = os.listdir(self.cat_path)
cat_list = [(os.path.join(self.cat_path, i), 0) for i in cat_list]

total_list = cat_list + dog_list

train_list, test_list = train_test_split(total_list, test_size=0.2)
train_list, val_list = train_test_split(train_list, test_size=0.2)
print('train list', len(train_list))
print('test list', len(test_list))
print('val list', len(val_list))
return train_list, test_list, val_list


# data Augumentation
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
])


class dataset(torch.utils.data.Dataset):

def __init__(self, file_list, transform=None):
self.file_list = file_list
self.transform = transform

# dataset length
def __len__(self):
self.filelength = len(self.file_list)
return self.filelength

# load an one of images
def __getitem__(self, idx):
img_path, label = self.file_list[idx]
img = Image.open(img_path).convert('RGB')
img_transformed = self.transform(img)
return img_transformed, label

CNN Model

This image classifier’s CNN model is composed of three layers: the kernel size is 3, the stride is 2, and the maximum pooling layer is 2. Following these layers are two connected layers with 10 nodes each. The code snippet below illustrates this structure.

class Cnn(nn.Module):
def __init__(self):
super(Cnn, self).__init__()

self.layer1 = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=3, padding=0, stride=2),
nn.BatchNorm2d(16),
nn.ReLU(),
nn.MaxPool2d(2)
)

self.layer2 = nn.Sequential(
nn.Conv2d(16, 32, kernel_size=3, padding=0, stride=2),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.MaxPool2d(2)
)

self.layer3 = nn.Sequential(
nn.Conv2d(32, 64, kernel_size=3, padding=0, stride=2),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(2)
)

self.fc1 = nn.Linear(3 * 3 * 64, 10)
self.dropout = nn.Dropout(0.5)
self.fc2 = nn.Linear(10, 2)
self.relu = nn.ReLU()

def forward(self, x):
out = self.layer1(x)
out = self.layer2(out)
out = self.layer3(out)
out = out.view(out.size(0), -1)
out = self.relu(self.fc1(out))
out = self.fc2(out)
return out

The training was carried out using NVIDIA RTX3060Ti GPU for 10 epochs. For further details you can visit the GitHub Repository for codes and training loop. The results for the CNN approach are shown below:

For training the ViT, you need to run the following code

from utils.simple_vit import ViT
model = ViT(
image_size=224,
patch_size=32,
num_classes=2,
dim=128,
depth=12,
heads=8,
mlp_dim=1024,
dropout=0.1,
emb_dropout=0.1,
).to(device)

The following parameters are essential for the vision transformer:

  • image_size = 224;
  • patch_size = 32;
  • num_classes = 2; and
  • dim = 128.

The image size parameter specifies the width and height of the images to be input into the model. These images should be of a size of 224x224 pixel. The patch size parameter specifies the size of each patch in the image. This parameter is used to divide the images into smaller patches. Image embeddings are used to represent each image patch. Similarly,

  • depth = 12:
  • heads = 8:
  • mlp_dim = 1024:
  • dropout = 0.1:
  • emb_dropout = 0.1:

The parameter specifies the Depth (depth or number of layers) of the Vision Transformer (encoder model) and the Number of Attention Heads (heads) in the Model’s Self-Attention Mechanism (Self-Attention). MLP_dim specifies the Dimensionality of the Hidden MLP (MLP) layers in the Model, which is responsible for the Transformation of Token Representations after Self-Attention. Dropout controls the Dropout Rate, which is a Regularization technique used to avoid Overfitting. This parameter sets a random fraction of Input Units to 0 during Training. EmbedDropout() specifies the Dropout Rate specifically applied to Token Embeddings. This parameter helps to prevent the Over-reliance on Specific Tokens during Training.

In order to perform the classification task, the vision transformer was trained for 20 training iterations using the NVIDIA RTX60 Ti GPU machine. Due to the slow convergence of the training loss, the training was performed for 20 training iterations (as opposed to the 10 training iterations used for CNN). The results showed that the CNN approach achieved better results, i.e. around 76% accuracy in 10 epochs in comparison to ViT which reached up to 71% accuracy in 20 epochs.

In conclusion, the comparison of CNN and Vision Transformers models reveals significant differences in model size and memory requirements as well as accuracy and performance. Generally, CNN models are known to be compact in size and to be efficient in memory utilization in resource-limited environments, and have been found to be highly efficient in image processing and to be highly accurate in a variety of computer vision tasks. Vision Transformers, however, offer a more powerful approach to capturing global dependencies and context-based understanding of images, which can lead to improved performance in some tasks. Vision Transformers typically have a larger model size and a higher memory requirement than CNNs, and while they may be able to achieve high accuracy, particularly when dealing with large datasets, their computational demands may limit their usability in resource-limited scenarios. Consequently, the selection of a CNN or Vision Transformer model depends on the individual needs and constraints of the task, taking into account factors such as resource availability and dataset size, as well as trade-offs between model complexity and performance. Further progress in both architectures is expected in the near future, and researchers and practitioners can make more informed decisions.

--

--

Sander Ali Khowaja
Sander Ali Khowaja

Written by Sander Ali Khowaja

An aspiring academician and researcher interested in Computer Vision, Privacy Preservation Machine Learning, Self-Supervised Federated Learning & Data Analytics

No responses yet