Learn R Programming

T4transport (version 0.1.2)

swdist: Sliced Wasserstein Distance

Description

Sliced Wasserstein (SW) Distance rabin_2012_WassersteinBarycenterItsT4transport is a popular alternative to the standard Wasserstein distance due to its computational efficiency on top of nice theoretical properties. For the \(d\)-dimensional probability measures \(\mu\) and \(\nu\), the SW distance is defined as $$\mathcal{SW}_p (\mu, \nu) = \left( \int_{\mathbf{S}^{d-1}} \mathcal{W}_p^p ( \langle \theta, \mu\rangle, \langle \theta, \nu \rangle d\lambda (\theta) \right)^{1/p},$$ where \(\mathbf{S}^{d-1}\) is the \((d-1)\)-dimensional unit hypersphere and \(\lambda\) is the uniform distribution on \(\mathbf{S}^{d-1}\). Practically, it is computed via Monte Carlo integration.

Usage

swdist(X, Y, p = 2, ...)

Value

a named list containing

distance

\(\mathcal{SW}_p\) distance value.

projdist

a length-niter vector of projected univariate distances.

Arguments

X

an \((M\times P)\) matrix of row observations.

Y

an \((N\times P)\) matrix of row observations.

p

an exponent for the order of the distance (default: 2).

...

extra parameters including

nproj

the number of Monte Carlo samples for SW computation (default: 496).

References

Examples

Run this code
# \donttest{
#-------------------------------------------------------------------
#  Sliced-Wasserstein Distance between Two Bivariate Normal
#
# * class 1 : samples from Gaussian with mean=(-1, -1)
# * class 2 : samples from Gaussian with mean=(+1, +1)
#-------------------------------------------------------------------
# SMALL EXAMPLE
set.seed(100)
m = 20
n = 30
X = matrix(rnorm(m*2, mean=-1),ncol=2) # m obs. for X
Y = matrix(rnorm(n*2, mean=+1),ncol=2) # n obs. for Y

# COMPUTE THE SLICED-WASSERSTEIN DISTANCE
outsw <- swdist(X, Y, nproj=100)

# VISUALIZE
# prepare ingredients for plotting
plot_x = 1:1000
plot_y = base::cumsum(outsw$projdist)/plot_x

# draw
opar <- par(no.readonly=TRUE)
plot(plot_x, plot_y, type="b", cex=0.1, lwd=2,
     xlab="number of MC samples", ylab="distance",
     main="Effect of MC Sample Size")
abline(h=outsw$distance, col="red", lwd=2)
legend("bottomright", legend="SW Distance", 
       col="red", lwd=2)
par(opar)
# }

Run the code above in your browser using DataLab