Neural ODEs

Deep Learning
Author

Sandro Cavallari

Published

April 10, 2024

Ordinalry Differential Equations

Ordinalry Differential Equations or ODEs are equations with a single indipendent varlaible (usually called time \(t\)) and one or more derivatives of functions defined in terms of the indipendent variable. Formally,

\[ f(x(t), t) = x'(t) = \frac{ \partial x(t)} {\partial t} = \lim_{\Delta t \rightarrow 0} \frac{x(t + \Delta t) - x(t)}{\Delta t} \]

where:

  • \(t\) represent the time or any other indipendent variable used as domain of the derivate operator;
  • \(x(t) \in \mathbb{R}^d\) is the dependent variable defining the system’s state;
  • \(x'(t) \in \mathbb{R}^d\) is the first order derivative of \(x(t)\);
  • \(f \in \mathbb{R}^d \times \mathbb{R}^+\) is the vector-field differential functioin describing the systems’s evolution over time.

As \(f\) defines the evolution of a complex system over every infinitesimal interval of time \(\Delta t\), we can formally define an ODE problem as:

\[ \begin{align} x(t + \Delta t) & = x(t) + \Delta t \cdot f(x(t), t) \\ \text{s.t.} & ~~ \Delta t \rightarrow 0. \end{align} \]

Computing a Solution

In the general term, the solution to an ODE involve to computetion of the antiderivative of \(f\), in other words the integral of \(f\). As the integral of any function involve an arbitrary constant, usually defined as \(C\), there is the need to specify an initial condition \(x_0\) to guarantee that the solution of the ODE is unique:

\[ x_t = x_0 + \int_0^t f(x_\tau) \partial \tau \]

Note that, given the inital state \(x_0\) and a set of points in time \(\{ t_0, ..., t_N\}\), the objective is to obtain the state solution \(x_{0:N} \equiv \{ x_0, ..., x_N\}\). Unfortunetly, solving the above integral analytically is possible only for a limited amount of differential functions. Therefore, numerical solvers are used in practice.

import numpy as np
import torch
import torch.nn as nn
from torchdyn.numerics import odeint

torch.manual_seed(42)
np.random.seed(42)
device = torch.device("cpu")


class VDPoscillator(nn.Module):
    def __init__(self, mu: float) -> None:
        super().__init__()
        self.mu = mu

    def forward(
        self,
        t: torch.Tensor,
        x: torch.Tensor,
    ):
        var_0 = x[..., 1]
        var_1 = self.mu * (1 - x[..., 0] ** 2) * x[..., 1] - x[..., 0]
        return torch.stack(
            (var_0, var_1),
            dim=-1,
        )


from typing import Callable
import numpy as np
import torch
from plot_vector_field import plot_vdp_animation

vdp = VDPoscillator(1.0).to(device)

# initial value, of shape [N, d]
x0 = torch.tensor([[1.0, 0.0]]).float().to(device)

# integration time points, of shape [T]
t_span = torch.linspace(0.0, 15.0, 500).to(device)

with torch.no_grad():
    t_eval, vdp_sol_euler = odeint(
        vdp, x0, t_span=t_span, solver="euler"
    )  # [T], [T, N, d]

t_eval.size(), vdp_sol_euler.size()

# anim = plot_vdp_animation(
#     t=t_eval.detach().cpu().numpy(),
#     X=sol_euler.detach().cpu().numpy(),
#     ode_rhs=vdp,
# )
(torch.Size([500]), torch.Size([500, 1, 2]))

As overmentioned the solution depend on the initial state of the system. Following there is a demonstration of how the system will behave if you used different starting points.

from plot_vector_field import plot_ode

x0 = (
    torch.tensor(
        [
            [1.0, 0.0],
            [-2.0, -3.0],
            [-2.0, 3.0],
        ]
    )
    .float()
    .to(device)
)

# integration time points, of shape [T]
t_span = torch.linspace(0.0, 15.0, 500).to(device)

# forward integration
with torch.no_grad():
    t_vdps_eval, vdps_sol_euler = odeint(vdp, x0, t_span=t_span, solver="euler")

plot_ode(
    t=t_vdps_eval.detach().cpu().numpy(),
    X=vdps_sol_euler.detach().cpu().numpy(),
    ode_rhs=vdp,
)

Neural ODEs

Neural ODEs are a family of ODEs for which the vector-field \(f(x_t, t)\) is defined by a neural network. As such, \(f(x_t, t)\) is both differentiable and learnable. Thus, given a set of observation \(y_{0:N}\) from a unknwon dynamical system, we can used it to learn a model of the evolution of the system’s dynamics.

Problem Formulation

Given a dataset containing noise observation \(y_n\) where each observation is the perturbation of an unknown state \(x_n\) generated by an unknown underling dynamics \(f_{true}\):

\[ \begin{align} y_{n} & = x_{n} + \epsilon, ~~ \epsilon \sim \mathcal{N}(0, \sigma^2) \quad \substack{y_n \text{ is a noise variable}} \\ x_{n} & = x_{0} + \int_{0}^{n} f_{true}(x_{\tau}) \partial \tau \end{align} \]

The objective is to learn a neural network \(f_\theta\) that matches the unknown dynamics:

\[ f_\theta \approx f_{true}. \]

from torch import nn, Tensor
from torchdyn.core import NeuralODE


class VectorField(nn.Module):
    def __init__(self, d: int):
        """d - ODE dimensionality"""
        super().__init__()
        self.d = d
        self._f = nn.Sequential(
            nn.Linear(d, 20),
            nn.ReLU(),
            nn.Linear(20, 20),
            nn.ReLU(),
            nn.Linear(20, d),
        )
        self.reset_parameter()

    def reset_parameter(self):
        for name, param in self.named_parameters():
            if "bias" in name:
                nn.init.constant_(param, 0.0)

            else:
                nn.init.xavier_uniform_(param, gain=nn.init.calculate_gain("relu"))

    def forward(self, t: Tensor, x: Tensor, **kwargs):
        """Forward integrates the NODE system and returns state solutions
        Input
            ts - [T]   time points
            x0 - [N,d] initial value
        Returns
            X  - [N,d] forward simulated states
        """
        return self._f(x)


# define vector-field
field = VectorField(2).to(device)
model = NeuralODE(field, solver="euler").to(device)

# let's compute the integral of our neural net!
x0 = torch.tensor([[1.0, 0.0]]).float().to(device)
t_span = torch.linspace(0.0, 1.0, 100).to(device)

t_eval, trajectory_init = model(x0, t_span)

plot_ode(
    t_eval.detach().cpu().numpy(),
    trajectory_init.detach().cpu().numpy(),
    field.forward,
)

VDP Learning

In the previous example \(f_\theta\) is randomly initialized, thus the resulting vector-field does not exibit any interesting behaviour. As abovementioned, the objective is to learn such \(f_\theta\) from some observations \(y_n\). Getting a state solution \(x_n\) involve solving an ODE defined by \(f_\theta\). Similarly, optimizing \(\theta\) involves computing the gradinets w.r.t. \(x_n\). As differentiating through the ODE solver is computationally inefficent, the gradients of the paramters are obtained with the Adjoint State Method. Thus, we can adopt maximum-likelihood estimation (MLE) to train our neural ODE:

\[ \begin{split} \text{argmin}_\theta \mathcal{L(\theta)} & = \frac{1}{2} \sum_{n=0}^N || y_n - x_n||_2^2 \\ s.t. ~~ x_n & = x_0 + \int_0^{n} f_\theta(x_\tau) \partial \tau \end{split} \]

In order to train our \(f_\theta\), some observations of a systems are needed. Thus, next some noise observation are generated from the original VDP oscillator.

import torch.utils.data as data

# generate noise examples
y_n = vdps_sol_euler + torch.randn_like(vdps_sol_euler) * 0.05

observations_flatten = y_n.permute(1, 0, 2).reshape(-1, 2)
t_eval = t_vdps_eval.repeat(3)

t_steps = 50
dataset_size = 10


# function that extract portion of a batch of trajectory
def get_example(t_eval_, observations_, t_steps):
    T = t_eval_.size(0) // 3
    example_id = torch.multinomial(torch.ones(3) / 3, 1).item()

    t_start = example_id * T
    t_end = (example_id + 1) * T
    t0 = torch.randint(
        t_start, (t_end - t_steps), (1,)
    ).item()  # pick the initial value

    (
        t_span_,
        x_start_,
        x_targets_,
    ) = (
        t_eval_[t0 : t0 + t_steps],
        observations_[t0],
        observations_[t0 : t0 + t_steps],
    )  # pick subsequences

    return (
        t_span_.detach().cpu().numpy(),
        x_start_.unsqueeze(0).detach().cpu().numpy(),
        x_targets_.unsqueeze(0).detach().cpu().numpy(),
    )


t_spans = []
x_starts = []
x_targets = []

for i in range(dataset_size):
    t_span, x_start, x_target = get_example(t_eval, observations_flatten, t_steps)
    t_spans.append(t_span)
    x_starts.append(x_start)
    x_targets.append(x_target)


t_spans = torch.tensor(t_spans).float()
x_starts = torch.from_numpy(np.concatenate(x_starts, 0)).float()
x_targets = torch.from_numpy(np.concatenate(x_targets, 0)).float()
t_spans.size(), x_starts.size(), x_targets.size()

train_dataset = data.TensorDataset(x_starts, x_targets, t_spans)

# plot noise examples
plot_ode(
    t_vdps_eval.detach().cpu().numpy(),
    y_n.detach().cpu().numpy(),
    vdp,
)
/tmp/ipykernel_2201/3053862232.py:52: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:274.)
  t_spans = torch.tensor(t_spans).float()

Obtained a sample of trajectories we can train a model on it. Namely, the objective is to adjust the paramters of the vector field in such a way that given a starting point we simulate the dynamic of the system. Note that during training we can not leverage GPU parallelization trought batches as we need to solve a ODE at each step. However, it is possible to save memory by train on section of collected trajectories.

from PIL import Image
import torch
import torch.nn as nn
import torch.utils.data as data
import pytorch_lightning as pl

epochs = 15  # number of optimization iterations
lr = 5e-4
max_steps_per_epoch = 300
trainloader = data.DataLoader(
    train_dataset,
    batch_size=1,
    shuffle=True,
)
model.vf.vf.reset_parameter()


class Learner(pl.LightningModule):

    def __init__(
        self,
        model: nn.Module,
        lr: float,
    ):
        super().__init__()
        self.model, self.lr = model, lr
        self.loss = nn.MSELoss()

    def forward(self, x, t_span, **kwargs):
        return self.model(x, t_span, **kwargs)

    def training_step(self, batch, batch_idx):
        x, y, t_span = batch
        t_span = t_span[0]  # [1, T] -> [T]
        t_eval, y_hat = self.model(x, t_span)
        y = y[0].flatten()
        y_hat = y_hat[:, 0].flatten()
        loss = self.loss(y_hat, y)
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        return {"loss": loss}

    def configure_optimizers(self):
        return torch.optim.AdamW(self.model.parameters(), lr=self.lr)


learner = Learner(
    model,
    lr=lr,
)

trainer = pl.Trainer(
    gradient_clip_val=1.0,
    gradient_clip_algorithm="norm",
    max_epochs=epochs,
    limit_train_batches=max_steps_per_epoch,
)
# trainer.fit(learner, trainloader)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default

Following there is an animation that demonstare how the training process refine the underling vector field to produce always more accurate prediction of the system dinamics.

Classification Problem

Until now we have focus on problem that naturally fit the NeuralODE space, namly we focused on regression problem over trajectories of phisical system where the functional form of the dynamics is learned by a neural network. However, NeuralODEs are not strictly limited to solve regression problems. By assuming arbitrary time-steps: where \(t_0\) and \(t_1\) are respectively the starting and end time of the dynamic system transforming the input data \(x(t_0)\) into a representation \(x(t_1)\) that is usefull to solve a classification problem. NeuralODEs vector fields is limited to work on data of the same dimensionality, in other words \(x(t_0)\) and \(x(t_1) \in \mathbb{R}^n\). However, as NeuralODE are continuous and differentiable, they can be jointly trained with a linear classifier to solve the final task.

For example, given the dataset in ?@fig-half-moon-dataset; it is possible to define the following classifier:

from torch import nn, Tensor
from torchdyn.core import NeuralODE


class VectorField(nn.Module):
    def __init__(
        self,
        input_dim: int,
    ):
        """d - ODE dimensionality"""
        super().__init__()
        self.input_dim = input_dim
        self._f = nn.Sequential(
            nn.Linear(self.input_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, self.input_dim),
        )

        self.reset_parameters()

    def reset_parameters(self):
        for name, param in self.named_parameters():
            if "bias" in name:
                nn.init.constant_(param, val=0.0)

            else:
                nn.init.xavier_normal_(param, gain=nn.init.calculate_gain("relu"))

    def forward(self, t: Tensor, x: Tensor, **kwargs):
        return self._f(x)


class ClassifierODE(nn.Module):
    def __init__(
        self,
        input_dim: int,
        num_classes: int,
        solver="dopri5",
        rtol=1e-5,
        atol=1e-5,
        sensitivity: str = "adjoint",
    ):
        super().__init__()
        self.input_dim = input_dim
        self.num_classes = num_classes

        self.vector_field = VectorField(input_dim=self.input_dim)
        # combine neural ODE
        self.neural_ode = NeuralODE(
            vector_field=self.vector_field,
            solver=solver,
            atol=atol,
            rtol=rtol,
            sensitivity=sensitivity,
            atol_adjoint=atol,
            rtol_adjoint=rtol,
        )
        # and linear classifier
        self.final = nn.Linear(
            self.input_dim,
            self.num_classes,
            bias=False,
        )

    def forward(self, x: Tensor, t_span: Tensor, **kwargs):
        t_eval, trajectory = self.neural_ode(x, t_span, **kwargs)
        return self.final(trajectory[-1])

    def reset_parameters(self):
        self.vector_field.reset_parameters()
        nn.init.xavier_normal_(self.final.weight)


input_dim = 2
num_classes = 1

model = ClassifierODE(
    input_dim=input_dim,
    num_classes=num_classes,
)
model = model.to(device)
Figure 1: Illustration the HalfMoon dataset used for training.
from half_moon import HalfMoonDataset
from torch.utils.data import DataLoader

dataset = HalfMoonDataset(1000)

train_size = int(0.8 * len(dataset))
val_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - val_size

train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
    dataset, [train_size, val_size, test_size]
)

epochs = 15  # number of optimization iterations
lr = 1e-3
model.reset_parameters()
# we can use batch_size since we are using the sampe time span for all the examples
batch_size = 64

train_dataloader = DataLoader(train_dataset, batch_size=batch_size)
valid_dataloader = DataLoader(val_dataset, batch_size=batch_size)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size)


class HalfMoonLearner(pl.LightningModule):
    def __init__(
        self,
        model: nn.Module,
        lr: float,
    ):
        super().__init__()
        self.model, self.lr = model, lr
        self.loss = nn.BCEWithLogitsLoss()
        self.epoch_images = []
        self.t_span = torch.linspace(0.0, 1.0, 50)
        self.y_n = y_n

    def forward(self, x, t_span, **kwargs):
        return self.model(x, t_span, **kwargs)

    def training_step(self, batch, batch_idx):
        x0 = batch["data"]
        y = batch["label"].unsqueeze(-1)
        y_hat = self.model(x0, self.t_span)
        loss = self.loss(y_hat, y)
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        return {"loss": loss}

    def configure_optimizers(self):
        return torch.optim.AdamW(self.model.parameters(), lr=self.lr)


half_moon_learner = HalfMoonLearner(
    model,
    lr=lr,
)

trainer = pl.Trainer(
    gradient_clip_val=1.0,
    gradient_clip_algorithm="norm",
    max_epochs=epochs,
)

# trainer.fit(half_moon_learner, train_dataloader)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

Finally, we can inspect the learned trajectories for classification as we did for the regression problem.

Figure 2: Learned trajectories for a classification problem. Note that the graph represents the evolution trought time each dimentions of the input \(x\). I.e., it is the evolution from \(x(t_0)\) to \(x(t_1)\), but this time it is learned purely based on the class lable and not based on a sampled trajectory.

Resources