Neural ODEs

Author

Sandro Cavallari

Published

April 10, 2024

Ordinalry Differential Equations

Ordinalry Differential Equations or ODEs are equations with a single indipendent varlaible (usually called time t) and one or more derivatives of functions defined in terms of the indipendent variable. Formally,

f(x(t),t)=x(t)=x(t)t=limΔt0x(t+Δt)x(t)Δt

where:

  • t represent the time or any other indipendent variable used as domain of the derivate operator;
  • x(t)Rd is the dependent variable defining the system’s state;
  • x(t)Rd is the first order derivative of x(t);
  • fRd×R+ is the vector-field differential functioin describing the systems’s evolution over time.

As f defines the evolution of a complex system over every infinitesimal interval of time Δt, we can formally define an ODE problem as:

x(t+Δt)=x(t)+Δtf(x(t),t)s.t.  Δt0.

Computing a Solution

In the general term, the solution to an ODE involve to computetion of the antiderivative of f, in other words the integral of f. As the integral of any function involve an arbitrary constant, usually defined as C, there is the need to specify an initial condition x0 to guarantee that the solution of the ODE is unique:

xt=x0+0tf(xτ)τ

Note that, given the inital state x0 and a set of points in time {t0,...,tN}, the objective is to obtain the state solution x0:N{x0,...,xN}. Unfortunetly, solving the above integral analytically is possible only for a limited amount of differential functions. Therefore, numerical solvers are used in practice.

import numpy as np
import torch
import torch.nn as nn
from torchdyn.numerics import odeint

torch.manual_seed(42)
np.random.seed(42)
device = torch.device("cpu")


class VDPoscillator(nn.Module):
    def __init__(self, mu: float) -> None:
        super().__init__()
        self.mu = mu

    def forward(
        self,
        t: torch.Tensor,
        x: torch.Tensor,
    ):
        var_0 = x[..., 1]
        var_1 = self.mu * (1 - x[..., 0] ** 2) * x[..., 1] - x[..., 0]
        return torch.stack(
            (var_0, var_1),
            dim=-1,
        )


from typing import Callable
import numpy as np
import torch
from plot_vector_field import plot_vdp_animation

vdp = VDPoscillator(1.0).to(device)

# initial value, of shape [N, d]
x0 = torch.tensor([[1.0, 0.0]]).float().to(device)

# integration time points, of shape [T]
t_span = torch.linspace(0.0, 15.0, 500).to(device)

with torch.no_grad():
    t_eval, vdp_sol_euler = odeint(
        vdp, x0, t_span=t_span, solver="euler"
    )  # [T], [T, N, d]

t_eval.size(), vdp_sol_euler.size()

# anim = plot_vdp_animation(
#     t=t_eval.detach().cpu().numpy(),
#     X=sol_euler.detach().cpu().numpy(),
#     ode_rhs=vdp,
# )
(torch.Size([500]), torch.Size([500, 1, 2]))
Video Player is loading.
Current Time 0:00
Duration -:-
Loaded: 0%
Stream Type LIVE
Remaining Time -:-
 
1x
    • Chapters
    • descriptions off, selected
    • captions off, selected

      As overmentioned the solution depend on the initial state of the system. Following there is a demonstration of how the system will behave if you used different starting points.

      from plot_vector_field import plot_ode
      
      x0 = (
          torch.tensor(
              [
                  [1.0, 0.0],
                  [-2.0, -3.0],
                  [-2.0, 3.0],
              ]
          )
          .float()
          .to(device)
      )
      
      # integration time points, of shape [T]
      t_span = torch.linspace(0.0, 15.0, 500).to(device)
      
      # forward integration
      with torch.no_grad():
          t_vdps_eval, vdps_sol_euler = odeint(vdp, x0, t_span=t_span, solver="euler")
      
      plot_ode(
          t=t_vdps_eval.detach().cpu().numpy(),
          X=vdps_sol_euler.detach().cpu().numpy(),
          ode_rhs=vdp,
      )

      Neural ODEs

      Neural ODEs are a family of ODEs for which the vector-field f(xt,t) is defined by a neural network. As such, f(xt,t) is both differentiable and learnable. Thus, given a set of observation y0:N from a unknwon dynamical system, we can used it to learn a model of the evolution of the system’s dynamics.

      Problem Formulation

      Given a dataset containing noise observation yn where each observation is the perturbation of an unknown state xn generated by an unknown underling dynamics ftrue:

      yn=xn+ϵ,  ϵN(0,σ2)yn is a noise variablexn=x0+0nftrue(xτ)τ

      The objective is to learn a neural network fθ that matches the unknown dynamics:

      fθftrue.

      from torch import nn, Tensor
      from torchdyn.core import NeuralODE
      
      
      class VectorField(nn.Module):
          def __init__(self, d: int):
              """d - ODE dimensionality"""
              super().__init__()
              self.d = d
              self._f = nn.Sequential(
                  nn.Linear(d, 20),
                  nn.ReLU(),
                  nn.Linear(20, 20),
                  nn.ReLU(),
                  nn.Linear(20, d),
              )
              self.reset_parameter()
      
          def reset_parameter(self):
              for name, param in self.named_parameters():
                  if "bias" in name:
                      nn.init.constant_(param, 0.0)
      
                  else:
                      nn.init.xavier_uniform_(param, gain=nn.init.calculate_gain("relu"))
      
          def forward(self, t: Tensor, x: Tensor, **kwargs):
              """Forward integrates the NODE system and returns state solutions
              Input
                  ts - [T]   time points
                  x0 - [N,d] initial value
              Returns
                  X  - [N,d] forward simulated states
              """
              return self._f(x)
      
      
      # define vector-field
      field = VectorField(2).to(device)
      model = NeuralODE(field, solver="euler").to(device)
      
      # let's compute the integral of our neural net!
      x0 = torch.tensor([[1.0, 0.0]]).float().to(device)
      t_span = torch.linspace(0.0, 1.0, 100).to(device)
      
      t_eval, trajectory_init = model(x0, t_span)
      
      plot_ode(
          t_eval.detach().cpu().numpy(),
          trajectory_init.detach().cpu().numpy(),
          field.forward,
      )

      VDP Learning

      In the previous example fθ is randomly initialized, thus the resulting vector-field does not exibit any interesting behaviour. As abovementioned, the objective is to learn such fθ from some observations yn. Getting a state solution xn involve solving an ODE defined by fθ. Similarly, optimizing θ involves computing the gradinets w.r.t. xn. As differentiating through the ODE solver is computationally inefficent, the gradients of the paramters are obtained with the Adjoint State Method. Thus, we can adopt maximum-likelihood estimation (MLE) to train our neural ODE:

      argminθL(θ)=12n=0N||ynxn||22s.t.  xn=x0+0nfθ(xτ)τ

      In order to train our fθ, some observations of a systems are needed. Thus, next some noise observation are generated from the original VDP oscillator.

      import torch.utils.data as data
      
      # generate noise examples
      y_n = vdps_sol_euler + torch.randn_like(vdps_sol_euler) * 0.05
      
      observations_flatten = y_n.permute(1, 0, 2).reshape(-1, 2)
      t_eval = t_vdps_eval.repeat(3)
      
      t_steps = 50
      dataset_size = 10
      
      
      # function that extract portion of a batch of trajectory
      def get_example(t_eval_, observations_, t_steps):
          T = t_eval_.size(0) // 3
          example_id = torch.multinomial(torch.ones(3) / 3, 1).item()
      
          t_start = example_id * T
          t_end = (example_id + 1) * T
          t0 = torch.randint(
              t_start, (t_end - t_steps), (1,)
          ).item()  # pick the initial value
      
          (
              t_span_,
              x_start_,
              x_targets_,
          ) = (
              t_eval_[t0 : t0 + t_steps],
              observations_[t0],
              observations_[t0 : t0 + t_steps],
          )  # pick subsequences
      
          return (
              t_span_.detach().cpu().numpy(),
              x_start_.unsqueeze(0).detach().cpu().numpy(),
              x_targets_.unsqueeze(0).detach().cpu().numpy(),
          )
      
      
      t_spans = []
      x_starts = []
      x_targets = []
      
      for i in range(dataset_size):
          t_span, x_start, x_target = get_example(t_eval, observations_flatten, t_steps)
          t_spans.append(t_span)
          x_starts.append(x_start)
          x_targets.append(x_target)
      
      
      t_spans = torch.tensor(t_spans).float()
      x_starts = torch.from_numpy(np.concatenate(x_starts, 0)).float()
      x_targets = torch.from_numpy(np.concatenate(x_targets, 0)).float()
      t_spans.size(), x_starts.size(), x_targets.size()
      
      train_dataset = data.TensorDataset(x_starts, x_targets, t_spans)
      
      # plot noise examples
      plot_ode(
          t_vdps_eval.detach().cpu().numpy(),
          y_n.detach().cpu().numpy(),
          vdp,
      )
      /tmp/ipykernel_2423/3053862232.py:52: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at /pytorch/torch/csrc/utils/tensor_new.cpp:253.)
        t_spans = torch.tensor(t_spans).float()

      Obtained a sample of trajectories we can train a model on it. Namely, the objective is to adjust the paramters of the vector field in such a way that given a starting point we simulate the dynamic of the system. Note that during training we can not leverage GPU parallelization trought batches as we need to solve a ODE at each step. However, it is possible to save memory by train on section of collected trajectories.

      from PIL import Image
      import torch
      import torch.nn as nn
      import torch.utils.data as data
      import pytorch_lightning as pl
      
      epochs = 15  # number of optimization iterations
      lr = 5e-4
      max_steps_per_epoch = 300
      trainloader = data.DataLoader(
          train_dataset,
          batch_size=1,
          shuffle=True,
      )
      model.vf.vf.reset_parameter()
      
      
      class Learner(pl.LightningModule):
      
          def __init__(
              self,
              model: nn.Module,
              lr: float,
          ):
              super().__init__()
              self.model, self.lr = model, lr
              self.loss = nn.MSELoss()
      
          def forward(self, x, t_span, **kwargs):
              return self.model(x, t_span, **kwargs)
      
          def training_step(self, batch, batch_idx):
              x, y, t_span = batch
              t_span = t_span[0]  # [1, T] -> [T]
              t_eval, y_hat = self.model(x, t_span)
              y = y[0].flatten()
              y_hat = y_hat[:, 0].flatten()
              loss = self.loss(y_hat, y)
              self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
              return {"loss": loss}
      
          def configure_optimizers(self):
              return torch.optim.AdamW(self.model.parameters(), lr=self.lr)
      
      
      learner = Learner(
          model,
          lr=lr,
      )
      
      trainer = pl.Trainer(
          gradient_clip_val=1.0,
          gradient_clip_algorithm="norm",
          max_epochs=epochs,
          limit_train_batches=max_steps_per_epoch,
      )
      # trainer.fit(learner, trainloader)
      💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
      GPU available: False, used: False
      TPU available: False, using: 0 TPU cores
      HPU available: False, using: 0 HPUs
      /opt/hostedtoolcache/Python/3.10.18/x64/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default

      Following there is an animation that demonstare how the training process refine the underling vector field to produce always more accurate prediction of the system dinamics.

      Video Player is loading.
      Current Time 0:00
      Duration -:-
      Loaded: 0%
      Stream Type LIVE
      Remaining Time -:-
       
      1x
        • Chapters
        • descriptions off, selected
        • captions off, selected

          Classification Problem

          Until now we have focus on problem that naturally fit the NeuralODE space, namly we focused on regression problem over trajectories of phisical system where the functional form of the dynamics is learned by a neural network. However, NeuralODEs are not strictly limited to solve regression problems. By assuming arbitrary time-steps: where t0 and t1 are respectively the starting and end time of the dynamic system transforming the input data x(t0) into a representation x(t1) that is usefull to solve a classification problem. NeuralODEs vector fields is limited to work on data of the same dimensionality, in other words x(t0) and x(t1)Rn. However, as NeuralODE are continuous and differentiable, they can be jointly trained with a linear classifier to solve the final task.

          For example, given the dataset in ?@fig-half-moon-dataset; it is possible to define the following classifier:

          from torch import nn, Tensor
          from torchdyn.core import NeuralODE
          
          
          class VectorField(nn.Module):
              def __init__(
                  self,
                  input_dim: int,
              ):
                  """d - ODE dimensionality"""
                  super().__init__()
                  self.input_dim = input_dim
                  self._f = nn.Sequential(
                      nn.Linear(self.input_dim, 64),
                      nn.ReLU(),
                      nn.Linear(64, 64),
                      nn.ReLU(),
                      nn.Linear(64, self.input_dim),
                  )
          
                  self.reset_parameters()
          
              def reset_parameters(self):
                  for name, param in self.named_parameters():
                      if "bias" in name:
                          nn.init.constant_(param, val=0.0)
          
                      else:
                          nn.init.xavier_normal_(param, gain=nn.init.calculate_gain("relu"))
          
              def forward(self, t: Tensor, x: Tensor, **kwargs):
                  return self._f(x)
          
          
          class ClassifierODE(nn.Module):
              def __init__(
                  self,
                  input_dim: int,
                  num_classes: int,
                  solver="dopri5",
                  rtol=1e-5,
                  atol=1e-5,
                  sensitivity: str = "adjoint",
              ):
                  super().__init__()
                  self.input_dim = input_dim
                  self.num_classes = num_classes
          
                  self.vector_field = VectorField(input_dim=self.input_dim)
                  # combine neural ODE
                  self.neural_ode = NeuralODE(
                      vector_field=self.vector_field,
                      solver=solver,
                      atol=atol,
                      rtol=rtol,
                      sensitivity=sensitivity,
                      atol_adjoint=atol,
                      rtol_adjoint=rtol,
                  )
                  # and linear classifier
                  self.final = nn.Linear(
                      self.input_dim,
                      self.num_classes,
                      bias=False,
                  )
          
              def forward(self, x: Tensor, t_span: Tensor, **kwargs):
                  t_eval, trajectory = self.neural_ode(x, t_span, **kwargs)
                  return self.final(trajectory[-1])
          
              def reset_parameters(self):
                  self.vector_field.reset_parameters()
                  nn.init.xavier_normal_(self.final.weight)
          
          
          input_dim = 2
          num_classes = 1
          
          model = ClassifierODE(
              input_dim=input_dim,
              num_classes=num_classes,
          )
          model = model.to(device)
          Figure 1: Illustration the HalfMoon dataset used for training.
          from half_moon import HalfMoonDataset
          from torch.utils.data import DataLoader
          
          dataset = HalfMoonDataset(1000)
          
          train_size = int(0.8 * len(dataset))
          val_size = int(0.1 * len(dataset))
          test_size = len(dataset) - train_size - val_size
          
          train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
              dataset, [train_size, val_size, test_size]
          )
          
          epochs = 15  # number of optimization iterations
          lr = 1e-3
          model.reset_parameters()
          # we can use batch_size since we are using the sampe time span for all the examples
          batch_size = 64
          
          train_dataloader = DataLoader(train_dataset, batch_size=batch_size)
          valid_dataloader = DataLoader(val_dataset, batch_size=batch_size)
          test_dataloader = DataLoader(test_dataset, batch_size=batch_size)
          
          
          class HalfMoonLearner(pl.LightningModule):
              def __init__(
                  self,
                  model: nn.Module,
                  lr: float,
              ):
                  super().__init__()
                  self.model, self.lr = model, lr
                  self.loss = nn.BCEWithLogitsLoss()
                  self.epoch_images = []
                  self.t_span = torch.linspace(0.0, 1.0, 50)
                  self.y_n = y_n
          
              def forward(self, x, t_span, **kwargs):
                  return self.model(x, t_span, **kwargs)
          
              def training_step(self, batch, batch_idx):
                  x0 = batch["data"]
                  y = batch["label"].unsqueeze(-1)
                  y_hat = self.model(x0, self.t_span)
                  loss = self.loss(y_hat, y)
                  self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
                  return {"loss": loss}
          
              def configure_optimizers(self):
                  return torch.optim.AdamW(self.model.parameters(), lr=self.lr)
          
          
          half_moon_learner = HalfMoonLearner(
              model,
              lr=lr,
          )
          
          trainer = pl.Trainer(
              gradient_clip_val=1.0,
              gradient_clip_algorithm="norm",
              max_epochs=epochs,
          )
          
          # trainer.fit(half_moon_learner, train_dataloader)
          💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
          GPU available: False, used: False
          TPU available: False, using: 0 TPU cores
          HPU available: False, using: 0 HPUs

          Finally, we can inspect the learned trajectories for classification as we did for the regression problem.

          Figure 2: Learned trajectories for a classification problem. Note that the graph represents the evolution trought time each dimentions of the input x. I.e., it is the evolution from x(t0) to x(t1), but this time it is learned purely based on the class lable and not based on a sampled trajectory.

          Resources