amatorch.optim

Routine to fit AMA filters using Gradient Descent.

Functions

fit(model, stimuli, labels, epochs[, ...])

Learn AMA filters using Gradient Descent, with the option to specify a custom loss function.

amatorch.optim.fit(model, stimuli, labels, epochs, loss_fun=None, batch_size=512, learning_rate=0.1, decay_step=1000, decay_rate=1, pairwise=False)

Learn AMA filters using Gradient Descent, with the option to specify a custom loss function.

Parameters:
  • model (AMA model object) – The model used for fitting.

  • stimuli (torch.Tensor) – Stimuli tensor of shape (n_stim, n_channels, n_dim).

  • labels (torch.Tensor) – Label tensor of shape (n_stim).

  • epochs (int) – Number of training epochs.

  • loss_fun (callable, optional) – Loss function that takes in model, stimuli, and labels. Default is negative log posterior at the true category (cross-entropy).

  • batch_size (int, optional) – Batch size, by default 512.

  • learning_rate (float, optional) – Initial learning rate, by default 0.1.

  • decay_step (int, optional) – Number of steps to decay the learning rate, by default 1000.

  • decay_rate (float, optional) – Learning rate decay factor, by default 1.

  • pairwise (bool, optional) – Whether to train the filters in pairs, by default False.

Returns:

  • torch.Tensor – Tensor containing the loss at each epoch (shape: epochs).

  • torch.Tensor – Tensor containing the training time at each epoch (shape: epochs).