Autoencoders
Autoencoders constitute a family of neural network architectures designed to learn compressed representations of data in an unsupervised manner. The fundamental structure of an autoencoder is organized into two main blocks: An encoder, which transforms the original input into a lower-dimensional latent representation, and a decoder, which takes this latent representation and reconstructs from it an approximation of the original input. The training objective consists of minimizing the discrepancy between the reconstructed output and the input, so that the model is forced to capture the most relevant characteristics of the data in the latent space.
This document presents several variants of autoencoders, from basic dense architectures to more advanced models such as variational autoencoders (VAE), Beta-VAE, and VQ-VAE. All implementations are developed on the MNIST dataset and are provided as fully functional code, ready to be executed from start to finish in a Jupyter Notebook environment.
Vanilla Autoencoder with Dense Layers
The vanilla autoencoder uses exclusively dense (fully connected) layers to encode and decode MNIST images. Each image of size \(28 \times 28\) is flattened into a vector of dimension 784 and projected into a lower-dimensional latent space. The encoder applies a sequence of linear transformations and nonlinear activation functions until it reaches the latent space, whereas the decoder performs the inverse process to reconstruct the image.
This configuration introduces the central idea of autoencoders but exhibits clear limitations. Dense layers do not explicitly exploit the spatial structure of the image, which leads to a large number of parameters due to the full connectivity between neurons. In addition, since local relationships between pixels are not modeled explicitly, reconstructions tend to be blurrier and less detailed.
The following code presents a basic functional implementation on MNIST.
# 3pps
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
class VanillaAutoencoder(nn.Module):
def __init__(self, input_dim: int = 784, latent_dim: int = 32) -> None:
super().__init__()
# Encoder: Progressively reduces dimensionality
self.encoder = nn.Sequential(
nn.Linear(input_dim, 256),
nn.ReLU(),
nn.Linear(256, 128),
nn.ReLU(),
nn.Linear(128, latent_dim),
)
# Decoder: Reconstructs from the latent space
self.decoder = nn.Sequential(
nn.Linear(latent_dim, 128),
nn.ReLU(),
nn.Linear(128, 256),
nn.ReLU(),
nn.Linear(256, input_dim),
nn.Sigmoid(), # Output in [0, 1]
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Flatten the image
x = x.view(x.size(0), -1)
# Encode
latent = self.encoder(x)
# Decode
reconstructed = self.decoder(latent)
# Return to image shape
return reconstructed.view(-1, 1, 28, 28)
def encode(self, x: torch.Tensor) -> torch.Tensor:
x = x.view(x.size(0), -1)
return self.encoder(x)
def prepare_mnist_data(batch_size: int = 128):
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(
root="./data", train=True, download=True, transform=transform
)
test_dataset = datasets.MNIST(
root="./data", train=False, download=True, transform=transform
)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
return train_loader, test_loader
def train_autoencoder(
model: nn.Module,
train_loader: DataLoader,
num_epochs: int = 10,
device: str = "cuda",
) -> nn.Module:
model = model.to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
model.train()
for epoch in range(num_epochs):
total_loss = 0.0
for data, _ in train_loader:
data = data.to(device)
optimizer.zero_grad()
reconstructed = model(data)
loss = criterion(reconstructed, data)
loss.backward()
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(train_loader)
print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.6f}")
return model
def visualize_reconstructions(
model: nn.Module,
test_loader: DataLoader,
num_images: int = 10,
device: str = "cuda",
) -> None:
model.eval()
data, _ = next(iter(test_loader))
data = data[:num_images].to(device)
with torch.no_grad():
reconstructed = model(data)
data = data.cpu()
reconstructed = reconstructed.cpu()
fig, axes = plt.subplots(2, num_images, figsize=(15, 3))
for i in range(num_images):
axes[0, i].imshow(data[i].squeeze(), cmap="gray")
axes[0, i].axis("off")
axes[0, i].set_title("Original")
axes[1, i].imshow(reconstructed[i].squeeze(), cmap="gray")
axes[1, i].axis("off")
axes[1, i].set_title("Reconstructed")
plt.tight_layout()
plt.show()
# Vanilla autoencoder execution
train_loader, test_loader = prepare_mnist_data()
vanilla_ae = VanillaAutoencoder(input_dim=784, latent_dim=32)
device = "cuda" if torch.cuda.is_available() else "cpu"
vanilla_ae = train_autoencoder(vanilla_ae, train_loader, num_epochs=2, device=device)
visualize_reconstructions(vanilla_ae, test_loader, device=device)
Denoising Autoencoder
The denoising autoencoder extends the previous approach by introducing noise into the input during training. In this case, the encoder receives a corrupted version of the image, while the loss function compares the decoder output with the clean original image. This mechanism forces the model to learn robust latent representations that capture the underlying structure of the data, rather than merely approximating the identity function.
Noise is usually introduced as additive Gaussian noise, and values are subsequently clipped to keep them in the range \([0, 1]\). In this way, the model learns to "undo" the corruption, acting as a filter that preserves relevant content and discards spurious details.
The following code illustrates an implementation of this variant on MNIST.
class DenoisingAutoencoder(nn.Module):
def __init__(self, input_dim: int = 784, latent_dim: int = 32) -> None:
super().__init__()
self.encoder = nn.Sequential(
nn.Linear(input_dim, 256),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(256, 128),
nn.ReLU(),
nn.Linear(128, latent_dim),
)
self.decoder = nn.Sequential(
nn.Linear(latent_dim, 128),
nn.ReLU(),
nn.Linear(128, 256),
nn.ReLU(),
nn.Linear(256, input_dim),
nn.Sigmoid(),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.view(x.size(0), -1)
latent = self.encoder(x)
reconstructed = self.decoder(latent)
return reconstructed.view(-1, 1, 28, 28)
def add_noise(images: torch.Tensor, noise_factor: float = 0.3) -> torch.Tensor:
noisy = images + noise_factor * torch.randn_like(images)
noisy = torch.clip(noisy, 0.0, 1.0)
return noisy
def train_denoising_ae(
model: nn.Module,
train_loader: DataLoader,
num_epochs: int = 10,
device: str = "cuda",
noise_factor: float = 0.3,
) -> nn.Module:
model = model.to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
model.train()
for epoch in range(num_epochs):
total_loss = 0.0
for data, _ in train_loader:
clean_data = data.to(device)
noisy_data = add_noise(clean_data, noise_factor)
optimizer.zero_grad()
reconstructed = model(noisy_data)
loss = criterion(reconstructed, clean_data)
loss.backward()
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(train_loader)
print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.6f}")
return model
def visualize_denoising(
model: nn.Module,
test_loader: DataLoader,
noise_factor: float = 0.3,
num_images: int = 10,
device: str = "cuda",
) -> None:
model.eval()
data, _ = next(iter(test_loader))
data = data[:num_images].to(device)
noisy_data = add_noise(data, noise_factor)
with torch.no_grad():
reconstructed = model(noisy_data)
data = data.cpu()
noisy_data = noisy_data.cpu()
reconstructed = reconstructed.cpu()
fig, axes = plt.subplots(3, num_images, figsize=(15, 5))
for i in range(num_images):
axes[0, i].imshow(data[i].squeeze(), cmap="gray")
axes[0, i].axis("off")
if i == 0:
axes[0, i].set_ylabel("Original", rotation=0, labelpad=40)
axes[1, i].imshow(noisy_data[i].squeeze(), cmap="gray")
axes[1, i].axis("off")
if i == 0:
axes[1, i].set_ylabel("Noisy", rotation=0, labelpad=40)
axes[2, i].imshow(reconstructed[i].squeeze(), cmap="gray")
axes[2, i].axis("off")
if i == 0:
axes[2, i].set_ylabel("Denoised", rotation=0, labelpad=40)
plt.tight_layout()
plt.show()
# Denoising autoencoder execution
denoising_ae = DenoisingAutoencoder(input_dim=784, latent_dim=32)
denoising_ae = train_denoising_ae(
denoising_ae, train_loader, num_epochs=2, device=device
)
visualize_denoising(denoising_ae, test_loader, device=device)
Convolutional Autoencoder
Convolutional autoencoders are better suited to image data because they explicitly exploit spatial structure. The encoder applies convolutions with shared weights and local filters; spatial dimensionality is reduced through stride and the stacking of layers. The decoder uses transposed convolutions to perform upsampling and reconstruct the original resolution.
In this context, convolutions provide several advantages. They significantly reduce the number of parameters compared with dense layers, due to weight sharing across different spatial positions. They also capture local patterns and hierarchical structures (edges, digit parts, whole digits), which leads to sharper reconstructions that are more consistent with image content.
The following implementation illustrates a convolutional autoencoder with a linear bottleneck.
class ConvAutoencoder(nn.Module):
def __init__(self, latent_dim: int = 128) -> None:
super().__init__()
# Convolutional encoder
self.encoder = nn.Sequential(
# 28x28 -> 14x14
nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.BatchNorm2d(32),
# 14x14 -> 7x7
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.BatchNorm2d(64),
# 7x7 -> 4x4 (slight additional reduction)
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.BatchNorm2d(128),
)
# Linear bottleneck
self.flatten = nn.Flatten()
self.fc_encode = nn.Linear(128 * 4 * 4, latent_dim)
self.fc_decode = nn.Linear(latent_dim, 128 * 4 * 4)
self.unflatten = nn.Unflatten(1, (128, 4, 4))
# Decoder with transposed convolutions
self.decoder = nn.Sequential(
# 4x4 -> 7x7
nn.ConvTranspose2d(
128, 64, kernel_size=3, stride=2, padding=1, output_padding=0
),
nn.ReLU(),
nn.BatchNorm2d(64),
# 7x7 -> 14x14
nn.ConvTranspose2d(
64, 32, kernel_size=3, stride=2, padding=1, output_padding=1
),
nn.ReLU(),
nn.BatchNorm2d(32),
# 14x14 -> 28x28
nn.ConvTranspose2d(
32, 1, kernel_size=3, stride=2, padding=1, output_padding=1
),
nn.Sigmoid(),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Encode
x = self.encoder(x)
x = self.flatten(x)
latent = self.fc_encode(x)
# Decode
x = self.fc_decode(latent)
x = self.unflatten(x)
reconstructed = self.decoder(x)
return reconstructed
def encode(self, x: torch.Tensor) -> torch.Tensor:
x = self.encoder(x)
x = self.flatten(x)
return self.fc_encode(x)
# Training the convolutional autoencoder
conv_ae = ConvAutoencoder(latent_dim=128)
conv_ae = train_autoencoder(conv_ae, train_loader, num_epochs=2, device=device)
visualize_reconstructions(conv_ae, test_loader, device=device)
Transposed convolutions can introduce characteristic artifacts known as checkerboard artifacts, which arise when the combination of kernel size and stride produces uneven overlaps during the upsampling operation.
Autoencoder with Interpolation-Based Upsampling
To mitigate checkerboard artifacts, it is common to replace transposed convolutions with an upsampling strategy based on interpolation followed by standard convolutions. In this configuration, the spatial resolution is first increased by interpolation (bilinear, bicubic, etc.), and then a convolution is applied to refine the result and learn filters over the rescaled image.
This procedure tends to produce smoother and visually more coherent reconstructions, significantly reducing undesired patterns at the cost of some additional computational cost.
The following model preserves the same convolutional encoder as the previous autoencoder
but replaces the ConvTranspose2d-based decoder with a decoder that combines Upsample
and Conv2d.
class UpsamplingAutoencoder(nn.Module):
def __init__(self, latent_dim: int = 128) -> None:
super().__init__()
# Encoder identical to the convolutional autoencoder
self.encoder = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.BatchNorm2d(64),
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.BatchNorm2d(128),
)
self.flatten = nn.Flatten()
self.fc_encode = nn.Linear(128 * 4 * 4, latent_dim)
self.fc_decode = nn.Linear(latent_dim, 128 * 4 * 4)
self.unflatten = nn.Unflatten(1, (128, 4, 4))
# Decoder with upsampling + convolution
self.decoder = nn.Sequential(
# 4x4 -> 7x7
nn.Upsample(size=(7, 7), mode="bilinear", align_corners=False),
nn.Conv2d(128, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.BatchNorm2d(64),
# 7x7 -> 14x14
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
nn.Conv2d(64, 32, kernel_size=3, padding=1),
nn.ReLU(),
nn.BatchNorm2d(32),
# 14x14 -> 28x28
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
nn.Conv2d(32, 1, kernel_size=3, padding=1),
nn.Sigmoid(),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.encoder(x)
x = self.flatten(x)
latent = self.fc_encode(x)
x = self.fc_decode(latent)
x = self.unflatten(x)
reconstructed = self.decoder(x)
return reconstructed
# Training and visualization
upsampling_ae = UpsamplingAutoencoder(latent_dim=128)
upsampling_ae = train_autoencoder(
upsampling_ae, train_loader, num_epochs=2, device=device
)
visualize_reconstructions(upsampling_ae, test_loader, device=device)
The use of bilinear or bicubic interpolation followed by standard convolutions generally produces visually more pleasant reconstructions and significantly reduces checkerboard artifacts, while preserving the model's ability to capture high-level patterns.
Variational Autoencoder (VAE)
The variational autoencoder (VAE) introduces an important conceptual change with respect to deterministic autoencoders. Instead of learning a direct mapping from the input to a fixed latent vector, the encoder learns the parameters of a probability distribution over the latent space. It is usually assumed that each latent dimension follows an independent Gaussian distribution, so the encoder produces a mean \(\mu\) and a logarithm of the variance \(\log \sigma^2\) for each dimension.
During training, a sample \(z\) is drawn from the latent space using the reparameterization trick:
where \(\varepsilon \sim \mathcal{N}(0, I)\) and
This formulation allows gradients to be backpropagated through the sampling operation.
The VAE loss function includes two terms. The first is the reconstruction loss, which measures the discrepancy between the original and reconstructed images (for example, using binary cross-entropy). The second is a regularization term based on the KullbackāLeibler (KL) divergence between the learned latent distribution and a standard normal distribution \(\mathcal{N}(0, I)\):
This term forces the latent space to adopt a well-structured distribution, facilitating sampling and the generation of new examples.
The following code presents a convolutional VAE implementation for MNIST.
# 3pps
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.manifold import TSNE
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
class VAE(nn.Module):
def __init__(self, latent_dim: int = 20) -> None:
super().__init__()
# Convolutional encoder
self.encoder = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1), # 28x28 -> 14x14
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # 14x14 -> 7x7
nn.ReLU(),
nn.Flatten(),
)
self.fc_mu = nn.Linear(64 * 7 * 7, latent_dim)
self.fc_logvar = nn.Linear(64 * 7 * 7, latent_dim)
# Decoder
self.fc_decode = nn.Linear(latent_dim, 64 * 7 * 7)
self.decoder = nn.Sequential(
nn.Unflatten(1, (64, 7, 7)),
nn.ConvTranspose2d(
64, 32, kernel_size=3, stride=2, padding=1, output_padding=1
), # 7x7 -> 14x14
nn.ReLU(),
nn.ConvTranspose2d(
32, 1, kernel_size=3, stride=2, padding=1, output_padding=1
), # 14x14 -> 28x28
nn.Sigmoid(),
)
def encode(self, x: torch.Tensor):
h = self.encoder(x)
return self.fc_mu(h), self.fc_logvar(h)
def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z: torch.Tensor) -> torch.Tensor:
x = self.fc_decode(z)
return self.decoder(x)
def forward(self, x: torch.Tensor):
mu, logvar = self.encode(x)
z = self.reparameterize(mu, logvar)
reconstructed = self.decode(z)
return reconstructed, mu, logvar
def vae_loss(
reconstructed: torch.Tensor,
original: torch.Tensor,
mu: torch.Tensor,
logvar: torch.Tensor,
) -> torch.Tensor:
recon_loss = nn.functional.binary_cross_entropy(
reconstructed, original, reduction="sum"
)
kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return recon_loss + kl_loss
def train_vae(
model: nn.Module,
train_loader: DataLoader,
num_epochs: int = 10,
device: str = "cuda",
) -> nn.Module:
model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
for epoch in range(num_epochs):
model.train()
total_loss = 0.0
for data, _ in train_loader:
data = data.to(device)
optimizer.zero_grad()
reconstructed, mu, logvar = model(data)
loss = vae_loss(reconstructed, data, mu, logvar)
loss.backward()
optimizer.step()
total_loss += loss.item()
print(
f"Epoch [{epoch+1}/{num_epochs}], "
f"Loss: {total_loss / len(train_loader.dataset):.4f}"
)
return model
def visualize_latent_space_tsne(
model: VAE, data_loader: DataLoader, device: str = "cuda", n_samples: int = 5000
) -> None:
"""Visualize the latent space using t-SNE."""
model.eval()
latent_vectors = []
labels = []
with torch.no_grad():
for data, label in data_loader:
data = data.to(device)
mu, _ = model.encode(data)
latent_vectors.append(mu.cpu().numpy())
labels.append(label.numpy())
if len(latent_vectors) * data.size(0) >= n_samples:
break
latent_vectors = np.concatenate(latent_vectors, axis=0)[:n_samples]
labels = np.concatenate(labels, axis=0)[:n_samples]
print("Applying t-SNE...")
tsne = TSNE(n_components=2, random_state=42, perplexity=30)
latent_2d = tsne.fit_transform(latent_vectors)
plt.figure(figsize=(12, 10))
scatter = plt.scatter(
latent_2d[:, 0], latent_2d[:, 1], c=labels, cmap="tab10", alpha=0.6, s=5
)
plt.colorbar(scatter, label="Digit")
plt.title("t-SNE Visualization of the VAE Latent Space")
plt.xlabel("t-SNE Dimension 1")
plt.ylabel("t-SNE Dimension 2")
plt.tight_layout()
plt.show()
def generate_samples(
model: VAE, num_samples: int = 16, latent_dim: int = 20, device: str = "cuda"
) -> None:
model.eval()
with torch.no_grad():
z = torch.randn(num_samples, latent_dim).to(device)
samples = model.decode(z).cpu()
fig, axes = plt.subplots(4, 4, figsize=(8, 8))
for i, ax in enumerate(axes.flat):
ax.imshow(samples[i].squeeze(), cmap="gray")
ax.axis("off")
plt.tight_layout()
plt.show()
def prepare_mnist_data(batch_size: int = 128):
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(
root="./data", train=True, download=True, transform=transform
)
test_dataset = datasets.MNIST(
root="./data", train=False, download=True, transform=transform
)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
return train_loader, test_loader
# Prepare data
train_loader, test_loader = prepare_mnist_data()
# Train VAE
vae = VAE(latent_dim=20)
device = "cuda" if torch.cuda.is_available() else "cpu"
vae = train_vae(vae, train_loader, num_epochs=2, device=device)
# Visualize latent space with t-SNE
visualize_latent_space_tsne(vae, test_loader, device=device)
# Generate synthetic samples
generate_samples(vae, latent_dim=20, device=device)
# Visualize reconstructions
with torch.no_grad():
data, _ = next(iter(test_loader))
data = data[:10].to(device)
reconstructed, _, _ = vae(data)
fig, axes = plt.subplots(2, 10, figsize=(15, 3))
for i in range(10):
axes[0, i].imshow(data[i].cpu().squeeze(), cmap="gray")
axes[0, i].axis("off")
axes[1, i].imshow(reconstructed[i].cpu().squeeze(), cmap="gray")
axes[1, i].axis("off")
axes[0, 0].set_ylabel("Original", size=12)
axes[1, 0].set_ylabel("Reconstructed", size=12)
plt.tight_layout()
plt.show()
VAEs are particularly useful for generating synthetic data by direct sampling in the latent space and for anomaly detection by analyzing out-of-distribution examples. However, they can suffer from the posterior collapse phenomenon, in which the decoder largely ignores latent information and learns to reconstruct from local patterns alone, reducing the quality and informativeness of latent representations.
Beta-VAE
The Beta-VAE introduces a hyperparameter \(\beta\) in the VAE loss function to weight the KL divergence term:
When \(\beta > 1\), the model is forced to align the latent distribution more strongly with the standard normal distribution, which tends to produce more disentangled representations. In a disentangled latent space, each dimension preferentially captures an independent factor of variation in the data (for example, stroke thickness, slant, or size), improving interpretability and control over generated samples.
Excessively high values of \(\beta\) can degrade reconstruction quality by penalizing latent code complexity too strongly.
The following code shows how to adapt the loss and training procedure for a Beta-VAE using the VAE architecture defined above.
def beta_vae_loss(
reconstructed: torch.Tensor,
original: torch.Tensor,
mu: torch.Tensor,
logvar: torch.Tensor,
beta: float = 4.0,
) -> torch.Tensor:
recon_loss = nn.functional.binary_cross_entropy(
reconstructed, original, reduction="sum"
)
kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return recon_loss + beta * kl_loss
def train_beta_vae(
model: VAE,
train_loader: DataLoader,
num_epochs: int = 10,
beta: float = 4.0,
device: str = "cuda",
) -> VAE:
model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
model.train()
for epoch in range(num_epochs):
total_loss = 0.0
for data, _ in train_loader:
data = data.to(device)
optimizer.zero_grad()
reconstructed, mu, logvar = model(data)
loss = beta_vae_loss(reconstructed, data, mu, logvar, beta)
loss.backward()
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(train_loader.dataset)
print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")
return model
To explore the effect of controlled variations along individual latent dimensions, the latent traversal technique is used. It consists of systematically modifying a single latent coordinate while keeping the remaining ones fixed.
def visualize_latent_traversal(
model: VAE,
test_loader: DataLoader,
latent_dim: int = 20,
dim_to_vary: int = 0,
device: str = "cuda",
) -> None:
model.eval()
data, _ = next(iter(test_loader))
data = data[0:1].to(device)
with torch.no_grad():
mu, _ = model.encode(data)
values = torch.linspace(-3, 3, 10)
samples = []
for val in values:
z = mu.clone()
z[0, dim_to_vary] = val
reconstructed = model.decode(z)
samples.append(reconstructed)
samples = torch.cat(samples, dim=0)
samples = samples.cpu()
fig, axes = plt.subplots(1, 10, figsize=(15, 2))
for i, ax in enumerate(axes.flat):
ax.imshow(samples[i].squeeze(), cmap="gray")
ax.axis("off")
ax.set_title(f"{values[i]:.1f}")
plt.tight_layout()
plt.show()
# Training the Beta-VAE
beta_vae = VAE(latent_dim=20)
beta_vae = train_beta_vae(
beta_vae, train_loader, num_epochs=2, beta=4.0, device=device
)
# Visualize variation of some latent dimensions
for dim in range(5):
visualize_latent_traversal(beta_vae, test_loader, dim_to_vary=dim, device=device)
Latent traversal enables inspection of the influence of each latent dimension on generated samples, facilitating the interpretation of disentangled representations and the design of controlled manipulations over specific attributes.
VQ-VAE (Vector Quantized VAE)
VQ-VAE introduces a fundamental modification in the treatment of the latent space. Instead of continuous codes, it uses a discrete representation based on a learned codebook of embeddings. The encoder projects the input into a continuous latent tensor of dimension \(C\); each latent vector is then quantized by selecting the closest embedding from the codebook, that is, by assigning a discrete index. The decoder receives the quantized embeddings and reconstructs the input.
This discretization offers several advantages. It avoids the posterior collapse problem typical of some VAEs, as quantization forces the model to actively use the latent space. Moreover, the discrete representation is particularly well suited to be modeled later using autoregressive models (for example, transformers), which has been crucial in generative architectures such as DALLĀ·E. In this context, latent indices act as tokens on which language-modeling techniques can be applied.
The following code presents a simple VQ-VAE implementation for MNIST, including the vector quantization module.
"""VQ-VAE (Vector Quantized Variational Autoencoder) Implementation"""
# 3pps
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
class VectorQuantizer(nn.Module):
"""
Vector Quantizer layer for VQ-VAE.
Converts continuous latent vectors into discrete codes from the codebook.
"""
def __init__(
self, num_embeddings: int, embedding_dim: int, commitment_cost: float = 0.25
) -> None:
super().__init__()
self.embedding_dim = embedding_dim
self.num_embeddings = num_embeddings
self.commitment_cost = commitment_cost
# Codebook of embeddings
self.embeddings = nn.Embedding(num_embeddings, embedding_dim)
self.embeddings.weight.data.uniform_(-1 / num_embeddings, 1 / num_embeddings)
def forward(self, inputs: torch.Tensor):
"""
Args:
inputs: Tensor of shape (B, C, H, W)
Returns:
quantized: Quantized tensor (B, C, H, W)
loss: Quantization loss (codebook + commitment)
encoding_indices: Indices of selected codebook vectors
"""
# Reorder to (B, H, W, C)
inputs = inputs.permute(0, 2, 3, 1).contiguous()
input_shape = inputs.shape
# Flatten to (B*H*W, C)
flat_input = inputs.view(-1, self.embedding_dim)
# L2 distances to each codebook embedding
distances = (
torch.sum(flat_input**2, dim=1, keepdim=True)
+ torch.sum(self.embeddings.weight**2, dim=1)
- 2 * torch.matmul(flat_input, self.embeddings.weight.t())
)
# Index of nearest embedding
encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
# One-hot encoding
encodings = torch.zeros(
encoding_indices.shape[0], self.num_embeddings, device=inputs.device
)
encodings.scatter_(1, encoding_indices, 1)
# Quantization via codebook
quantized = torch.matmul(encodings, self.embeddings.weight).view(input_shape)
# VQ losses
e_latent_loss = nn.functional.mse_loss(quantized.detach(), inputs)
q_latent_loss = nn.functional.mse_loss(quantized, inputs.detach())
loss = q_latent_loss + self.commitment_cost * e_latent_loss
# Straight-through estimator
quantized = inputs + (quantized - inputs).detach()
# Back to (B, C, H, W)
quantized = quantized.permute(0, 3, 1, 2).contiguous()
return quantized, loss, encoding_indices
class VQVAE(nn.Module):
"""
VQ-VAE model with encoder, vector quantizer, and decoder.
"""
def __init__(self, num_embeddings: int = 512, embedding_dim: int = 64) -> None:
super().__init__()
# Encoder: (1, 28, 28) -> (embedding_dim, 7, 7)
self.encoder = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=4, stride=2, padding=1), # 28x28 -> 14x14
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1), # 14x14 -> 7x7
nn.ReLU(),
nn.Conv2d(64, embedding_dim, kernel_size=1), # 7x7, C=embedding_dim
)
# Vector Quantizer
self.vq = VectorQuantizer(num_embeddings, embedding_dim)
# Decoder: (embedding_dim, 7, 7) -> (1, 28, 28)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(
embedding_dim, 64, kernel_size=4, stride=2, padding=1
), # 7x7 -> 14x14
nn.ReLU(),
nn.ConvTranspose2d(
64, 32, kernel_size=4, stride=2, padding=1
), # 14x14 -> 28x28
nn.ReLU(),
nn.ConvTranspose2d(32, 1, kernel_size=1),
nn.Sigmoid(),
)
def forward(self, x: torch.Tensor):
"""
Args:
x: Input tensor (B, 1, 28, 28)
Returns:
reconstructed: Reconstruction (B, 1, 28, 28)
vq_loss: Vector quantization loss
"""
z = self.encoder(x)
quantized, vq_loss, _ = self.vq(z)
reconstructed = self.decoder(quantized)
return reconstructed, vq_loss
def encode(self, x: torch.Tensor):
"""Encode and quantize the input."""
z = self.encoder(x)
quantized, _, indices = self.vq(z)
return quantized, indices
def decode(self, z: torch.Tensor) -> torch.Tensor:
"""Decode a quantized latent tensor."""
return self.decoder(z)
def train_vqvae(
model: VQVAE,
train_loader: DataLoader,
num_epochs: int = 10,
lr: float = 1e-3,
device: str = "cuda",
) -> VQVAE:
"""
Train the VQ-VAE model.
Args:
model: VQVAE model.
train_loader: Training DataLoader.
num_epochs: Number of epochs.
lr: Learning rate.
device: "cuda" or "cpu".
Returns:
Trained model.
"""
model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
model.train()
for epoch in range(num_epochs):
total_recon_loss = 0.0
total_vq_loss = 0.0
for data, _ in train_loader:
data = data.to(device)
optimizer.zero_grad()
reconstructed, vq_loss = model(data)
recon_loss = nn.functional.mse_loss(reconstructed, data)
loss = recon_loss + vq_loss
loss.backward()
optimizer.step()
total_recon_loss += recon_loss.item()
total_vq_loss += vq_loss.item()
avg_recon = total_recon_loss / len(train_loader)
avg_vq = total_vq_loss / len(train_loader)
print(
f"Epoch [{epoch+1}/{num_epochs}] | "
f"Recon Loss: {avg_recon:.6f} | "
f"VQ Loss: {avg_vq:.6f}"
)
return model
def visualize_vqvae_reconstructions(
model: VQVAE, test_loader: DataLoader, device: str = "cuda", num_images: int = 8
) -> None:
"""
Visualize original and VQ-VAE reconstructed images.
"""
model.eval()
data, _ = next(iter(test_loader))
data = data[:num_images].to(device)
with torch.no_grad():
reconstructed, _ = model(data)
data = data.cpu()
reconstructed = reconstructed.cpu()
fig, axes = plt.subplots(2, num_images, figsize=(12, 3))
for i in range(num_images):
axes[0, i].imshow(data[i].squeeze(), cmap="gray")
axes[0, i].axis("off")
axes[1, i].imshow(reconstructed[i].squeeze(), cmap="gray")
axes[1, i].axis("off")
axes[0, 0].set_ylabel("Original", size=12)
axes[1, 0].set_ylabel("Reconstructed", size=12)
plt.tight_layout()
plt.show()
# Main VQ-VAE execution
NUM_EMBEDDINGS = 512
EMBEDDING_DIM = 64
NUM_EPOCHS = 2
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
vqvae = VQVAE(num_embeddings=NUM_EMBEDDINGS, embedding_dim=EMBEDDING_DIM)
vqvae = train_vqvae(vqvae, train_loader, num_epochs=NUM_EPOCHS, device=DEVICE)
visualize_vqvae_reconstructions(vqvae, test_loader, device=DEVICE)
VQ-VAE provides a discrete latent space that is particularly suitable for integration into multimodal systems and complex generative models, in which tokenization of data is essential. Vector quantization offers a robust foundation for applying advanced sequential modeling techniques to image representations and facilitates integration with language architectures that operate on discrete sequences.