Learn R Programming

T4transport (version 0.1.2)

sinkhorn: Wasserstein Distance by Entropic Regularization

Description

Due to high computational cost for linear programming approaches to compute Wasserstein distance, cuturi_sinkhorn_2013;textualT4transport proposed an entropic regularization scheme as an efficient approximation to the original problem. This comes with a regularization parameter \(\lambda > 0\) in the term $$\lambda h(\Gamma) = \lambda \sum_{m,n} \Gamma_{m,n} \log (\Gamma_{m,n}).$$ As \(\lambda\rightarrow 0\), the solution to an approximation problem approaches to the solution of a true problem. However, we have an issue with numerical underflow. Our implementation returns an error when it happens, so please use a larger number when necessary.

Usage

sinkhorn(X, Y, p = 2, wx = NULL, wy = NULL, lambda = 0.1, ...)

sinkhornD(D, p = 2, wx = NULL, wy = NULL, lambda = 0.1, ...)

Value

a named list containing

distance

\(\mathcal{W}_p\) distance value.

iteration

the number of iterations it took to converge.

plan

an \((M\times N)\) nonnegative matrix for the optimal transport plan.

Arguments

X

an \((M\times P)\) matrix of row observations.

Y

an \((N\times P)\) matrix of row observations.

p

an exponent for the order of the distance (default: 2).

wx

a length-\(M\) marginal density that sums to \(1\). If NULL (default), uniform weight is set.

wy

a length-\(N\) marginal density that sums to \(1\). If NULL (default), uniform weight is set.

lambda

a regularization parameter (default: 0.1).

...

extra parameters including

maxiter

maximum number of iterations (default: 496).

abstol

stopping criterion for iterations (default: 1e-10).

D

an \((M\times N)\) distance matrix \(d(x_m, y_n)\) between two sets of observations.

References

Examples

Run this code
# \donttest{
#-------------------------------------------------------------------
#  Wasserstein Distance between Samples from Two Bivariate Normal
#
# * class 1 : samples from Gaussian with mean=(-1, -1)
# * class 2 : samples from Gaussian with mean=(+1, +1)
#-------------------------------------------------------------------
## SMALL EXAMPLE
set.seed(100)
m = 20
n = 10
X = matrix(rnorm(m*2, mean=-1),ncol=2) # m obs. for X
Y = matrix(rnorm(n*2, mean=+1),ncol=2) # n obs. for Y

## COMPARE WITH WASSERSTEIN 
outw = wasserstein(X, Y)
skh1 = sinkhorn(X, Y, lambda=0.05)
skh2 = sinkhorn(X, Y, lambda=0.10)

## VISUALIZE : SHOW THE PLAN AND DISTANCE
pm1 = paste0("wasserstein plan ; distance=",round(outw$distance,2))
pm2 = paste0("sinkhorn lbd=0.05; distance=",round(skh1$distance,2))
pm5 = paste0("sinkhorn lbd=0.1 ; distance=",round(skh2$distance,2))

opar <- par(no.readonly=TRUE)
par(mfrow=c(1,3))
image(outw$plan, axes=FALSE, main=pm1)
image(skh1$plan, axes=FALSE, main=pm2)
image(skh2$plan, axes=FALSE, main=pm5)
par(opar)
# }

Run the code above in your browser using DataLab