Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
[MRG] Implement Sinkhorn in log-domain for WDA
* for small values of the regularization parameter (reg) the current implementation runs into numerical issues (nans and infs)

* this can be resolved by using log-domain implementation of the sinkhorn algorithm
  • Loading branch information
Jakub Zadrożny committed Jan 17, 2022
commit 65ba51ae79c89c5fb0e1f2163a6999d3cfcd8d62
23 changes: 16 additions & 7 deletions ot/dr.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,25 @@ def dist(x1, x2):
return x1p2.reshape((-1, 1)) + x2p2.reshape((1, -1)) - 2 * np.dot(x1, x2.T)


def logsumexp(M, axis):
r"""Log-sum-exp reduction compatible with autograd (no numpy implementation)
"""
amax = np.amax(M, axis=axis, keepdims=True)
return np.log(np.sum(np.exp(M - amax), axis=axis)) + np.squeeze(amax, axis=axis)
Comment thread
rflamary marked this conversation as resolved.


def sinkhorn(w1, w2, M, reg, k):
r"""Sinkhorn algorithm with fixed number of iteration (autograd)
r"""Sinkhorn algorithm in log-domain with fixed number of iteration (autograd)
"""
K = np.exp(-M / reg)
ui = np.ones((M.shape[0],))
vi = np.ones((M.shape[1],))
Mr = -M / reg
ui = np.zeros((M.shape[0],))
vi = np.zeros((M.shape[1],))
log_w1 = np.log(w1)
log_w2 = np.log(w2)
for i in range(k):
vi = w2 / (np.dot(K.T, ui))
ui = w1 / (np.dot(K, vi))
G = ui.reshape((M.shape[0], 1)) * K * vi.reshape((1, M.shape[1]))
vi = log_w2 - logsumexp(Mr + ui[:, None], 0)
ui = log_w1 - logsumexp(Mr + vi[None, :], 1)
G = np.exp(ui[:, None] + Mr + vi[None, :])
return G


Expand Down
22 changes: 22 additions & 0 deletions test/test_dr.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,28 @@ def test_wda():
np.testing.assert_allclose(np.sum(Pwda**2, 0), np.ones(p))


@pytest.mark.skipif(nogo, reason="Missing modules (autograd or pymanopt)")
def test_wda_low_reg():

n_samples = 100 # nb samples in source and target datasets
np.random.seed(0)

# generate gaussian dataset
xs, ys = ot.datasets.make_data_classif('gaussrot', n_samples)

n_features_noise = 8

xs = np.hstack((xs, np.random.randn(n_samples, n_features_noise)))

p = 2

Pwda, projwda = ot.dr.wda(xs, ys, p, reg=0.01, maxiter=10)

projwda(xs)

np.testing.assert_allclose(np.sum(Pwda**2, 0), np.ones(p))


@pytest.mark.skipif(nogo, reason="Missing modules (autograd or pymanopt)")
def test_wda_normalized():

Expand Down