Skip to content

Graph-Based Models

Introduction

This chapter examines data with graph structure, that is, sets of nodes connected through known relationships. This type of structure proves extremely common in real-world applications, from proteins and molecules to traffic networks, social networks, and recommendation systems. The objective consists of introducing differentiable models capable of operating systematically and efficiently on these data, combining theoretical foundations with practical implementations in PyTorch.

Until now, data without structure or with relatively simple structures have been primarily considered: tabular data represented as vectors, sets, sequences, and regular grids such as images. However, many types of data present more complex dependencies between their elements. Networks of very diverse nature (social, transportation, energy, communication networks) are composed of millions of units (people, users, products, stations) that interact only through a reduced number of connections, such as roads, electrical links, social interactions, or ratings in recommendation systems. This type of structure is naturally described through the language of graph theory, which provides a rigorous mathematical framework for modeling and analyzing these complex relationships.

What follows introduces the basic notions of graphs, their matrix representations, and the fundamental operations that allow defining learning models on them. Both the theoretical foundations of convolutional layers on graphs and a detailed reconstruction of the original Graph Neural Networks model proposed by Scarselli et al. in 2009 are presented, including practical implementations that illustrate the mathematical concepts.

Graphs and Features on Graphs

In its simplest form, a graph is described as a pair of sets \(G = (V, E)\), where \(V = \{1, \dots, n\}\) is the set of nodes (or vertices), and \(E\) is the set of edges. For a simple undirected graph, it can be written \(E = \{(i,j) \mid i,j \in \mathbb{N}\}\), where each pair \((i,j)\) indicates that a connection (edge) exists between nodes \(i\) and \(j\). In practice, the number of nodes \(n = |V|\) and the number of edges \(m = |E|\) can vary from one graph to another within the same dataset.

Graphs generalize many previously seen structures. For example, a set can be viewed as a graph containing only self-loops of the form \((i,i)\), that is, each element connects only with itself. A fully connected graph, which contains all possible edges between pairs of nodes, relates to attention mechanisms. Images can be represented as graphs in which each pixel is associated with a node and neighboring pixels are connected following a regular grid structure. Similarly, a sequence can be viewed as a linear graph where each element connects with the next. In this sense, graphs provide a unifying framework: sets correspond to empty graphs or graphs with only loops, sequences to linear graphs, and images to regular graphs.

The connections of a graph can be equivalently represented through matrices. The most common representation is the adjacency matrix, a binary square matrix \(A \in \{0,1\}^{n \times n}\) defined by:

\[A_{ij} = \begin{cases} 1 & \text{if } (i,j) \in E,\\ 0 & \text{otherwise}. \end{cases}\]

In this format, a set is represented by the identity matrix \(A = I\) (only self-loops), a fully connected graph is represented by a matrix of all ones, and an image corresponds to a matrix with Toeplitz structure that reflects the local neighborhood structure of pixels. If the graph is undirected, every edge \((i,j)\) appears in pairs \((i,j)\) and \((j,i)\), and the adjacency matrix is symmetric, that is, \(A^\top = A\).

Besides the adjacency matrix, other useful representations exist. One of them is the incidence matrix \(B \in \{0,1\}^{n \times |E|}\), where each column corresponds to an edge and each row to a node. Typically, \(B_{ij} = 1\) if node \(i\) participates in edge \(j\), and \(0\) otherwise. In simple undirected graphs where each edge connects exactly two nodes, the sum over rows of each column is \(2\), so \(B \mathbf{1}^\top = 2\), where \(\mathbf{1}\) is the vector of all ones.

Computational Representation of Graphs

To implement learning models on graphs, it becomes necessary to establish a computationally manageable representation. In the code, an adjacency list-based representation is used, which proves especially efficient for sparse graphs. The following class defines a graph with nodes, edges, and possible features associated with both nodes and edges:

class Graph:
    def __init__(self, num_nodes, edges, features, edge_features=None):
        """
        Args:
            num_nodes: Total number of nodes.
            edges: List of tuples (src, dst) representing undirected edges.
            features: Array or tensor of shape (num_nodes, feature_dim) with node features.
            edge_features: Array or tensor of shape (num_edges, edge_dim) with optional edge features.
        """
        self.num_nodes = num_nodes
        self.adj_list = {i: [] for i in range(num_nodes)}

        for i, (src, dst) in enumerate(edges):
            if edge_features is not None:
                self.adj_list[src].append((dst, edge_features[i]))
                self.adj_list[dst].append((src, edge_features[i]))
            else:
                self.adj_list[src].append((dst, None))
                self.adj_list[dst].append((src, None))

        self.features = features

Each node stores a feature vector that can encode, for example, properties of an atom in chemistry such as type, charge, hybridization, or valence. Edges, in turn, can include features such as bond type or bond length. For undirected graphs, the edge is explicitly stored in both directions in the adjacency list, which simplifies information aggregation between neighbors.

Graphs can incorporate a wide variety of features that describe additional information. In general terms, three types of features are distinguished: node features (attributes associated with each node), edge features (attributes associated with each edge), and graph features (attributes associated with the complete graph).

To begin, the simplest case is considered: only unstructured features at nodes are available. Each node \(i\) has an associated feature vector \(\mathbf{x}_i \in \mathbb{R}^c\), where \(c\) is the feature dimension. Gathering the features of all nodes in a single matrix yields the feature matrix \(X \in \mathbb{R}^{n \times c}\), where the \(i\)-th row of \(X\) corresponds to \(\mathbf{x}_i\). The graph is then described by the pair \((X, A)\), where \(A \in \mathbb{R}^{n \times n}\) is the adjacency matrix.

In most cases, the order of nodes in \(X\) and in \(A\) is irrelevant. This means that if a simultaneous permutation is applied to the rows of \(X\) and to the rows and columns of \(A\), the resulting graph is essentially the same. Mathematically, if \(P \in \{0,1\}^{n \times n}\) is a permutation matrix, then \((X, A)\) and \((PX, \, P A P^\top)\) represent the same graph. This permutation invariance property has direct consequences on the design of models that operate on graphs: the layers used must be, at least, equivariant with respect to permutations.

Graph Topology, Degrees, and Normalized Matrices

Besides explicitly defined features, relevant information can be extracted directly from the graph topology. One of the most basic quantities associated with each node is its degree. The degree of a node \(i\), denoted as \(d_i\), is the number of nodes with which it is connected: \(d_i = \sum_j A_{ij}\). When \(A\) can take real values (for example, similarity weights), the degree is usually defined as the weighted sum of connections. The degree distribution \(\{d_i\}_{i=1}^n\) is an important characteristic of the graph.

The degrees can be gathered in a diagonal matrix called the degree matrix:

\[D = \begin{pmatrix} d_1 & 0 & \cdots & 0 \\ 0 & d_2 & \cdots & 0 \\ \vdots & \vdots & \ddots & \vdots \\ 0 & 0 & \cdots & d_n \end{pmatrix}.\]

From \(A\) and \(D\), different normalized versions of the adjacency matrix can be defined, very useful in diffusion operations and in graph layers. A first option is row normalization, defined as \(A' = D^{-1} A\), that is, \(A'_{ij} = \frac{1}{d_i} A_{ij}\). In this case, each row of \(A'\) sums to one, so a probability distribution over the neighbors of each node is obtained, describing the probability of a random walk.

A widely used symmetric normalization is \(A' = D^{-1/2} A D^{-1/2}\), with entries \(A'_{ij} = \frac{A_{ij}}{\sqrt{d_i d_j}}\). This version assigns a weight to each connection that takes into account the degree of both connected nodes. In the context of signal processing on graphs, these matrices are often called graph-shift matrices.

In practice, most real graphs are sparse: each node connects only with a small fraction of the total nodes. This property is crucial, as it allows using specific implementations and data structures for sparse matrices, with better computational complexity. Modern numerical computation and machine learning libraries (including implementations in JAX, PyTorch, or TensorFlow) provide explicit support for this type of matrices.

Diffusion Operations on Graphs

The fundamental operation to be used on graphs is diffusion, which can be interpreted as a smoothing of node features following the graph topology. To illustrate this idea, consider a scalar feature associated with each node, collected in a vector \(\mathbf{x} \in \mathbb{R}^n\). The simplest diffusion operation is \(\mathbf{x}' = A \mathbf{x}\), where \(A\) can be the adjacency matrix, a normalized version, or any shift matrix on the graph.

It is convenient to rewrite this operation at the node level. The updated value for node \(i\) is \(x'_i = \sum_{j \in \mathcal{N}(i)} A_{ij} x_j\), where \(\mathcal{N}(i) = \{j \mid (i,j) \in E\}\) denotes the 1-hop neighborhood of node \(i\), that is, the set of nodes directly connected with \(i\). If \(x_i\) is interpreted as a physical quantity at node \(i\), multiplication by \(A\) can be viewed as a diffusion process: the value at each node is replaced by a weighted average of the values in its neighborhood. By repeatedly applying this operation, the signal progressively smooths over the graph.

Another fundamental matrix in graph analysis is the Laplacian matrix, defined as \(L = D - A\), where \(D\) is the degree matrix. A diffusion iteration based on the Laplacian can be written at the node level as:

\[[L \mathbf{x}]_i = \sum_{(i,j) \in E} A_{ij} (x_i - x_j).\]

This expression shows that the Laplacian is intimately related to the notion of gradient or variation of the signal over the graph: it measures the difference between the value at a node and the value at its neighbors, weighted by the edge weights. An important property is that the vector of all ones \(\mathbf{1}\) is always an eigenvector of the Laplacian associated with eigenvalue zero (in connected and undirected graphs, this is the smallest eigenvalue), which relates to the fact that a constant signal over the graph presents no variation between neighboring nodes.

Transition Function and State Diffusion

The central objective of a GNN consists of learning meaningful representations for nodes, their edges, or the complete graph. Instead of treating each input independently, as in traditional networks, nodes interact through edges: node representations are updated using information from their neighbors. This dynamic is formalized through the transition function. Let \(x_n\) be the state vector of node \(n\), which must summarize its local information and that of its neighborhood. The model by Scarselli et al. proposes an iterative update of the form:

\[x_n^{(t+1)} = f_w\bigl(l_n,\ l_{\text{co}[n]},\ x_{\text{ne}[n]}^{(t)},\ l_{\text{ne}[n]} \bigr),\]

where \(x_n^{(t+1)}\) represents the new state of node \(n\) at iteration \(t+1\), \(l_n\) are the features of node \(n\), \(l_{\text{co}[n]}\) are the features of edges connected to \(n\), \(x_{\text{ne}[n]}^{(t)}\) are the states of neighbors of \(n\) at iteration \(t\), \(l_{\text{ne}[n]}\) are the features of neighboring nodes of \(n\), and \(f_w\) is the transition function parameterized by \(w\), implemented through a neural network.

Informally, it can be understood as a message-passing process: each node begins with an initial state, receives information from its neighbors (their states and features), updates its state, and repeats this process. After several iterations, the states incorporate both local and global information from the graph. In the proposed implementation, the transition function is decomposed into two networks: a transition network that, from the features of a node, its neighbor, and the edge connecting them, produces a matrix \(A_{n,u}\) that transforms the neighbor's state \(x_n\), and a forcing network that, from the features of a node, produces a bias term \(b_n\) that is added to the linear contribution.

The code for this transition function is:

import torch
import torch.nn as nn
import torch.nn.utils.spectral_norm as spectral_norm

class F_function(nn.Module):
    def __init__(self, state_dim, hidden_neurons=5, feature_dim=None, edge_dim=None):
        super(F_function, self).__init__()

        self.state_dim = state_dim
        self.feature_dim = feature_dim
        self.edge_dim = edge_dim

        transition_input_dim = 2 * feature_dim + (edge_dim if edge_dim is not None else 0)
        self.transition_network = nn.Sequential(
            spectral_norm(nn.Linear(transition_input_dim, hidden_neurons)),
            nn.Tanh(),
            spectral_norm(nn.Linear(hidden_neurons, state_dim * state_dim))
        )

        self.forcing_network = nn.Sequential(
            spectral_norm(nn.Linear(feature_dim, hidden_neurons)),
            nn.ReLU(),
            spectral_norm(nn.Linear(hidden_neurons, state_dim))
        )

    def forward(self, graph, states):
        """
        Args:
            graph: Graph object.
            states: Tensor of shape (num_nodes, state_dim) with current states.

        Returns:
            new_states: Tensor (num_nodes, state_dim) with new states.
        """
        device = states.device
        new_states = torch.zeros_like(states, device=device)

        for u in range(graph.num_nodes):
            sum_term = torch.zeros(self.state_dim, device=device)

            for (n, edge_features) in graph.adj_list[u]:
                if edge_features is None:
                    input_transition = torch.cat([graph.features[n], graph.features[u]], dim=0)
                else:
                    input_transition = torch.cat([graph.features[n], graph.features[u], edge_features], dim=0)

                phi_w = self.transition_network(input_transition)
                A_n_u = phi_w.view(self.state_dim, self.state_dim)

                mu = 1.0 / len(graph.adj_list[u])
                A_n_u = mu * A_n_u

                b_n = self.forcing_network(graph.features[n])

                sum_term += torch.matmul(A_n_u, states[n]) + b_n.squeeze()

            new_states[u] = sum_term

        return new_states

For each node \(u\), the contributions of its neighbors \(n \in \mathcal{N}(u)\) are summed in the form:

\[x_u^{\text{new}} = \sum_{n \in \mathcal{N}(u)} \bigl( A_{n,u} x_n + b_n \bigr),\]

where \(A_{n,u}\) comes from the transition network and \(b_n\) from the forcing network. The factor \(\mu = 1 / \deg(u)\) acts as a scaling that contributes to the numerical and theoretical stability of the model, by moderating the total influence of neighbors on the node.

State Convergence, Fixed Points, and Contraction

The updates defined by the transition function are applied iteratively until reaching a stationary solution or fixed point. The objective consists of finding a state \(x^\star\) such that \(x^\star = F(x^\star)\), where \(F\) is the global mapping that, from the current state of all nodes, returns the new state after a transition iteration.

Mathematically, to ensure the existence and uniqueness of a fixed point, as well as the convergence of the Picard iteration \(x^{(t+1)} = F(x^{(t)})\), the concept of contractive mapping is invoked. Let \((X, d)\) be a metric space. A function \(F: X \to X\) is a contraction if there exists a constant \(c \in [0,1)\) such that:

\[d(F(x), F(y)) \leq c\, d(x, y) \quad \text{for all } x, y \in X.\]

In the GNN context, work is usually done in \(X = \mathbb{R}^{N \times d}\) with a matrix norm, for example the Frobenius norm. The contraction condition can be expressed as \(\|F(x) - F(y)\|_{\mathsf{F}} \leq c_{\mathsf{F}} \, \|x - y\|_{\mathsf{F}}\), with \(c_{\mathsf{F}} < 1\). If \(F\) is differentiable, one way to control \(c_{\mathsf{F}}\) is through the spectral norm of the Jacobian \(J_F(x)\).

In practice, the state iteration is implemented simply:

def compute_states(graph, transition, max_iters=200, epsilon=1e-3):
    """
    Iterates the transition function until reaching an approximate fixed point.

    Args:
        graph: Graph object.
        transition: F_function module (transition function).
        max_iters: Maximum number of iterations.
        epsilon: Convergence threshold (norm of difference between iterations).

    Returns:
        states: Tensor (num_nodes, state_dim) with converged states.
    """
    device = graph.features.device
    states = torch.zeros(graph.num_nodes, transition.state_dim, device=device)

    for _ in range(max_iters):
        new_states = transition(graph, states)
        diff = torch.norm(new_states - states)
        states = new_states
        if diff < epsilon:
            break
    else:
        print("Max iterations reached but not converged")

    return states

Starting from an initial state, for example all zeros, the iteration \(x^{(t+1)} = F(x^{(t)})\) is performed until the norm of the difference \(\|x^{(t+1)} - x^{(t)}\|\) is less than a threshold \(\varepsilon\). If the transition function is effectively contractive, this process converges for any initial condition, and the point to which it converges is unique.

Instead of assuming that the transition function is a contraction, this property is imposed during training through a penalty term in the loss function. The idea consists of estimating the spectral norm \(\|J_F(x^\star)\|_2\) at the approximate fixed point \(x^\star\) and penalizing its excess over a threshold \(\tau < 1\). To avoid explicitly forming the Jacobian, which would have dimensions \((Nd) \times (Nd)\), a power iteration method combined with vector–Jacobian products provided by autograd is employed.

The penalty calculation is implemented as:

import torch.autograd as autograd

def calculate_spectral_norm_penalty(transition_model, graph, states, k=5, target=0.95):
    """
    Calculates the approximate spectral norm of the Jacobian of the transition function
    at point 'states' and returns a penalty if this norm exceeds 'target'.

    Args:
        transition_model: F_function module.
        graph: Graph on which the state is evaluated.
        states: States at the approximate fixed point (x*).
        k: Number of power iterations to estimate the norm.
        target: Desired contraction constant (< 1).

    Returns:
        penalty: Scalar (PyTorch tensor) with the penalty value.
    """
    x = states.detach().requires_grad_()
    F_x = transition_model(graph, x)

    v = torch.randn_like(x, device=x.device)
    v = v / (torch.norm(v) + 1e-6)

    with torch.set_grad_enabled(True):
        for _ in range(k):
            vJp = autograd.grad(
                F_x, x,
                grad_outputs=v,
                retain_graph=True,
                create_graph=True
            )[0]
            norm = torch.norm(vJp)
            v = vJp / (norm + 1e-6)

    vJp_final = autograd.grad(
        F_x, x,
        grad_outputs=v,
        retain_graph=True,
        create_graph=True
    )[0]
    spectral_norm_estimate = torch.norm(vJp_final)

    penalty = torch.relu(spectral_norm_estimate - target)

    return penalty

This penalty is combined with the task loss during training, so the model simultaneously learns to solve the task and to maintain the transition function within the contractive regime, thus guaranteeing the stability and convergence of the state iteration.

Output Function

Once converged states \(x_n^\star\) are available for each node \(n\), an output function is required that transforms these states and the original features into predictions for the desired task. Formally, for each node \(n\) it is defined:

\[o_n = g_w(l_n, x_n),\]

where \(l_n\) are the original features of the node, \(x_n\) is the converged state of the node, \(g_w\) is the output function parameterized by \(w\), and \(o_n\) is the final output.

A simple and general implementation of \(g_w\) is:

class G_Function(nn.Module):
    def __init__(self, feature_dim, state_dim, output_dim, hidden_neurons=5):
        super(G_Function, self).__init__()

        input_dim = feature_dim + state_dim
        self.state_dim = state_dim
        self.output_dim = output_dim

        self.fc = nn.Sequential(
            nn.utils.spectral_norm(nn.Linear(input_dim, hidden_neurons)),
            nn.ReLU(),
            nn.utils.spectral_norm(nn.Linear(hidden_neurons, output_dim))
        )

    def forward(self, graph, states):
        device = states.device
        outputs = torch.zeros(graph.num_nodes, self.output_dim,
                              dtype=states.dtype, device=device)

        for v in range(graph.num_nodes):
            input_v = torch.cat([graph.features[v], states[v]], dim=0)
            outputs[v] = self.fc(input_v)

        return outputs

Complete GNN Architecture

The combination of the transition function, the state iteration until reaching a fixed point, and the output function gives rise to the global GNN architecture:

class GNN(nn.Module):
    def __init__(self, feature_dim, state_dim, output_dim, edge_dim=None, hidden_neurons=5):
        super(GNN, self).__init__()

        self.state_dim = state_dim
        self.transition = F_function(
            state_dim=state_dim,
            hidden_neurons=hidden_neurons,
            feature_dim=feature_dim,
            edge_dim=edge_dim
        )
        self.output = G_Function(
            feature_dim=feature_dim,
            state_dim=state_dim,
            output_dim=output_dim,
            hidden_neurons=hidden_neurons
        )

    def forward(self, graph):
        states = compute_states(graph, self.transition)
        return self.output(graph, states)

GNN Training

Training integrates the fixed-point iteration and the contraction penalty. The total loss is composed of:

\[L_{\text{total}} = L_{\text{task}} + \lambda_{\text{pen}} \cdot \max\bigl(0,\ \|J_F(x^\star)\|_2 - \tau\bigr),\]

where \(L_{\text{task}}\) is the loss associated with the task, \(\lambda_{\text{pen}}\) is the penalty weight, and \(\tau < 1\) is the target contraction constant.

Rprop (Resilient Backpropagation) is employed, an algorithm that uses only the sign of the gradient, ignoring its magnitude, which proves especially useful for stabilizing the training of models with implicit deep structure.

from tqdm.auto import tqdm

def train(gnn, train_data, val_data,
          criterion, validate_criterion,
          validate_every=20,
          epochs=5000, lr=0.01,
          lambda_penalty=1.0, target_contraction=0.95):
    eta_plus = 1.2
    eta_minus = 0.5
    initial_delta = 0.1

    prev_grads = {}
    deltas = {}

    for name, param in gnn.named_parameters():
        if param.requires_grad:
            prev_grads[name] = torch.zeros_like(param.data)
            deltas[name] = torch.full_like(param.data, initial_delta)

    best_val_loss = float('inf')
    best_model_state = None
    train_losses = []
    val_losses = []

    print(f"Starting training with {epochs} epochs")
    print(f"Training data size: {len(train_data)}, Validation data size: {len(val_data)}")

    for epoch in tqdm(range(epochs), desc="Training"):
        gnn.train()
        gnn.zero_grad()

        total_train_loss = 0.0

        for graph, labels, mask in train_data:
            states_star = compute_states(graph, gnn.transition)
            preds = gnn.output(graph, states_star).squeeze()
            task_loss = criterion(preds[mask], labels[mask])

            penalty = calculate_spectral_norm_penalty(
                gnn.transition, graph, states_star,
                k=5, target=target_contraction
            )
            penalty_loss = lambda_penalty * penalty
            total_loss = task_loss + penalty_loss
            total_loss.backward()

            total_train_loss += total_loss.item()

        with torch.no_grad():
            for name, param in gnn.named_parameters():
                if param.grad is not None:
                    sign_changed = (prev_grads[name] * param.grad.data) < 0

                    if epoch == 0:
                        param.data -= lr * param.grad.data
                    else:
                        deltas[name][sign_changed] *= eta_minus
                        deltas[name][~sign_changed] *= eta_plus
                        param.data -= torch.sign(param.grad.data) * deltas[name]

                    prev_grads[name].copy_(param.grad.data)

        gnn.zero_grad()
        avg_train_loss = total_train_loss / max(1, len(train_data))
        train_losses.append(avg_train_loss)

        if (epoch + 1) % validate_every == 0:
            gnn.eval()
            total_val_loss = 0.0
            with torch.no_grad():
                for graph, labels, mask in val_data:
                    preds = gnn(graph).squeeze()
                    loss = validate_criterion(preds[mask], labels[mask])
                    total_val_loss += loss.item()

            avg_val_loss = total_val_loss / max(1, len(val_data))
            val_losses.append(avg_val_loss)

            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                best_model_state = {k: v.clone() for k, v in gnn.state_dict().items()}

    if best_model_state is not None:
        gnn.load_state_dict(best_model_state)

    return gnn, train_losses, val_losses

This procedure guarantees that the model learns meaningful relationships between nodes and edges for the specific task, the transition function is maintained within a contractive regime ensuring convergence of the state iteration, and parameter updates are robust against gradient oscillations thanks to the use of Rprop.

Complete Example: Node Classification in a Synthetic Graph

To illustrate the complete functioning of the GNN quickly, a small synthetic graph with 8 nodes divided into two communities is used.

import numpy as np
import matplotlib.pyplot as plt

# Small graph: 8 nodes, 2 communities
edges = [(0, 1), (0, 2), (1, 2), (1, 3), (4, 5), (4, 6), (5, 6), (5, 7), (3, 4)]
labels = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1], dtype=torch.float32)

# Features: normalized degree
num_nodes = 8
adj_matrix = np.zeros((num_nodes, num_nodes))
for src, dst in edges:
    adj_matrix[src, dst] = 1
    adj_matrix[dst, src] = 1

degrees = adj_matrix.sum(axis=1)
features = torch.tensor(degrees / degrees.max(), dtype=torch.float32).unsqueeze(1)

# Create graph
graph = Graph(num_nodes, edges, features)

# Train/val masks
train_mask = torch.zeros(num_nodes, dtype=torch.bool)
val_mask = torch.zeros(num_nodes, dtype=torch.bool)
train_mask[:6] = True
val_mask[6:] = True

train_data = [(graph, labels, train_mask)]
val_data = [(graph, labels, val_mask)]

# Create model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
graph.features = graph.features.to(device)
labels = labels.to(device)

model = GNN(feature_dim=1, state_dim=2, output_dim=1, hidden_neurons=4).to(device)

epochs=5
validate_every = 1

# Train
criterion = nn.BCEWithLogitsLoss()
model, train_losses, val_losses = train(
    model, train_data, val_data,
    criterion=criterion, validate_criterion=criterion,
    epochs=epochs, lr=0.01, lambda_penalty=0.3,
    target_contraction=0.95, validate_every=validate_every
)

# Evaluate
model.eval()
with torch.no_grad():
    preds = model(graph).squeeze()
    pred_labels = (torch.sigmoid(preds) > 0.5).float()
    accuracy = (pred_labels == labels).float().mean()
    print(f"\nAccuracy: {accuracy.item():.2%}")

# Visualize losses
plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.plot(train_losses)
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True)

plt.subplot(1, 2, 2)
val_epochs = list(range(validate_every, len(train_losses) + 1, validate_every))
plt.plot(val_epochs, val_losses)
plt.title('Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True)
plt.tight_layout()
plt.show()