Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 1edec77

Browse files
Implementation of proper Sinkhorn-Knopp factorization
1 parent 4563b3c commit 1edec77

File tree

2 files changed

+102
-6
lines changed

2 files changed

+102
-6
lines changed

‎LatentGraphLearning.ipynb‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@
147147
"name": "python",
148148
"nbconvert_exporter": "python",
149149
"pygments_lexer": "ipython3",
150-
"version": "3.11.5"
150+
"version": "3.12.1"
151151
}
152152
},
153153
"nbformat": 4,

‎OptimalTransportWasserteinDistance.ipynb‎

Lines changed: 101 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@
370370
"source": [
371371
"## Probabilistic formulation, doubly stochastic matrices and Sinkhorn's theorem\n",
372372
"\n",
373-
"This part is largly inspired by the following wikipedia article on [sinkhorn theorem](https://en.wikipedia.org/wiki/Sinkhorn%27s_theorem) and [doubly stochastic matrices](https://en.wikipedia.org/wiki/Doubly_stochastic_matrix) as well as this [blog post](https://zipjiang.github.io/2020/11/23/sinkhorn's-theorem-,-sinkhorn-algorithm-and-applications.html)\n",
373+
"This part is largly inspired by the following wikipedia article on [sinkhorn theorem](https://en.wikipedia.org/wiki/Sinkhorn%27s_theorem) and [doubly stochastic matrices](https://en.wikipedia.org/wiki/Doubly_stochastic_matrix) as well as this [blog post](https://zipjiang.github.io/2020/11/23/sinkhorn's-theorem-,-sinkhorn-algorithm-and-applications.html) from Zipjiang, and this [original overview](https://djalil.chafai.net/blog/2021/08/28/sinkhorn-and-circular-law/) by Djalil Chafai\n",
374374
"\n",
375375
"### (Doubly) Stochastic matrices\n",
376376
"A doubly stochastic matrix (also called bistochastic matrix) is a square matrix $X = (x_{ij})$ of nonnegative real numbers, each of whose rows and columns sums to 1, i.e: $ \\sum_{i}x_{ij}=\\sum_{j}x_{ij}=1 $\n",
@@ -390,11 +390,11 @@
390390
"metadata": {},
391391
"outputs": [],
392392
"source": [
393-
"# Sinkhorn method\n",
394-
"\n",
393+
"# Basic sinkhorn method to make a matrix a doubly-stochastic \n",
395394
"def sinkhorn(A, L):\n",
396395
" # Code for calculating the doubly stochastic matrix\n",
397396
" # using Sinkhorn-Knopp algorithm.\n",
397+
" # TODO: \n",
398398
" # ----------\n",
399399
" # Input: positive matrix A[N x N], max iteration L\n",
400400
"\n",
@@ -406,7 +406,7 @@
406406
" col_residual = np.dot(A.T, np.ones(N))-1\n",
407407
"\n",
408408
" # Test for convergence and early stop.\n",
409-
" if np.allclose(col_residual, 0) and np.allclose(row_residual, 0)::\n",
409+
" if np.allclose(col_residual, 0) and np.allclose(row_residual, 0):\n",
410410
" break\n",
411411
"\n",
412412
" return A\n",
@@ -415,6 +415,102 @@
415415
"A = np.random.uniform(0,100,(N,N))"
416416
]
417417
},
418+
{
419+
"cell_type": "code",
420+
"execution_count": 33,
421+
"metadata": {},
422+
"outputs": [
423+
{
424+
"name": "stdout",
425+
"output_type": "stream",
426+
"text": [
427+
"d1:\n",
428+
"[[ 3.99603367 0. 0. ]\n",
429+
" [ 0. 6.98839773 0. ]\n",
430+
" [ 0. 0. 10.02649081]]\n",
431+
"\n",
432+
"S:\n",
433+
"[[0.36826188 0.31559231 0.31614581]\n",
434+
" [0.30884438 0.37479894 0.31635668]\n",
435+
" [0.32289374 0.30960875 0.36749752]]\n",
436+
"\n",
437+
"d2:\n",
438+
"[[1.01930782 0. 0. ]\n",
439+
" [0. 1.03083177 0. ]\n",
440+
" [0. 0. 0.94987111]]\n",
441+
"Sum of rows: [1. 1. 1.]\n",
442+
"Sum of columns: [1. 1. 1.]\n",
443+
"Original matrix: [[1.5 1.3 1.2]\n",
444+
" [2.2 2.7 2.1]\n",
445+
" [3.3 3.2 3.5]]\n",
446+
"Rebuilt matrix: [[1.5 1.3 1.2]\n",
447+
" [2.2 2.7 2.1]\n",
448+
" [3.3 3.2 3.5]]\n"
449+
]
450+
}
451+
],
452+
"source": [
453+
"import numpy as np\n",
454+
"\n",
455+
"def sinkhorn_knopp(matrix, epsilon=1e-10, max_iterations=1000):\n",
456+
" \"\"\"\n",
457+
" Perform matrix factorization using Sinkhorn-Knopp algorithm.\n",
458+
" \n",
459+
" Parameters:\n",
460+
" matrix (numpy.ndarray): Input non-negative matrix.\n",
461+
" epsilon (float): Convergence threshold.\n",
462+
" max_iterations (int): Maximum number of iterations.\n",
463+
" \n",
464+
" Returns:\n",
465+
" d1 (numpy.ndarray): Diagonal matrix 1.\n",
466+
" S (numpy.ndarray): Doubly stochastic matrix.\n",
467+
" d2 (numpy.ndarray): Diagonal matrix 2.\n",
468+
" \"\"\"\n",
469+
" m, n = matrix.shape\n",
470+
" d1 = np.ones(n)\n",
471+
" d2 = np.ones(n)\n",
472+
" S = matrix.copy()\n",
473+
" iter, error = 0, np.inf\n",
474+
" \n",
475+
" # Sinkhorn-Knopp algorithm\n",
476+
" while error > epsilon and iter < max_iterations:\n",
477+
" # Update rows sums (column vector of rows sums)\n",
478+
" RS = np.sum(S, axis=1)\n",
479+
" d1 *= RS\n",
480+
" S = np.dot(np.diag(1/RS), S)\n",
481+
"\n",
482+
" # Update columns sums (row vector of columns sums)\n",
483+
" CS = np.sum(S, axis=0)\n",
484+
" d2 *= CS\n",
485+
" S = np.dot(S, np.diag(1/CS))\n",
486+
"\n",
487+
" iter += 1\n",
488+
"\n",
489+
" # Check convergence\n",
490+
" error = np.linalg.norm(RS - 1) + np.linalg.norm(CS - 1)\n",
491+
"\n",
492+
" return np.diag(d1), S, np.diag(d2)\n",
493+
"\n",
494+
"# Example usage with a non-negative square matrix\n",
495+
"matrix = np.array([[1.5, 1.3, 1.2],\n",
496+
" [2.2, 2.7, 2.1],\n",
497+
" [3.3, 3.2, 3.5]])\n",
498+
"\n",
499+
"d1, S, d2 = sinkhorn_knopp(matrix)\n",
500+
"\n",
501+
"print(\"d1:\")\n",
502+
"print(d1)\n",
503+
"print(\"\\nS:\")\n",
504+
"print(S)\n",
505+
"print(\"\\nd2:\")\n",
506+
"print(d2)\n",
507+
"\n",
508+
"print(f\"Sum of rows: {np.sum(S, axis=1)}\")\n",
509+
"print(f\"Sum of columns: {np.sum(S, axis=0)}\")\n",
510+
"print(f\"Original matrix: {matrix}\")\n",
511+
"print(f\"Rebuilt matrix: {np.dot(d1, np.dot(S, d2))}\")"
512+
]
513+
},
418514
{
419515
"cell_type": "code",
420516
"execution_count": 7,
@@ -1083,7 +1179,7 @@
10831179
"name": "python",
10841180
"nbconvert_exporter": "python",
10851181
"pygments_lexer": "ipython3",
1086-
"version": "3.11.5"
1182+
"version": "3.12.1"
10871183
}
10881184
},
10891185
"nbformat": 4,

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /