|
370 | 370 | "source": [
|
371 | 371 | "## Probabilistic formulation, doubly stochastic matrices and Sinkhorn's theorem\n",
|
372 | 372 | "\n",
|
373 | | - "This part is largly inspired by the following wikipedia article on [sinkhorn theorem](https://en.wikipedia.org/wiki/Sinkhorn%27s_theorem) and [doubly stochastic matrices](https://en.wikipedia.org/wiki/Doubly_stochastic_matrix) as well as this [blog post](https://zipjiang.github.io/2020/11/23/sinkhorn's-theorem-,-sinkhorn-algorithm-and-applications.html)\n", |
| 373 | + "This part is largly inspired by the following wikipedia article on [sinkhorn theorem](https://en.wikipedia.org/wiki/Sinkhorn%27s_theorem) and [doubly stochastic matrices](https://en.wikipedia.org/wiki/Doubly_stochastic_matrix) as well as this [blog post](https://zipjiang.github.io/2020/11/23/sinkhorn's-theorem-,-sinkhorn-algorithm-and-applications.html) from Zipjiang, and this [original overview](https://djalil.chafai.net/blog/2021/08/28/sinkhorn-and-circular-law/) by Djalil Chafai\n", |
374 | 374 | "\n",
|
375 | 375 | "### (Doubly) Stochastic matrices\n",
|
376 | 376 | "A doubly stochastic matrix (also called bistochastic matrix) is a square matrix $X = (x_{ij})$ of nonnegative real numbers, each of whose rows and columns sums to 1, i.e: $ \\sum_{i}x_{ij}=\\sum_{j}x_{ij}=1 $\n",
|
|
390 | 390 | "metadata": {},
|
391 | 391 | "outputs": [],
|
392 | 392 | "source": [
|
393 | | - "# Sinkhorn method\n", |
394 | | - "\n", |
| 393 | + "# Basic sinkhorn method to make a matrix a doubly-stochastic \n", |
395 | 394 | "def sinkhorn(A, L):\n",
|
396 | 395 | " # Code for calculating the doubly stochastic matrix\n",
|
397 | 396 | " # using Sinkhorn-Knopp algorithm.\n",
|
| 397 | + " # TODO: \n", |
398 | 398 | " # ----------\n",
|
399 | 399 | " # Input: positive matrix A[N x N], max iteration L\n",
|
400 | 400 | "\n",
|
|
406 | 406 | " col_residual = np.dot(A.T, np.ones(N))-1\n",
|
407 | 407 | "\n",
|
408 | 408 | " # Test for convergence and early stop.\n",
|
409 | | - " if np.allclose(col_residual, 0) and np.allclose(row_residual, 0)::\n", |
| 409 | + " if np.allclose(col_residual, 0) and np.allclose(row_residual, 0):\n", |
410 | 410 | " break\n",
|
411 | 411 | "\n",
|
412 | 412 | " return A\n",
|
|
415 | 415 | "A = np.random.uniform(0,100,(N,N))"
|
416 | 416 | ]
|
417 | 417 | },
|
| 418 | + { |
| 419 | + "cell_type": "code", |
| 420 | + "execution_count": 33, |
| 421 | + "metadata": {}, |
| 422 | + "outputs": [ |
| 423 | + { |
| 424 | + "name": "stdout", |
| 425 | + "output_type": "stream", |
| 426 | + "text": [ |
| 427 | + "d1:\n", |
| 428 | + "[[ 3.99603367 0. 0. ]\n", |
| 429 | + " [ 0. 6.98839773 0. ]\n", |
| 430 | + " [ 0. 0. 10.02649081]]\n", |
| 431 | + "\n", |
| 432 | + "S:\n", |
| 433 | + "[[0.36826188 0.31559231 0.31614581]\n", |
| 434 | + " [0.30884438 0.37479894 0.31635668]\n", |
| 435 | + " [0.32289374 0.30960875 0.36749752]]\n", |
| 436 | + "\n", |
| 437 | + "d2:\n", |
| 438 | + "[[1.01930782 0. 0. ]\n", |
| 439 | + " [0. 1.03083177 0. ]\n", |
| 440 | + " [0. 0. 0.94987111]]\n", |
| 441 | + "Sum of rows: [1. 1. 1.]\n", |
| 442 | + "Sum of columns: [1. 1. 1.]\n", |
| 443 | + "Original matrix: [[1.5 1.3 1.2]\n", |
| 444 | + " [2.2 2.7 2.1]\n", |
| 445 | + " [3.3 3.2 3.5]]\n", |
| 446 | + "Rebuilt matrix: [[1.5 1.3 1.2]\n", |
| 447 | + " [2.2 2.7 2.1]\n", |
| 448 | + " [3.3 3.2 3.5]]\n" |
| 449 | + ] |
| 450 | + } |
| 451 | + ], |
| 452 | + "source": [ |
| 453 | + "import numpy as np\n", |
| 454 | + "\n", |
| 455 | + "def sinkhorn_knopp(matrix, epsilon=1e-10, max_iterations=1000):\n", |
| 456 | + " \"\"\"\n", |
| 457 | + " Perform matrix factorization using Sinkhorn-Knopp algorithm.\n", |
| 458 | + " \n", |
| 459 | + " Parameters:\n", |
| 460 | + " matrix (numpy.ndarray): Input non-negative matrix.\n", |
| 461 | + " epsilon (float): Convergence threshold.\n", |
| 462 | + " max_iterations (int): Maximum number of iterations.\n", |
| 463 | + " \n", |
| 464 | + " Returns:\n", |
| 465 | + " d1 (numpy.ndarray): Diagonal matrix 1.\n", |
| 466 | + " S (numpy.ndarray): Doubly stochastic matrix.\n", |
| 467 | + " d2 (numpy.ndarray): Diagonal matrix 2.\n", |
| 468 | + " \"\"\"\n", |
| 469 | + " m, n = matrix.shape\n", |
| 470 | + " d1 = np.ones(n)\n", |
| 471 | + " d2 = np.ones(n)\n", |
| 472 | + " S = matrix.copy()\n", |
| 473 | + " iter, error = 0, np.inf\n", |
| 474 | + " \n", |
| 475 | + " # Sinkhorn-Knopp algorithm\n", |
| 476 | + " while error > epsilon and iter < max_iterations:\n", |
| 477 | + " # Update rows sums (column vector of rows sums)\n", |
| 478 | + " RS = np.sum(S, axis=1)\n", |
| 479 | + " d1 *= RS\n", |
| 480 | + " S = np.dot(np.diag(1/RS), S)\n", |
| 481 | + "\n", |
| 482 | + " # Update columns sums (row vector of columns sums)\n", |
| 483 | + " CS = np.sum(S, axis=0)\n", |
| 484 | + " d2 *= CS\n", |
| 485 | + " S = np.dot(S, np.diag(1/CS))\n", |
| 486 | + "\n", |
| 487 | + " iter += 1\n", |
| 488 | + "\n", |
| 489 | + " # Check convergence\n", |
| 490 | + " error = np.linalg.norm(RS - 1) + np.linalg.norm(CS - 1)\n", |
| 491 | + "\n", |
| 492 | + " return np.diag(d1), S, np.diag(d2)\n", |
| 493 | + "\n", |
| 494 | + "# Example usage with a non-negative square matrix\n", |
| 495 | + "matrix = np.array([[1.5, 1.3, 1.2],\n", |
| 496 | + " [2.2, 2.7, 2.1],\n", |
| 497 | + " [3.3, 3.2, 3.5]])\n", |
| 498 | + "\n", |
| 499 | + "d1, S, d2 = sinkhorn_knopp(matrix)\n", |
| 500 | + "\n", |
| 501 | + "print(\"d1:\")\n", |
| 502 | + "print(d1)\n", |
| 503 | + "print(\"\\nS:\")\n", |
| 504 | + "print(S)\n", |
| 505 | + "print(\"\\nd2:\")\n", |
| 506 | + "print(d2)\n", |
| 507 | + "\n", |
| 508 | + "print(f\"Sum of rows: {np.sum(S, axis=1)}\")\n", |
| 509 | + "print(f\"Sum of columns: {np.sum(S, axis=0)}\")\n", |
| 510 | + "print(f\"Original matrix: {matrix}\")\n", |
| 511 | + "print(f\"Rebuilt matrix: {np.dot(d1, np.dot(S, d2))}\")" |
| 512 | + ] |
| 513 | + }, |
418 | 514 | {
|
419 | 515 | "cell_type": "code",
|
420 | 516 | "execution_count": 7,
|
|
1083 | 1179 | "name": "python",
|
1084 | 1180 | "nbconvert_exporter": "python",
|
1085 | 1181 | "pygments_lexer": "ipython3",
|
1086 | | - "version": "3.11.5" |
| 1182 | + "version": "3.12.1" |
1087 | 1183 | }
|
1088 | 1184 | },
|
1089 | 1185 | "nbformat": 4,
|
|
0 commit comments