-
Notifications
You must be signed in to change notification settings - Fork 88
Add pinv function (Moore-Penrose Pseudo-inverse) #299
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
Draft
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
116 changes: 116 additions & 0 deletions
ndarray-linalg/src/pinv.rs
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
//! Moore-Penrose pseudo-inverse of a Matrices | ||
//! | ||
//! [](https://hadrienj.github.io/posts/Deep-Learning-Book-Series-2.9-The-Moore-Penrose-Pseudoinverse/) | ||
|
||
use crate::{error::*, svd::SVDInplace, types::*}; | ||
use ndarray::*; | ||
use num_traits::Float; | ||
|
||
/// pseudo-inverse of a matrix reference | ||
pub trait Pinv { | ||
type E; | ||
type C; | ||
fn pinv(&self, threshold: Option<Self::E>) -> Result<Self::C>; | ||
} | ||
|
||
/// pseudo-inverse | ||
pub trait PInvInto { | ||
type E; | ||
type C; | ||
fn pinv_into(self, rcond: Option<Self::E>) -> Result<Self::C>; | ||
} | ||
|
||
/// pseudo-inverse for a mutable reference of a matrix | ||
pub trait PInvInplace { | ||
type E; | ||
type C; | ||
fn pinv_inplace(&mut self, rcond: Option<Self::E>) -> Result<Self::C>; | ||
} | ||
|
||
impl<A, S> PInvInto for ArrayBase<S, Ix2> | ||
where | ||
A: Scalar + Lapack, | ||
S: DataMut<Elem = A>, | ||
{ | ||
type E = A::Real; | ||
type C = Array2<A>; | ||
|
||
fn pinv_into(mut self, rcond: Option<Self::E>) -> Result<Self::C> { | ||
self.pinv_inplace(rcond) | ||
} | ||
} | ||
|
||
impl<A, S> Pinv for ArrayBase<S, Ix2> | ||
where | ||
A: Scalar + Lapack, | ||
S: Data<Elem = A>, | ||
{ | ||
type E = A::Real; | ||
type C = Array2<A>; | ||
|
||
fn pinv(&self, rcond: Option<Self::E>) -> Result<Self::C> { | ||
let a = self.to_owned(); | ||
a.pinv_into(rcond) | ||
} | ||
} | ||
|
||
impl<A, S> PInvInplace for ArrayBase<S, Ix2> | ||
where | ||
A: Scalar + Lapack, | ||
S: DataMut<Elem = A>, | ||
{ | ||
type E = A::Real; | ||
type C = Array2<A>; | ||
|
||
fn pinv_inplace(&mut self, rcond: Option<Self::E>) -> Result<Self::C> { | ||
if let (Some(u), s, Some(v_h)) = self.svd_inplace(true, true)? { | ||
// threshold = ε⋅max(m, n)⋅max(Σ) | ||
// NumPy defaults rcond to 1e-15 which is about 10 * f64 machine epsilon | ||
let rcond = rcond.unwrap_or_else(|| { | ||
let (n, m) = self.dim(); | ||
Self::E::epsilon() * Self::E::real(n.max(m)) | ||
}); | ||
let threshold = rcond * s[0]; | ||
|
||
// Determine how many singular values to keep and compute the | ||
// values of `V Σ+` (up to `num_keep` columns). | ||
let (num_keep, v_s_inv) = { | ||
let mut v_h_t = v_h.reversed_axes(); | ||
let mut num_keep = 0; | ||
for (&sing_val, mut v_h_t_col) in s.iter().zip(v_h_t.columns_mut()) { | ||
if sing_val > threshold { | ||
let sing_val_recip = sing_val.recip(); | ||
v_h_t_col.map_inplace(|v_h_t| { | ||
*v_h_t = A::from_real(sing_val_recip) * v_h_t.conj() | ||
}); | ||
num_keep += 1; | ||
} else { | ||
/* | ||
if sing_val != Self::E::real(0.0) { | ||
panic!( | ||
"for {:#?} singular value {:?} smaller then threshold {:?}", | ||
&self, &sing_val, &threshold | ||
); | ||
} | ||
*/ | ||
break; | ||
} | ||
} | ||
v_h_t.slice_axis_inplace(Axis(1), Slice::from(..num_keep)); | ||
(num_keep, v_h_t) | ||
}; | ||
|
||
// Compute `U^H` (up to `num_keep` rows). | ||
let u_h = { | ||
let mut u_t = u.reversed_axes(); | ||
u_t.slice_axis_inplace(Axis(0), Slice::from(..num_keep)); | ||
u_t.map_inplace(|x| *x = x.conj()); | ||
u_t | ||
}; | ||
|
||
Ok(v_s_inv.dot(&u_h)) | ||
} else { | ||
unreachable!() | ||
} | ||
} | ||
} |
27 changes: 27 additions & 0 deletions
ndarray-linalg/src/rank.rs
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
///! Computes the rank of a matrix using single value decomposition | ||
use ndarray::*; | ||
|
||
use super::error::*; | ||
use super::svd::SVD; | ||
use super::types::*; | ||
use num_traits::Float; | ||
|
||
pub trait Rank { | ||
fn rank(&self) -> Result<Ix>; | ||
} | ||
|
||
impl<A, S> Rank for ArrayBase<S, Ix2> | ||
where | ||
A: Scalar + Lapack, | ||
S: Data<Elem = A>, | ||
{ | ||
fn rank(&self) -> Result<Ix> { | ||
let (_, sv, _) = self.svd(false, false)?; | ||
|
||
let (n, m) = self.dim(); | ||
let tol = A::Real::epsilon() * A::Real::real(n.max(m)) * sv[0]; | ||
|
||
let output = sv.iter().take_while(|v| v > &&tol).count(); | ||
Ok(output) | ||
} | ||
} |
117 changes: 117 additions & 0 deletions
ndarray-linalg/tests/pinv.rs
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
use ndarray::arr2; | ||
use ndarray::*; | ||
use ndarray_linalg::*; | ||
use rand::{thread_rng, Rng}; | ||
|
||
/// create a zero rank array | ||
pub fn zero_rank<A, Sh>(sh: Sh) -> Array2<A> | ||
where | ||
A: Scalar + Lapack, | ||
Sh: ShapeBuilder<Dim = Ix2> + Clone, | ||
{ | ||
random_with_rank(sh, 0) | ||
} | ||
|
||
/// create a random matrix with a random partial rank. | ||
pub fn partial_rank<A, Sh>(sh: Sh) -> Array2<A> | ||
where | ||
A: Scalar + Lapack, | ||
Sh: ShapeBuilder<Dim = Ix2> + Clone, | ||
{ | ||
let mut rng = thread_rng(); | ||
let (m, n) = sh.clone().into_shape().raw_dim().into_pattern(); | ||
let min_dim = n.min(m); | ||
let rank = rng.gen_range(1..min_dim); | ||
println!("desired rank = {}", rank); | ||
random_with_rank(sh, rank) | ||
} | ||
|
||
/// create a random matrix and ensures it is full rank. | ||
pub fn full_rank<A, Sh>(sh: Sh) -> Array2<A> | ||
where | ||
A: Scalar + Lapack, | ||
Sh: ShapeBuilder<Dim = Ix2> + Clone, | ||
{ | ||
let (m, n) = sh.clone().into_shape().raw_dim().into_pattern(); | ||
let min_dim = n.min(m); | ||
random_with_rank(sh, min_dim) | ||
} | ||
|
||
fn test<T: Scalar + Lapack>(a: &Array2<T>, tolerance: T::Real) { | ||
println!("a = \n{:?}", &a); | ||
let a_plus: Array2<_> = a.pinv(None).unwrap(); | ||
println!("a_plus = \n{:?}", &a_plus); | ||
let ident = a.dot(&a_plus); | ||
assert_close_l2!(&ident.dot(a), &a, tolerance); | ||
assert_close_l2!(&a_plus.dot(&ident), &a_plus, tolerance); | ||
} | ||
|
||
macro_rules! test_both_impl { | ||
($type:ty, $test:tt, $n:expr, $m:expr, $t:expr) => { | ||
paste::item! { | ||
#[test] | ||
fn [<pinv_test_ $type _ $test _ $n x $m _r>]() { | ||
let a: Array2<$type> = $test(($n, $m)); | ||
test::<$type>(&a, $t); | ||
} | ||
|
||
#[test] | ||
fn [<pinv_test_ $type _ $test _ $n x $m _c>]() { | ||
let a = $test(($n, $m).f()); | ||
test::<$type>(&a, $t); | ||
} | ||
} | ||
}; | ||
} | ||
|
||
macro_rules! test_pinv_impl { | ||
($type:ty, $n:expr, $m:expr, $a:expr) => { | ||
test_both_impl!($type, zero_rank, $n, $m, $a); | ||
test_both_impl!($type, partial_rank, $n, $m, $a); | ||
test_both_impl!($type, full_rank, $n, $m, $a); | ||
}; | ||
} | ||
|
||
test_pinv_impl!(f32, 3, 3, 1e-4); | ||
test_pinv_impl!(f32, 4, 3, 1e-4); | ||
test_pinv_impl!(f32, 3, 4, 1e-4); | ||
|
||
test_pinv_impl!(c32, 3, 3, 1e-4); | ||
test_pinv_impl!(c32, 4, 3, 1e-4); | ||
test_pinv_impl!(c32, 3, 4, 1e-4); | ||
|
||
test_pinv_impl!(f64, 3, 3, 1e-12); | ||
test_pinv_impl!(f64, 4, 3, 1e-12); | ||
test_pinv_impl!(f64, 3, 4, 1e-12); | ||
|
||
test_pinv_impl!(c64, 3, 3, 1e-12); | ||
test_pinv_impl!(c64, 4, 3, 1e-12); | ||
test_pinv_impl!(c64, 3, 4, 1e-12); | ||
|
||
// | ||
// This matrix was taken from 7.1.1 Test1 in | ||
// "On Moore-Penrose Pseudoinverse Computation for Stiffness Matrices Resulting | ||
// from Higher Order Approximation" by Marek Klimczak | ||
// https://doi.org/10.1155/2019/5060397 | ||
// | ||
#[test] | ||
fn pinv_test_single_value_less_then_threshold_3x3() { | ||
#[rustfmt::skip] | ||
let a: Array2<f64> = arr2(&[ | ||
[ 1., -1., 0.], | ||
[-1., 2., -1.], | ||
[ 0., -1., 1.] | ||
], | ||
); | ||
#[rustfmt::skip] | ||
let a_plus_actual: Array2<f64> = arr2(&[ | ||
[ 5. / 9., -1. / 9., -4. / 9.], | ||
[-1. / 9., 2. / 9., -1. / 9.], | ||
[-4. / 9., -1. / 9., 5. / 9.], | ||
], | ||
); | ||
let a_plus: Array2<_> = a.pinv(None).unwrap(); | ||
println!("a_plus -> {:?}", &a_plus); | ||
println!("a_plus_actual -> {:?}", &a_plus); | ||
assert_close_l2!(&a_plus, &a_plus_actual, 1e-15); | ||
} |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.