amatorch.plot.inference

Functions

plot_estimates_statistics(estimates, true_values)

Plot the mean estimates by true value.

scatter_estimates(estimates, true_values[, ...])

Plot the model estimates.

amatorch.plot.inference.plot_estimates_statistics(estimates, true_values, ax=None, ci_bars=False, quantiles=(0.025, 0.975))

Plot the mean estimates by true value.

Parameters:
  • estimates (torch.Tensor) – The estimated value for each stimulus (n_stimuli).

  • labels (torch.int64) – The label for each stimulus (n_stimuli).

  • ax (matplotlib.axes.Axes) – The axes to plot on.

  • ci_bars (bool) – Whether to plot confidence intervals.

amatorch.plot.inference.scatter_estimates(estimates, true_values, jitter=0.0, ax=None, points_per_class=1000)

Plot the model estimates.

Parameters:
  • estimates (torch.Tensor) – The estimated value for each stimulus (n_stimuli).

  • true_values (torch.Tensor) – The true value for each stimulus (n_stimuli).

  • jitter (float) – Amount of noise to add to the true values for plotting.