LeNet¶
Libraries¶
We begin by importing the necessary Python modules and libraries for building and training our neural network with the MNIST dataset.
# Standard libraries
from typing import Any
# 3pps
import matplotlib.pyplot as plt
import torch
from sklearn.manifold import TSNE
from torch import nn
from torch.utils.data import DataLoader
from torchinfo import summary
from torchvision import datasets, transforms
from tqdm import tqdm
Functions¶
def show_images(images, labels):
fig, axes = plt.subplots(1, len(images), figsize=(10, 2))
for img, label, ax in zip(images, labels, axes):
ax.imshow(img.squeeze(), cmap='gray')
ax.set_title(f'Label: {label}')
ax.axis('off')
plt.show()
Main¶
We start by loading the MNIST dataset, including both the training and test sets. While loading the data, we apply two important transformations. First, each image is converted into a PyTorch tensor, which allows the model to process the data efficiently. Second, we normalize the images using transforms.Normalize((0.1307,), (0.3081,)). These numbers represent the mean (0.1307) and standard deviation (0.3081) of the MNIST dataset, and normalization ensures that the data has a consistent scale. This step is important because it helps the model train more effectively and converge faster. By combining these transformations, we prepare the dataset in a way that is both suitable for the model and optimized for learning.
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST(
root="./data",
train=True,
download=True,
transform=transform
)
test_dataset = datasets.MNIST(
root="./data",
train=False,
download=True,
transform=transform
)
train_dataset
Dataset MNIST
Number of datapoints: 60000
Root location: ./data
Split: Train
StandardTransform
Transform: Compose(
ToTensor()
Normalize(mean=(0.1307,), std=(0.3081,))
)
test_dataset
Dataset MNIST
Number of datapoints: 10000
Root location: ./data
Split: Test
StandardTransform
Transform: Compose(
ToTensor()
Normalize(mean=(0.1307,), std=(0.3081,))
)
We can use the DataLoader class to divide the dataset into batches and shuffle the data efficiently using PyTorch’s built-in functionality. To get started, we first define some global variables or constants that will be used throughout the data loading and training process.
BATCH_SIZE: int = 32
train_dataloader = DataLoader(
dataset=train_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
)
test_dataloader = DataLoader(
dataset=test_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
)
We can visualize some examples of sample-label pairs from the dataset. If we take one batch from the DataLoader, we will get as many samples as the chosen batch size. Each MNIST sample is a grayscale image of size 28 × 28, meaning it has a single channel. By inspecting these batches, we can better understand the structure and format of the data before feeding it into a neural network.
data_iter = iter(train_dataloader)
train_images, train_labels = next(data_iter)
train_images.shape, train_labels.shape
(torch.Size([32, 1, 28, 28]), torch.Size([32]))
We will display the first 10 samples from the dataset. This allows us to quickly inspect the images and their corresponding labels to ensure that the data has been loaded and preprocessed correctly.
show_images(train_images[:10], train_labels[:10])
Next, we will create a convolutional neural network (CNN) model, inspired by Yann LeCun’s LeNet architecture, and adapt it for the MNIST dataset. This model will use convolutional layers to automatically extract features from the images, followed by fully connected layers to perform classification.
class LeNet(nn.Module):
def __init__(self, input_tensor_shape: tuple[int, ...], **kwargs: Any) -> None:
super().__init__(**kwargs)
self.input_tensor_shape = input_tensor_shape
self.model = nn.Sequential(
nn.Conv2d(in_channels=self.input_tensor_shape[0], out_channels=16, kernel_size=4, stride=2, padding="valid"),
nn.BatchNorm2d(num_features=16),
nn.ReLU(),
nn.Conv2d(in_channels=16, out_channels=32, kernel_size=4, stride=2, padding="valid"),
nn.BatchNorm2d(num_features=32),
nn.ReLU(),
nn.AdaptiveAvgPool2d(output_size=(1,1)),
nn.Flatten(),
nn.Linear(32, 10),
)
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
return self.model(input_tensor)
We define the optimizer as AdamW and use cross-entropy as the loss function. AdamW is an adaptive optimizer that combines the benefits of Adam with correct weight decay, helping the model converge efficiently. Cross-entropy loss is well-suited for multi-class classification tasks like MNIST, as it measures the difference between the predicted probabilities and the true class labels.
model = LeNet(input_tensor_shape=(1,28,28))
summary(model, input_size=(BATCH_SIZE, 1,28,28))
========================================================================================== Layer (type:depth-idx) Output Shape Param # ========================================================================================== LeNet [32, 10] -- ├─Sequential: 1-1 [32, 10] -- │ └─Conv2d: 2-1 [32, 16, 13, 13] 272 │ └─BatchNorm2d: 2-2 [32, 16, 13, 13] 32 │ └─ReLU: 2-3 [32, 16, 13, 13] -- │ └─Conv2d: 2-4 [32, 32, 5, 5] 8,224 │ └─BatchNorm2d: 2-5 [32, 32, 5, 5] 64 │ └─ReLU: 2-6 [32, 32, 5, 5] -- │ └─AdaptiveAvgPool2d: 2-7 [32, 32, 1, 1] -- │ └─Flatten: 2-8 [32, 32] -- │ └─Linear: 2-9 [32, 10] 330 ========================================================================================== Total params: 8,922 Trainable params: 8,922 Non-trainable params: 0 Total mult-adds (Units.MEGABYTES): 8.06 ========================================================================================== Input size (MB): 0.10 Forward/backward pass size (MB): 1.80 Params size (MB): 0.04 Estimated Total Size (MB): 1.93 ==========================================================================================
optimizer = torch.optim.AdamW(params=model.parameters(), lr=1e-3, weight_decay=1e-4)
loss_function = torch.nn.CrossEntropyLoss()
Now, we need to create the training loop. This loop will iterate over the dataset for a number of epochs, feeding batches of data through the model, computing the loss, performing backpropagation, and updating the model’s parameters. A well-structured training loop is essential for effectively training the network and monitoring its performance over time.
NUM_EPOCHS: int = 5
train_losses, train_accuracies = [], []
test_losses, test_accuracies = [], []
for epoch in range(NUM_EPOCHS):
model.train()
running_loss, correct, total = 0.0, 0, 0
train_loop = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS} [Train]", leave=False)
for batch_image, batch_label in train_loop:
optimizer.zero_grad()
outputs = model(batch_image)
loss = loss_function(outputs, batch_label)
loss.backward()
optimizer.step()
running_loss += loss.item()
_, predicted = torch.max(outputs, 1)
total += batch_label.size(0)
correct += (predicted == batch_label).sum().item()
train_losses.append(running_loss / len(train_dataloader))
train_accuracies.append(100 * correct / total)
model.eval()
test_loss, correct_test, total_test = 0.0, 0, 0
test_loop = tqdm(test_dataloader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS} [Test]", leave=False)
with torch.no_grad():
for images, labels in test_loop:
outputs = model(images)
loss = loss_function(outputs, labels)
test_loss += loss.item()
_, predicted = torch.max(outputs, 1)
total_test += labels.size(0)
correct_test += (predicted == labels).sum().item()
test_losses.append(test_loss / len(test_dataloader))
test_accuracies.append(100 * correct_test / total_test)
print(f"Epoch [{epoch+1}/{NUM_EPOCHS}] "
f"Train Loss: {train_losses[-1]:.4f}, Train Acc: {train_accuracies[-1]:.2f}% "
f"| Test Loss: {test_losses[-1]:.4f}, Test Acc: {test_accuracies[-1]:.2f}%")
epochs = range(1, NUM_EPOCHS+1)
plt.plot(epochs, train_losses, label="Train Loss")
plt.plot(epochs, test_losses, label="Test Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("Training vs Test Loss")
plt.legend()
plt.show()
plt.plot(epochs, train_accuracies, label="Train Accuracy")
plt.plot(epochs, test_accuracies, label="Test Accuracy")
plt.xlabel("Epochs")
plt.ylabel("Accuracy (%)")
plt.title("Training vs Test Accuracy")
plt.legend()
plt.show()
Epoch [1/5] Train Loss: 0.9334, Train Acc: 76.77% | Test Loss: 0.4189, Test Acc: 90.49%
Epoch [2/5] Train Loss: 0.3321, Train Acc: 92.07% | Test Loss: 0.2243, Test Acc: 94.45%
Epoch [3/5] Train Loss: 0.2210, Train Acc: 94.38% | Test Loss: 0.1833, Test Acc: 95.19%
Epoch [4/5] Train Loss: 0.1745, Train Acc: 95.46% | Test Loss: 0.1389, Test Acc: 96.30%
Epoch [5/5] Train Loss: 0.1469, Train Acc: 96.10% | Test Loss: 0.1318, Test Acc: 96.39%
Now, we can visualize the data in a lower-dimensional space using t-SNE. This technique allows us to project high-dimensional representations—such as the feature outputs from our model—into two or three dimensions, making it easier to observe patterns, clusters, or separations between different classes. Visualizing the data in this way can provide valuable insights into how well the model is learning to distinguish between digits.
all_labels = []
embeddings = []
model.eval()
with torch.no_grad():
for batch_image, batch_label in train_dataloader:
output = model(batch_image)
embeddings.append(output.cpu())
all_labels.append(batch_label)
embeddings = torch.cat(embeddings, dim=0)
all_labels = torch.cat(all_labels, dim=0)
X_embedded = TSNE(n_components=2, learning_rate='auto',
init='random', perplexity=30).fit_transform(embeddings)
plt.figure(figsize=(10, 8))
scatter = plt.scatter(X_embedded[:,0], X_embedded[:,1], c=all_labels, cmap="tab10", alpha=0.7, s=15)
plt.colorbar(scatter, ticks=range(10), label="Classes")
plt.title("t-SNE Training Embeddings MNIST")
plt.show()