amatorch.optim
Routine to fit AMA filters using Gradient Descent.
Functions
|
Learn AMA filters using Gradient Descent, with the option to specify a custom loss function. |
- amatorch.optim.fit(model, stimuli, labels, epochs, loss_fun=None, batch_size=512, learning_rate=0.1, decay_step=1000, decay_rate=1, pairwise=False)
Learn AMA filters using Gradient Descent, with the option to specify a custom loss function.
- Parameters:
model (AMA model object) – The model used for fitting.
stimuli (torch.Tensor) – Stimuli tensor of shape (n_stim, n_channels, n_dim).
labels (torch.Tensor) – Label tensor of shape (n_stim).
epochs (int) – Number of training epochs.
loss_fun (callable, optional) – Loss function that takes in model, stimuli, and labels. Default is negative log posterior at the true category (cross-entropy).
batch_size (int, optional) – Batch size, by default 512.
learning_rate (float, optional) – Initial learning rate, by default 0.1.
decay_step (int, optional) – Number of steps to decay the learning rate, by default 1000.
decay_rate (float, optional) – Learning rate decay factor, by default 1.
pairwise (bool, optional) – Whether to train the filters in pairs, by default False.
- Returns:
torch.Tensor – Tensor containing the loss at each epoch (shape: epochs).
torch.Tensor – Tensor containing the training time at each epoch (shape: epochs).