amatorch.data_wrangle

Utility functions for wrangling and subsampling data, outputs and statistics.

Functions

statistics_dim_subset(means, covariances, ...)

Keep the statistics for a subset of the dimensions.

subsample_class_points(points, labels, ...)

Return a subsample of points for each class (whichever is smaller of n_per_class or the total points in the class).

subsample_classes(points, labels[, ...])

Return only the points and labels for the classes to keep.

amatorch.data_wrangle.statistics_dim_subset(means, covariances, keep_inds)

Keep the statistics for a subset of the dimensions.

Parameters:
  • means (torch.Tensor) – Means of the data. (n_classes, n_dim)

  • covariances (torch.Tensor) – Covariances of the data. (n_classes, n_dim, n_dim)

  • keep_inds (torch.Tensor) – Indices of the dimensions to keep. (n_keep_dim,)

Returns:

  • means_subset (torch.Tensor) – Means of the data for the subset of dimensions. (n_classes, n_keep_dim)

  • covariances_subset (torch.Tensor) – Covariances of the data for the subset of dimensions. (n_classes, n_keep_dim, n_keep_dim)

amatorch.data_wrangle.subsample_class_points(points, labels, n_per_class)

Return a subsample of points for each class (whichever is smaller of n_per_class or the total points in the class).

Parameters:
  • points (torch.Tensor) – Points to subsample. (n_points, n_dim)

  • labels (torch.Tensor) – Labels or values of the points. (n_points,)

  • n_per_class (int) – Number of points to subsample for each class.

Returns:

  • subsampled_points (torch.Tensor) – Subsampled points. (n_classes * n_points_per_class, n_dim)

  • subsampled_labels (torch.Tensor) – Labels or values of the subsampled points. (n_classes * n_points_per_class,)

amatorch.data_wrangle.subsample_classes(points, labels, classes_to_keep=None)

Return only the points and labels for the classes to keep.

Parameters:
  • points (torch.Tensor) – Points to subsample. (n_points, n_dim)

  • labels (torch.Tensor) – Labels or values of the points. (n_points,)

  • classes_to_keep (list) – List of classes to keep.

Returns:

  • subsampled_points (torch.Tensor) – Subsampled points. (n_points_subsampled, n_dim)

  • subsampled_labels (torch.Tensor) – Labels or values of the subsampled points. (n_points_subsampled,)