neural_dynamics.models
Fully connected neural network with configurable activation functions and Xavier initialisation.
This class provides a standard feedforward architecture that integrates with the parameter system's NetworkArchitecture specification.
The network applies Xavier initialisation in good approximation to all layers and includes configurable activation functions between layers (except the final layer).
Arguments:
- architecture: Network architecture specification defining layer dimensions
- activation: Activation function. If None, defaults to nn.Tanh. Should return a new activation instance when called.
- final_activation: Activation function for the final layer. If None, no activation is applied.
- device: Computation device. Defaults to globally configured device.
Example:
>>> from neural_dynamics.core.hyperparameters import NetworkArchitecture >>> architecture = NetworkArchitecture(input_size=3, hidden_sizes=[64, 32], output_size=2) >>> network = FeedForwardNetwork(architecture, activation=lambda: nn.ReLU()) >>> output = network(torch.randn(10, 3)) # shape: [10, 2]
Initialise the feedforward network with specified architecture.
Arguments:
- architecture: Network layer configuration
- activation: Activation function
- final_activation: Optional activation for the final layer
- device: Computation device
Forward pass through the network.
Arguments:
- input: Input tensor of shape [batch_size, input_dim]
Returns:
Output tensor of shape [batch_size, output_dim]
Raises:
- ValueError: If input dimension doesn't match architecture (maybe this will be raised by PyTorch in any case)
Neural network approximating the drift term mu(x, t, u) in the neural SDE.
Initialise the drift network with specified architecture.
Neural ODE model backed by a trained drift network.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Train the drift network using batched trajectory supervision.
Arguments:
- hyperparameters: Training hyperparameters
- trajectories: Training trajectories [batch, time, state]
- time_grid: Time grid for trajectories
- device: Device for training
- validation_split: Fraction of data to use for validation
- early_stopping_patience: Early stopping patience (None uses hyperparameter default)
- wandb_run: Optional W&B run object for logging training progress
Returns:
Trained NeuralODE instance
Neural network approximating the diffusion term sigma(x, t, u).
Initialise the diffusion network with dimensionality checks.
Neural network that scores trajectories for the WGAN-GP critic.
Initialise the critic with trajectory-specific validation.
Arguments:
- architecture: Network architecture describing the critic network.
- trajectory_length: Number of timesteps in each trajectory segment.
- activation: Factory returning the activation module used between layers.
- device: Device where the network parameters are stored.
Raises:
- ValueError: If the provided configuration is inconsistent.
Forward pass through the network.
Arguments:
- input: Input tensor of shape [batch_size, input_dim]
Returns:
Output tensor of shape [batch_size, output_dim]
Raises:
- ValueError: If input dimension doesn't match architecture (maybe this will be raised by PyTorch in any case)
Neural SDE model composed of drift, diffusion, and critic networks.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Train a neural SDE starting from a pretrained neural ODE.
Arguments:
- hyperparameters: Training hyperparameters
- neural_ode: Pretrained neural ODE to initialize drift network
- trajectories: Training trajectories [batch, time, state]
- time_grid: Time grid for trajectories
- device: Device for training
- enable_adversarial: Whether to enable adversarial training
- random_seed: Random seed for reproducibility
- wandb_run: Optional W&B run object for logging training progress
Returns:
Trained NeuralSDE instance
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Sample a single stochastic trajectory using custom drift/diffusion functions.
This method allows evaluating the SDE with time-varying external inputs by providing custom drift and diffusion functions that capture the full dynamics.
Arguments:
- initial_state: Initial state [batch_size, state_dim] or [state_dim]
- time_grid: Time points for integration [num_steps]
- drift_function: Custom drift function f(t, y) -> drift
- diffusion_function: Custom diffusion function g(t, y) -> diffusion matrix
- device: Device for computation
Returns:
Trajectory tensor [batch_size, num_steps, state_dim]