Compute optimal transport between unnormalized images / mass distributions on grids
(pgrid
objects) or between mass distributions on general point patterns
(wpp
objects) under the option that mass can be dispose of. Transport cost
per unit is the Euclidean distance of the transport to the p
-th power.
Disposal cost per unit is C^p
.
unbalanced(a, b, ...)# S3 method for pgrid
unbalanced(
a,
b,
p = 1,
C = NULL,
method = c("networkflow", "revsimplex"),
output = c("dist", "all", "rawres"),
threads = 1,
...
)
# S3 method for wpp
unbalanced(
a,
b,
p = 1,
C = NULL,
method = c("networkflow", "revsimplex"),
output = c("dist", "all", "rawres"),
threads = 1,
...
)
If output = "dist"
a single numeric, the unbalanced \((p,C)\)-Wasserstein distance.
Otherwise a list. If output = "all"
the list is of class ut_pgrid
or ut_wpp
according
to the class of the objects a
and b
. It has a
, b
, p
, C
as attributes and
the following components:
same as for output = "dist"
.
an optimal transport plan. This is a data frame with columns from
, to
and mass
that specifies from which element of a
to which element of b
what amount of mass is sent.
from
and to
are specified as vector indices in terms of the usual column major enumeration
of the matrices a$mass and b$mass. The plan can be plotted via plot.pgrid(a, b, plan)
.
matrices (pgrid) or vectors (wpp) specifying the masses transported from each point and to each point, respectively. Corresponds to \((\pi^{(1)}_x)_{x \in S}\) and \((\pi^{(2)}_y)_{y \in S}\) above.
matrices (pgrid) or vectors (wpp) specifying the amount of mass at each point of a
and b
,
respectively, that cannot be transported and needs to be disposed of. Corresponds to
\((a_x - \pi^{(1)}_x)_{x \in S}\) and \((b_y - \pi^{(2)}_y)_{y \in S}\).
(pgrid only) a matrix specifying the amount of mass at each point that can stay in place. Corresponds to \((\pi_{x,x})_{x \in S}\).
Note that atrans + aextra + inplace
(pgrid) or atrans + aextra
(wpp)must be equal
to a$mass
and likewise for b.
A warning occurs if this is not the case (which may indeed happen from time to time for method
revsimplex, but the error reported should be very small).
objects of class pgrid
or wpp
that are compatible.
other arguments.
a power \(\geq 1\) applied to the transport and disposal costs. The order of the resulting unbalanced Wasserstein metric.
The base disposal cost (without the power p
)
one of "networkflow"
and "revsimplex"
, specifing the algorithm used. See details.
character. One of "dist", "all" and "rawres". Determines what the function
returns: only the unbalanced Wasserstein distance; all available information about the
transport plan and the extra mass; or the raw result obtained by the networkflow algorithm.
The latter is the same format as in the transport
function with option fullreturn=TRUE
.
The choice output = "rawres"
is mainly intended for internal use.
an integer specifying the number of threads for parallel computing in connection with the networkflow method.
Given two non-negative mass distributions \(a=(a_x)_{x \in S}\), \(b=(a_y)_{y \in S}\)
on a set \(S\) (a pixel grid / image if a
, b
are of class pgrid
or a more
general weighted point pattern if a
, b
are of class wpp
), this function minimizes the functional
$$\sum_{x,y \in S} \pi_{x,y} d(x,y)^p + C^p \bigl( \sum_{x \in S} (a_x - \pi^{(1)}_x) + \sum_{y \in S} (b_y - \pi^{(2)}_y) \bigr)$$
over all \((\pi_{x,y})_{x,y \in S}\) satisfying
$$0 \leq \pi^{(1)}_x := \sum_{y \in S} \pi_{x,y} \leq a_x \ \textrm{and} \ 0 \leq \pi^{(2)}_y := \sum_{x \in S} \pi_{x,y} \leq b_y.$$
Thus \(\pi_{x,y}\) denotes the amount of mass transported from \(x\) to \(y\), whereas \(\pi^{(1)}_x\) and \(\pi^{(2)}_y\) are the total mass transported away from \(x\) and total mass transported to \(y\), respectively. Accordingly \(\sum_{x \in S} (a_x - \pi^{(1)}_x)\) and \(\sum_{y \in S} (b_y - \pi^{(2)}_y)\) are the total amounts of mass of \(a\) and \(b\), respectively, that need to be disposed of.
The minimal value of the functional above taken to the \(1/p\) is what we refer to as unbalanced \((p,C)\)-Wasserstein metric. This metric is used, in various variants, in an number of research papers. See Heinemann et al. (2022) and the references therein and Müller et al. (2022), Remark 3. We follow the convention of the latter paper regarding the parametrization and the use of the term unbalanced Wasserstein metric.
The practical difference between the two methods "networkflow" and "revsimplex" can
roughly described as follows. The former is typically faster for large examples (for pgrid
objects 64x64
and beyond), especially if several threads are used. The latter is typically faster
for smaller examples (which may be relevant if pairwise transports between many objects
are computed) and it guarantees a sparse(r) solution, i.e. at most \(m+n+1\) individual
transports, where \(m\) and \(n\) are the numbers of non-zero masses in a
and b
, respectively).
Note however that due to the implementation the revsimplex algorithm is a little less
precise (roughly within 1e-7 tolerance). For more details on the algorithms see transport
.
Florian Heinemann, Marcel Klatt and Axel Munk (2022).
Kantorovich-Rubinstein distance and barycenter for finitely supported measures: Foundations and Algorithms.
Arxiv preprint.
tools:::Rd_expr_doi("10.48550/arXiv.2112.03581")
Raoul Müller, Dominic Schuhmacher and Jorge Mateu (2020).
Metrics and barycenters for point pattern data
Statistics and Computing 30, 953-972.
tools:::Rd_expr_doi("10.1007/s11222-020-09932-y")
plot.ut_pgrid
and plot.ut_wpp
, which can plot the various components of the list obtained for output="all"
.
a <- pgrid(matrix(1:12, 3, 4))
b <- pgrid(matrix(c(9:4, 12:7), 3, 4))
res1 <- unbalanced(a, b, 1, 0.5, output="all")
res2 <- unbalanced(a, b, 1, 0.3, output="all")
plot(a, b, res1$plan, angle=20, rot=TRUE)
plot(a, b, res2$plan, angle=20, rot=TRUE)
par(mfrow=c(1,2))
matimage(res2$aextra, x = a$generator[[1]], y = a$generator[[2]])
matimage(res2$bextra, x = b$generator[[1]], y = b$generator[[2]])
set.seed(31)
a <- wpp(matrix(runif(8),4,2), 3:6)
b <- wpp(matrix(runif(10),5,2), 1:5)
res1 <- unbalanced(a, b, 1, 0.5, output="all")
res2 <- unbalanced(a, b, 1, 0.3, output="all")
plot(a, b, res1$plan)
plot(a, b, res2$plan)
Run the code above in your browser using DataLab