neural_dynamics.training.train_loop
def
train_loop( trajectories: torch.Tensor, time_grid: torch.Tensor, hyperparameters: neural_dynamics.core.hyperparameters.Hyperparameters, device: torch.device, enable_adversarial: bool = True) -> Tuple[neural_dynamics.models.NeuralODE, neural_dynamics.models.NeuralSDE, list[float], list[float]]:
Train a Neural SDE model by first training a Neural ODE, then using it to initialise the SDE.
Arguments:
- neural_sde_placeholder: Placeholder argument for compatibility, ignored.
- trajectories: Training trajectories tensor of shape [batch, time, state]
- time_grid: Time grid tensor
- hyperparameters: Hyperparameters for training
- device: Device to train on
- enable_adversarial: Whether to enable adversarial training for the SDE
Returns:
Trained NeuralODE, Trained NeuralSDE, training losses, validation losses