amatorch.data_wrangle
Utility functions for wrangling and subsampling data, outputs and statistics.
Functions
|
Keep the statistics for a subset of the dimensions. |
|
Return a subsample of points for each class (whichever is smaller of n_per_class or the total points in the class). |
|
Return only the points and labels for the classes to keep. |
- amatorch.data_wrangle.statistics_dim_subset(means, covariances, keep_inds)
Keep the statistics for a subset of the dimensions.
- Parameters:
means (torch.Tensor) – Means of the data. (n_classes, n_dim)
covariances (torch.Tensor) – Covariances of the data. (n_classes, n_dim, n_dim)
keep_inds (torch.Tensor) – Indices of the dimensions to keep. (n_keep_dim,)
- Returns:
means_subset (torch.Tensor) – Means of the data for the subset of dimensions. (n_classes, n_keep_dim)
covariances_subset (torch.Tensor) – Covariances of the data for the subset of dimensions. (n_classes, n_keep_dim, n_keep_dim)
- amatorch.data_wrangle.subsample_class_points(points, labels, n_per_class)
Return a subsample of points for each class (whichever is smaller of n_per_class or the total points in the class).
- Parameters:
points (torch.Tensor) – Points to subsample. (n_points, n_dim)
labels (torch.Tensor) – Labels or values of the points. (n_points,)
n_per_class (int) – Number of points to subsample for each class.
- Returns:
subsampled_points (torch.Tensor) – Subsampled points. (n_classes * n_points_per_class, n_dim)
subsampled_labels (torch.Tensor) – Labels or values of the subsampled points. (n_classes * n_points_per_class,)
- amatorch.data_wrangle.subsample_classes(points, labels, classes_to_keep=None)
Return only the points and labels for the classes to keep.
- Parameters:
points (torch.Tensor) – Points to subsample. (n_points, n_dim)
labels (torch.Tensor) – Labels or values of the points. (n_points,)
classes_to_keep (list) – List of classes to keep.
- Returns:
subsampled_points (torch.Tensor) – Subsampled points. (n_points_subsampled, n_dim)
subsampled_labels (torch.Tensor) – Labels or values of the subsampled points. (n_points_subsampled,)