Skip to content

Attention Mechanisms in Convolutional Neural Networks

Attention mechanisms in convolutional neural networks enable the model to adaptively focus on the most relevant features of the input signal, either at the channel level or at the spatial level. These modules learn to recalibrate intermediate activations by assigning differentiated importance weights, which increases the representational capacity of the model without drastically increasing the number of parameters or the computational cost.

In modern architectures, attention is integrated in a modular way into existing convolutional blocks, such as the residual blocks of ResNet. The following sections describe and implement two of the most influential attention mechanisms in convolutional networks: The Squeeze-and-Excitation (SE) block and the Convolutional Block Attention Module (CBAM).

Squeeze-and-Excitation (SE) Block

The Squeeze-and-Excitation block, introduced in the work Squeeze-and-Excitation Networks, incorporates a channel-wise attention mechanism. The central idea is to explicitly model the dependency relationships between feature channels so that the network learns to emphasize those channels that are most informative for the task, while suppressing less relevant or redundant channels.

The SE mechanism decomposes into two conceptual stages, commonly referred to as squeeze and excitation. In the squeeze phase, the spatial dimension of each feature map is reduced by means of global average pooling. In this way, each channel is compressed into a single scalar value that summarizes its global activation across the entire image. In the excitation phase, these aggregated values are fed into a small fully connected network that learns a channel-wise attention function. The output of this network is a vector of weights in the interval \((0, 1)\), which is applied multiplicatively to the original channels, recalibrating their relative importance.

Let \(X \in \mathbb{R}^{B \times C \times H \times W}\) be a feature tensor with batch size \(B\), number of channels \(C\), and spatial dimensions \(H \times W\). The squeeze operation computes, for each channel \(c\),

\[z_c = \frac{1}{HW} \sum_{i=1}^{H} \sum_{j=1}^{W} X_c(i, j)\]

The compressed vector \(z \in \mathbb{R}^{C}\) is processed by a two-layer fully connected network with an intermediate dimensionality reduction, which produces a vector of weights \(s \in (0, 1)^{C}\) after a sigmoid activation. The recalibration is implemented as

\[\tilde{X}_c(i, j) = s_c \cdot X_c(i, j)\]

The following code shows an implementation of the SE block and its integration into a basic residual block in PyTorch. The code is designed for direct use in a reproducible and fully executable workflow.

import torch
import torch.nn as nn
import torch.nn.functional as F

class SqueezeExcitation(nn.Module):
    def __init__(self, in_channels: int, reduction_ratio: int = 16) -> None:
        super().__init__()
        reduced_channels = max(in_channels // reduction_ratio, 1)
        # Squeeze: Global average pooling per channel
        self.squeeze = nn.AdaptiveAvgPool2d(1)
        # Excitation: Two fully connected (implemented as Linear) layers
        self.excitation = nn.Sequential(
            nn.Linear(in_channels, reduced_channels, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(reduced_channels, in_channels, bias=False),
            nn.Sigmoid(),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        batch_size, channels, _, _ = x.size()
        # Squeeze: Global average pooling per channel
        squeezed = self.squeeze(x).view(batch_size, channels)
        # Excitation: Channel-wise weights in (0, 1)
        excited = self.excitation(squeezed).view(batch_size, channels, 1, 1)
        # Channel-wise recalibration
        return x * excited

class SEResidualBlock(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        stride: int = 1,
        reduction_ratio: int = 16,
    ) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=3,
            stride=stride,
            padding=1,
            bias=False,
        )
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(
            out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False
        )
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.se = SqueezeExcitation(out_channels, reduction_ratio)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(
                    in_channels, out_channels, kernel_size=1, stride=stride, bias=False
                ),
                nn.BatchNorm2d(out_channels),
            )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        identity = self.shortcut(x)
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out = self.se(out)
        out += identity
        out = F.relu(out)
        return out

To verify the correct construction and behavior of the SE block, a small functional test can be defined. This test checks that input and output shapes match and reports the number of parameters of the SE module.

def test_se_block() -> None:
    x = torch.randn(2, 64, 32, 32)
    se_block = SqueezeExcitation(in_channels=64, reduction_ratio=16)
    output = se_block(x)
    print(f"Input shape:  {x.shape}")
    print(f"Output shape: {output.shape}")
    print(f"SE parameters: {sum(p.numel() for p in se_block.parameters())}")
    assert x.shape == output.shape, "Shape mismatch"
    print("SE Block test passed")

test_se_block()

The SE block introduces a relatively moderate number of additional parameters, controlled by the hyperparameter reduction_ratio. This parameter determines the bottleneck size in the excitation network: Larger values reduce the capacity of the module but also decrease its computational cost. In practice, configurations such as reduction_ratio = 16 usually provide a good balance between modeling capacity and efficiency.

Convolutional Block Attention Module (CBAM)

The Convolutional Block Attention Module (CBAM) extends the SE idea by sequentially incorporating attention both in the channel domain and in the spatial domain. First, it applies a channel attention module conceptually similar to SE, but combining information from global average pooling and global max pooling. Subsequently, it applies a spatial attention module that analyzes the distribution of activations across channels to determine which regions of the image are most relevant.

The channel attention module in CBAM is built from two parallel paths. One path receives as input the output of a global average pooling and the other uses the output of a global max pooling, both computed over the spatial dimensions for each channel. Each of these summaries is processed by a small \(1 \times 1\) convolutional network that acts as a shared fully connected projection. The two resulting outputs are combined by element-wise addition and then passed through a sigmoid function to obtain a channel attention map that modulates the contribution of each channel.

The spatial attention module is applied to the feature maps already recalibrated by channel. To this end, two single-channel spatial maps are computed by aggregating over the channel dimension using mean and maximum operations. These two maps are concatenated along the channel axis and processed by a convolution of size \(k \times k\), typically with \(k = 7\), followed by a sigmoid activation. The result is a spatial attention map that is applied multiplicatively to the signal, modulating the importance of each spatial position \((i, j)\) in the image.

The following code presents the implementation of CBAM (channel and spatial attention) and its integration into a residual block.

import torch
import torch.nn as nn
import torch.nn.functional as F

class ChannelAttention(nn.Module):
    def __init__(self, in_channels: int, reduction_ratio: int = 16) -> None:
        super().__init__()
        reduced_channels = max(in_channels // reduction_ratio, 1)
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        # Shared MLP implemented with 1x1 convolutions
        self.fc = nn.Sequential(
            nn.Conv2d(in_channels, reduced_channels, kernel_size=1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(reduced_channels, in_channels, kernel_size=1, bias=False),
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        attention = self.sigmoid(avg_out + max_out)
        return x * attention

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size: int = 7) -> None:
        super().__init__()
        padding = (kernel_size - 1) // 2
        self.conv = nn.Conv2d(
            2, 1, kernel_size=kernel_size, padding=padding, bias=False
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Channel-wise average and max projections
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        combined = torch.cat([avg_out, max_out], dim=1)
        attention = self.sigmoid(self.conv(combined))
        return x * attention

class CBAM(nn.Module):
    def __init__(
        self, in_channels: int, reduction_ratio: int = 16, kernel_size: int = 7
    ) -> None:
        super().__init__()
        self.channel_attention = ChannelAttention(in_channels, reduction_ratio)
        self.spatial_attention = SpatialAttention(kernel_size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.channel_attention(x)
        x = self.spatial_attention(x)
        return x

class CBAMResidualBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, stride: int = 1) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=3,
            stride=stride,
            padding=1,
            bias=False,
        )
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(
            out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False
        )
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.cbam = CBAM(out_channels)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(
                    in_channels, out_channels, kernel_size=1, stride=stride, bias=False
                ),
                nn.BatchNorm2d(out_channels),
            )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        identity = self.shortcut(x)
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out = self.cbam(out)
        out += identity
        out = F.relu(out)
        return out

The following code fragment performs a basic check of the CBAM module, analogous to the test applied in the case of the SE block. It validates that the input and output have the same shape and reports the number of parameters of the module.

import torch

def test_cbam() -> None:
    x = torch.randn(2, 64, 32, 32)  # Batch of 2, 64 channels, 32x32 feature map
    cbam = CBAM(in_channels=64)

    output = cbam(x)

    print(f"Input shape:  {x.shape}")
    print(f"Output shape: {output.shape}")
    print(f"CBAM parameters: {sum(p.numel() for p in cbam.parameters())}")

    assert x.shape == output.shape, "Shape mismatch"
    print("CBAM test passed")

test_cbam()

In practice, CBAM often provides consistent improvements over SE, since it combines channel-level and spatial attention in a complementary way. Spatial attention is particularly useful in tasks where the localization of objects or discriminative regions plays a critical role, such as object detection, semantic and instance segmentation, or recognition in scenarios with multiple instances per image.