Vision Transformers on CIFAR-10 dataset: Part 3

This article is divided into three parts

AI Brewery
3 min readDec 19, 2021

Part(1/3): Introduction and Installation of Libraries

Part(2/3): Data Preparation

Part(3/3): Fine-tuning of the model

In the last articles, we saw a brief introduction to the concepts of Vision Transformers and Pytorch. We installed all the necessary libraries and prepared the data for the model training. Now let’s fine-tune the model and see the results.

Model Train function

It is to train the model on train dataloader.

# function to train the model
def train():

model.train()
total_loss = 0
# empty list to save model predictions
total_preds=[]
# iterate over batches
for step,batch in enumerate(train_dataloader):
# progress update after every 50 batches.
if step % 50 == 0 and not step == 0:
print(' Batch {:>5,} of {:>5,}.'.format(step, len(train_dataloader)))

# push the batch to gpu
lbl, pix = batch.items()
lbl, pix = lbl[1].to(device), pix[1].to(device)

# get model predictions for the current batch
preds = model(pix)

# compute the loss between actual and predicted values
loss = cross_entropy(preds, lbl)

# add on to the total loss
total_loss = total_loss + loss.item()

# backward pass to calculate the gradients
loss.backward()

# clip the the gradients to 1.0. It helps in preventing the exploding gradient problem
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

# update parameters
optimizer.step()

# clear calculated gradients
optimizer.zero_grad()
preds=preds.detach().cpu().numpy()

# append the model predictions
total_preds.append(preds)
# compute the training loss of the epoch
avg_loss = total_loss / len(train_dataloader)

total_preds = np.concatenate(total_preds, axis=0)

#returns the loss and predictions
return avg_loss, total_preds

Model Eval function

It is to evaluate the model on validation dataloader.

def eval():
total_loss = 0
model.eval() # prep model for evaluation
for step,batch in enumerate(val_dataloader):
lbl, pix = batch.items()
lbl, pix = lbl[1].to(device), pix[1].to(device)

# forward pass: compute predicted outputs by passing inputs to the model
preds = model(pix)
# calculate the loss
loss = cross_entropy(preds, lbl)
total_loss += loss.item()

return total_loss / len(val_dataloader)

Training the model

min_loss = inf
es = 0
for epoch in range(epochs):
print('\n Epoch {:} / {:}'.format(epoch + 1, epochs))

# Train model
train_loss, _ = train()
val_loss = eval()

# Early Stopping
if val_loss < min_loss:
min_loss = val_loss
es = 0
else:
es += 1
if es > 4:
print("Early stopping with train_loss: ", train_loss, "and val_loss for this epoch: ", val_loss, "...")
break

# it can make your experiment reproducible, similar to set random seed to all options where there needs a random seed.
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
print(f'\n Training Loss: {train_loss:.3f}')
print(f'\n Validation Loss: {val_loss:.3f}')

Save the model weights

torch.save(model.state_dict(), '/working/model')

Testing the model on Test DataLoader

def eval():
model.eval()
y_pred = []
y_true = []
with torch.no_grad():
for step, batch in tqdm(enumerate(test_dataloader), total = len(test_dataloader)):
lbl, pix = batch.items()
lbl, pix = lbl[1].to(device), pix[1].to(device)
outputs = model(pix)
outputs = torch.argmax(outputs, axis=1)
y_pred.extend(outputs.cpu().detach().numpy())
y_true.extend(lbl.cpu().detach().numpy())

return y_pred, y_true
y_pred, y_true = eval()

Calculate the accuracy

correct = np.array(y_pred) == np.array(y_true)
accuracy = correct.sum() / len(correct)
print("Accuracy of the model", accuracy)

To achieve better results:
1. We can add the entire dataset during training.
2. Increasing the number of epochs.
3. We can consider hyperparameter tuning with Optuna.

--

--