Vision Transformers on CIFAR-10 dataset: Part 3
This article is divided into three parts
Part(1/3): Introduction and Installation of Libraries
Part(2/3): Data Preparation
Part(3/3): Fine-tuning of the model
In the last articles, we saw a brief introduction to the concepts of Vision Transformers and Pytorch. We installed all the necessary libraries and prepared the data for the model training. Now let’s fine-tune the model and see the results.
Model Train function
It is to train the model on train dataloader.
# function to train the model
def train():
model.train()
total_loss = 0 # empty list to save model predictions
total_preds=[] # iterate over batches
for step,batch in enumerate(train_dataloader): # progress update after every 50 batches.
if step % 50 == 0 and not step == 0:
print(' Batch {:>5,} of {:>5,}.'.format(step, len(train_dataloader)))
# push the batch to gpu
lbl, pix = batch.items()
lbl, pix = lbl[1].to(device), pix[1].to(device)
# get model predictions for the current batch
preds = model(pix)
# compute the loss between actual and predicted values
loss = cross_entropy(preds, lbl)
# add on to the total loss
total_loss = total_loss + loss.item()
# backward pass to calculate the gradients
loss.backward()
# clip the the gradients to 1.0. It helps in preventing the exploding gradient problem
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
# update parameters
optimizer.step()
# clear calculated gradients
optimizer.zero_grad()
preds=preds.detach().cpu().numpy()
# append the model predictions
total_preds.append(preds)
# compute the training loss of the epoch
avg_loss = total_loss / len(train_dataloader)
total_preds = np.concatenate(total_preds, axis=0)
#returns the loss and predictions
return avg_loss, total_preds
Model Eval function
It is to evaluate the model on validation dataloader.
def eval():
total_loss = 0
model.eval() # prep model for evaluation
for step,batch in enumerate(val_dataloader):
lbl, pix = batch.items()
lbl, pix = lbl[1].to(device), pix[1].to(device)
# forward pass: compute predicted outputs by passing inputs to the model
preds = model(pix)
# calculate the loss
loss = cross_entropy(preds, lbl)
total_loss += loss.item()
return total_loss / len(val_dataloader)
Training the model
min_loss = inf
es = 0for epoch in range(epochs):
print('\n Epoch {:} / {:}'.format(epoch + 1, epochs))
# Train model
train_loss, _ = train()
val_loss = eval()
# Early Stopping
if val_loss < min_loss:
min_loss = val_loss
es = 0
else:
es += 1
if es > 4:
print("Early stopping with train_loss: ", train_loss, "and val_loss for this epoch: ", val_loss, "...")
break
# it can make your experiment reproducible, similar to set random seed to all options where there needs a random seed.
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
print(f'\n Training Loss: {train_loss:.3f}')
print(f'\n Validation Loss: {val_loss:.3f}')
Save the model weights
torch.save(model.state_dict(), '/working/model')
Testing the model on Test DataLoader
def eval():
model.eval()
y_pred = []
y_true = []
with torch.no_grad():
for step, batch in tqdm(enumerate(test_dataloader), total = len(test_dataloader)):
lbl, pix = batch.items()
lbl, pix = lbl[1].to(device), pix[1].to(device) outputs = model(pix)
outputs = torch.argmax(outputs, axis=1)
y_pred.extend(outputs.cpu().detach().numpy())
y_true.extend(lbl.cpu().detach().numpy())
return y_pred, y_truey_pred, y_true = eval()
Calculate the accuracy
correct = np.array(y_pred) == np.array(y_true)
accuracy = correct.sum() / len(correct)
print("Accuracy of the model", accuracy)
To achieve better results:
1. We can add the entire dataset during training.
2. Increasing the number of epochs.
3. We can consider hyperparameter tuning with Optuna.