Pairwise Wasserstein Distances Along Projection Directions Between Several Empirical Distributions
Source:R/sliced_wasserstein.R
compute_all_distances.RdThis function computes the squared 2-Wasserstein distances between the projections of several empirical directions along specified projection directions. The projection directions can optionally be transformed by means of a linear map. The functions can output the squared distances along each projection direction, or the squared distances averaged over projection directions.
Usage
compute_all_distances(
distributions,
thetas,
A = NULL,
verbose = TRUE,
keep_projections = TRUE
)Arguments
- distributions
A list of matrices representing empirical distributions.
- thetas
A matrix, each row of which represents a projection direction.
- A
Optionally, a matrix used to transform each projection direction.
- verbose
If
TRUE, show progress.- keep_projections
If
TRUE, the distance matrix for each projection direction is output. IfFALSE, the distance matrices for the different projection directions are averaged.
Value
If keep_projections = TRUE, a list of squared-distance matrices, one for each projection direction;
otherwise, a matrix with the averaged squared distances.
Examples
M1 <- matrix(rnorm(50), ncol = 5)
M2 <- matrix(rnorm(150), ncol = 5)
M3 <- matrix(rnorm(250), ncol = 5)
# Sliced Wasserstein:
my_directions <- generate_directions(20, 5)
compute_all_distances(list(M1, M2, M3), my_directions,
keep_projections = FALSE, verbose = FALSE)
#> [,1] [,2] [,3]
#> [1,] 0.0000000 0.2662411 0.2123149
#> [2,] 0.2662411 0.0000000 0.2030276
#> [3,] 0.2123149 0.2030276 0.0000000
# Marginal Wasserstein distances:
marginal_wass <- compute_all_distances(list(M1, M2, M3), diag(1, 5),
keep_projections = TRUE, verbose = FALSE)
marginal_wass[[3]] # along third dimension
#> [,1] [,2] [,3]
#> [1,] 0.0000000 0.11677560 0.15943208
#> [2,] 0.1167756 0.00000000 0.05865423
#> [3,] 0.1594321 0.05865423 0.00000000
# Reweight projection directions
A <- diag(c(4, 0.5, 3, 2, 1))
shear_wass <- compute_all_distances(list(M1, M2, M3), diag(1, 5), A = A,
keep_projections = TRUE, verbose = FALSE)
shear_wass[[3]]
#> [,1] [,2] [,3]
#> [1,] 0.000000 1.0509804 1.4348888
#> [2,] 1.050980 0.0000000 0.5278881
#> [3,] 1.434889 0.5278881 0.0000000