neural_dynamics.training

def train_loop( trajectories: torch.Tensor, time_grid: torch.Tensor, hyperparameters: neural_dynamics.core.hyperparameters.Hyperparameters, device: torch.device, enable_adversarial: bool = True) -> Tuple[neural_dynamics.models.NeuralODE, neural_dynamics.models.NeuralSDE, list[float], list[float]]:

Train a Neural SDE model by first training a Neural ODE, then using it to initialise the SDE.

Arguments:
  • neural_sde_placeholder: Placeholder argument for compatibility, ignored.
  • trajectories: Training trajectories tensor of shape [batch, time, state]
  • time_grid: Time grid tensor
  • hyperparameters: Hyperparameters for training
  • device: Device to train on
  • enable_adversarial: Whether to enable adversarial training for the SDE
Returns:

Trained NeuralODE, Trained NeuralSDE, training losses, validation losses