amatorch.plot
Utilities for plotting model parameters (filters and statistics) and model outputs (responses, posteriors and estimates).
- amatorch.plot.draw_color_bar(colormap, limits, fig, title=None)
Draw a color bar for the given colormap and limits.
- Parameters:
colormap (str or matplotlib.colors.Colormap) – The colormap to use.
limits (list) – The minimum and maximum values for the color scale.
fig (matplotlib.figure.Figure) – The figure to draw the color bar on.
- Returns:
color_bar – The created color bar.
- Return type:
ColorbarBase
- amatorch.plot.plot_estimates_statistics(estimates, true_values, ax=None, ci_bars=False, quantiles=(0.025, 0.975))
Plot the mean estimates by true value.
- Parameters:
estimates (torch.Tensor) – The estimated value for each stimulus (n_stimuli).
labels (torch.int64) – The label for each stimulus (n_stimuli).
ax (matplotlib.axes.Axes) – The axes to plot on.
ci_bars (bool) – Whether to plot confidence intervals.
- amatorch.plot.plot_filters(model, n_cols=2, n_filters=10)
Plot the filters of an AMA model.
- Parameters:
model (AMA model) – The model to plot the filters for.
n_cols (int, optional) – Number of columns in the grid layout. Default is 3.
n_filters (int, optional) – Number of filters to plot. Default is 10.
- amatorch.plot.scatter_estimates(estimates, true_values, jitter=0.0, ax=None, points_per_class=1000)
Plot the model estimates.
- Parameters:
estimates (torch.Tensor) – The estimated value for each stimulus (n_stimuli).
true_values (torch.Tensor) – The true value for each stimulus (n_stimuli).
jitter (float) – Amount of noise to add to the true values for plotting.
- amatorch.plot.scatter_responses(responses, labels, ax=None, values=None, filter_pair=(0, 1), n_points=1000, classes_plot=None, legend_type='none', **kwargs)
Plot scatter of the responses to different categories.
- Parameters:
responses (torch.Tensor) – Responses to the stimuli. Shape (n_stimuli, n_filters).
labels (torch.int64) – Class labels of each point with shape (n_points).
ax (matplotlib.axes.Axes, optional) – Axes to plot the scatter. If None, a new figure is created. The default is None.
values (torch.Tensor, optional) – Values to color the classes. The default is linearly spaced values between -1 and 1.
filter_pair (tuple, optional) – Pair of filters to plot. The default is (0, 1).
n_points (int, optional) – Number of points per class to plot. The default is 1000.
classes_plot (list, optional) – List of classes to plot. The default is all classes.
legend_type (str, optional) – Type of legend to add: ‘none’, ‘continuous’, ‘discrete’.
- Returns:
ax – Axes with the scatter plot.
- Return type:
matplotlib.axes.Axes
- amatorch.plot.statistics_ellipses(means, covariances, filter_pair=(0, 1), ax=None, values=None, classes_plot=None, color_map='viridis', legend_type='none', **kwargs)
Plot the ellipses of the filter response statistics across classes.
- Parameters:
means (torch.Tensor) – Means of the filter responses. Shape (n_classes, n_filters).
covariances (torch.Tensor) – Covariances of the filter responses. Shape (n_classes, n_filters, n_filters).
filter_pair (tuple of int, optional) – Pair of filters to plot. The default is [0, 1].
ax (matplotlib.axes.Axes, optional) – Axes to plot the ellipses. If None, a new figure is created. The default is None.
values (torch.Tensor, optional) – Values to color code the ellipses. Each value corresponds to a class. The default is linearly spaced values between -1 and 1.
classes_plot (list, optional) – List of classes to plot. The default is all classes.
color_map (str or matplotlib.colors.Colormap, optional) – Color map to use for the ellipses. The default is ‘viridis’.
legend_type (str, optional) – Type of legend to add: ‘none’, ‘continuous’, ‘discrete’.
- Returns:
ax – Axes with the scatter plot.
- Return type:
matplotlib.axes.Axes
Modules