Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 4913818

Browse files
Merge Tridiagonal_ into Lapack
1 parent d9c52a2 commit 4913818

File tree

2 files changed

+42
-148
lines changed

2 files changed

+42
-148
lines changed

‎lax/src/lib.rs‎

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -84,14 +84,14 @@ extern crate openblas_src as _src;
8484
#[cfg(any(feature = "netlib-system", feature = "netlib-static"))]
8585
extern crate netlib_src as _src;
8686

87-
pub mod error;
88-
pub mod flags;
89-
pub mod layout;
90-
87+
pub mod alloc;
9188
pub mod cholesky;
9289
pub mod eig;
9390
pub mod eigh;
9491
pub mod eigh_generalized;
92+
pub mod error;
93+
pub mod flags;
94+
pub mod layout;
9595
pub mod least_squares;
9696
pub mod opnorm;
9797
pub mod qr;
@@ -101,16 +101,12 @@ pub mod solveh;
101101
pub mod svd;
102102
pub mod svddc;
103103
pub mod triangular;
104+
pub mod tridiagonal;
104105

105-
mod alloc;
106-
mod tridiagonal;
107-
108-
pub use self::cholesky::*;
109106
pub use self::flags::*;
110107
pub use self::least_squares::LeastSquaresOwned;
111-
pub use self::opnorm::*;
112108
pub use self::svd::{SvdOwned, SvdRef};
113-
pub use self::tridiagonal::*;
109+
pub use self::tridiagonal::{LUFactorizedTridiagonal,Tridiagonal};
114110

115111
use self::{alloc::*, error::*, layout::*};
116112
use cauchy::*;
@@ -120,7 +116,7 @@ pub type Pivot = Vec<i32>;
120116

121117
#[cfg_attr(doc, katexit::katexit)]
122118
/// Trait for primitive types which implements LAPACK subroutines
123-
pub trait Lapack: Tridiagonal_ {
119+
pub trait Lapack: Scalar {
124120
/// Compute right eigenvalue and eigenvectors for a general matrix
125121
fn eig(
126122
calc_v: bool,
@@ -306,6 +302,19 @@ pub trait Lapack: Tridiagonal_ {
306302
a: &[Self],
307303
b: &mut [Self],
308304
) -> Result<()>;
305+
306+
/// Computes the LU factorization of a tridiagonal `m x n` matrix `a` using
307+
/// partial pivoting with row interchanges.
308+
fn lu_tridiagonal(a: Tridiagonal<Self>) -> Result<LUFactorizedTridiagonal<Self>>;
309+
310+
fn rcond_tridiagonal(lu: &LUFactorizedTridiagonal<Self>) -> Result<Self::Real>;
311+
312+
fn solve_tridiagonal(
313+
lu: &LUFactorizedTridiagonal<Self>,
314+
bl: MatrixLayout,
315+
t: Transpose,
316+
b: &mut [Self],
317+
) -> Result<()>;
309318
}
310319

311320
macro_rules! impl_lapack {
@@ -491,6 +500,28 @@ macro_rules! impl_lapack {
491500
use triangular::*;
492501
SolveTriangularImpl::solve_triangular(al, bl, uplo, d, a, b)
493502
}
503+
504+
fn lu_tridiagonal(a: Tridiagonal<Self>) -> Result<LUFactorizedTridiagonal<Self>> {
505+
use tridiagonal::*;
506+
let work = LuTridiagonalWork::<$s>::new(a.l);
507+
work.eval(a)
508+
}
509+
510+
fn rcond_tridiagonal(lu: &LUFactorizedTridiagonal<Self>) -> Result<Self::Real> {
511+
use tridiagonal::*;
512+
let mut work = RcondTridiagonalWork::<$s>::new(lu.a.l);
513+
work.calc(lu)
514+
}
515+
516+
fn solve_tridiagonal(
517+
lu: &LUFactorizedTridiagonal<Self>,
518+
bl: MatrixLayout,
519+
t: Transpose,
520+
b: &mut [Self],
521+
) -> Result<()> {
522+
use tridiagonal::*;
523+
SolveTridiagonalImpl::solve_tridiagonal(lu, bl, t, b)
524+
}
494525
}
495526
};
496527
}

‎lax/src/tridiagonal/mod.rs‎

Lines changed: 0 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -10,140 +10,3 @@ pub use lu::*;
1010
pub use matrix::*;
1111
pub use rcond::*;
1212
pub use solve::*;
13-
14-
use crate::{error::*, layout::*, *};
15-
use cauchy::*;
16-
use num_traits::Zero;
17-
18-
/// Wraps `*gttrf`, `*gtcon` and `*gttrs`
19-
pub trait Tridiagonal_: Scalar + Sized {
20-
/// Computes the LU factorization of a tridiagonal `m x n` matrix `a` using
21-
/// partial pivoting with row interchanges.
22-
fn lu_tridiagonal(a: Tridiagonal<Self>) -> Result<LUFactorizedTridiagonal<Self>>;
23-
24-
fn rcond_tridiagonal(lu: &LUFactorizedTridiagonal<Self>) -> Result<Self::Real>;
25-
26-
fn solve_tridiagonal(
27-
lu: &LUFactorizedTridiagonal<Self>,
28-
bl: MatrixLayout,
29-
t: Transpose,
30-
b: &mut [Self],
31-
) -> Result<()>;
32-
}
33-
34-
macro_rules! impl_tridiagonal {
35-
(@real, $scalar:ty, $gttrf:path, $gtcon:path, $gttrs:path) => {
36-
impl_tridiagonal!(@body, $scalar, $gttrf, $gtcon, $gttrs, iwork);
37-
};
38-
(@complex, $scalar:ty, $gttrf:path, $gtcon:path, $gttrs:path) => {
39-
impl_tridiagonal!(@body, $scalar, $gttrf, $gtcon, $gttrs, );
40-
};
41-
(@body, $scalar:ty, $gttrf:path, $gtcon:path, $gttrs:path, $($iwork:ident)*) => {
42-
impl Tridiagonal_ for $scalar {
43-
fn lu_tridiagonal(mut a: Tridiagonal<Self>) -> Result<LUFactorizedTridiagonal<Self>> {
44-
let (n, _) = a.l.size();
45-
let mut du2 = vec_uninit( (n - 2) as usize);
46-
let mut ipiv = vec_uninit( n as usize);
47-
// We have to calc one-norm before LU factorization
48-
let a_opnorm_one = a.opnorm_one();
49-
let mut info = 0;
50-
unsafe {
51-
$gttrf(
52-
&n,
53-
AsPtr::as_mut_ptr(&mut a.dl),
54-
AsPtr::as_mut_ptr(&mut a.d),
55-
AsPtr::as_mut_ptr(&mut a.du),
56-
AsPtr::as_mut_ptr(&mut du2),
57-
AsPtr::as_mut_ptr(&mut ipiv),
58-
&mut info,
59-
)
60-
};
61-
info.as_lapack_result()?;
62-
let du2 = unsafe { du2.assume_init() };
63-
let ipiv = unsafe { ipiv.assume_init() };
64-
Ok(LUFactorizedTridiagonal {
65-
a,
66-
du2,
67-
ipiv,
68-
a_opnorm_one,
69-
})
70-
}
71-
72-
fn rcond_tridiagonal(lu: &LUFactorizedTridiagonal<Self>) -> Result<Self::Real> {
73-
let (n, _) = lu.a.l.size();
74-
let ipiv = &lu.ipiv;
75-
let mut work: Vec<MaybeUninit<Self>> = vec_uninit(2 * n as usize);
76-
$(
77-
let mut $iwork: Vec<MaybeUninit<i32>> = vec_uninit(n as usize);
78-
)*
79-
let mut rcond = Self::Real::zero();
80-
let mut info = 0;
81-
unsafe {
82-
$gtcon(
83-
NormType::One.as_ptr(),
84-
&n,
85-
AsPtr::as_ptr(&lu.a.dl),
86-
AsPtr::as_ptr(&lu.a.d),
87-
AsPtr::as_ptr(&lu.a.du),
88-
AsPtr::as_ptr(&lu.du2),
89-
ipiv.as_ptr(),
90-
&lu.a_opnorm_one,
91-
&mut rcond,
92-
AsPtr::as_mut_ptr(&mut work),
93-
$(AsPtr::as_mut_ptr(&mut $iwork),)*
94-
&mut info,
95-
);
96-
}
97-
info.as_lapack_result()?;
98-
Ok(rcond)
99-
}
100-
101-
fn solve_tridiagonal(
102-
lu: &LUFactorizedTridiagonal<Self>,
103-
b_layout: MatrixLayout,
104-
t: Transpose,
105-
b: &mut [Self],
106-
) -> Result<()> {
107-
let (n, _) = lu.a.l.size();
108-
let ipiv = &lu.ipiv;
109-
// Transpose if b is C-continuous
110-
let mut b_t = None;
111-
let b_layout = match b_layout {
112-
MatrixLayout::C { .. } => {
113-
let (layout, t) = transpose(b_layout, b);
114-
b_t = Some(t);
115-
layout
116-
}
117-
MatrixLayout::F { .. } => b_layout,
118-
};
119-
let (ldb, nrhs) = b_layout.size();
120-
let mut info = 0;
121-
unsafe {
122-
$gttrs(
123-
t.as_ptr(),
124-
&n,
125-
&nrhs,
126-
AsPtr::as_ptr(&lu.a.dl),
127-
AsPtr::as_ptr(&lu.a.d),
128-
AsPtr::as_ptr(&lu.a.du),
129-
AsPtr::as_ptr(&lu.du2),
130-
ipiv.as_ptr(),
131-
AsPtr::as_mut_ptr(b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b)),
132-
&ldb,
133-
&mut info,
134-
);
135-
}
136-
info.as_lapack_result()?;
137-
if let Some(b_t) = b_t {
138-
transpose_over(b_layout, &b_t, b);
139-
}
140-
Ok(())
141-
}
142-
}
143-
};
144-
} // impl_tridiagonal!
145-
146-
impl_tridiagonal!(@real, f64, lapack_sys::dgttrf_, lapack_sys::dgtcon_, lapack_sys::dgttrs_);
147-
impl_tridiagonal!(@real, f32, lapack_sys::sgttrf_, lapack_sys::sgtcon_, lapack_sys::sgttrs_);
148-
impl_tridiagonal!(@complex, c64, lapack_sys::zgttrf_, lapack_sys::zgtcon_, lapack_sys::zgttrs_);
149-
impl_tridiagonal!(@complex, c32, lapack_sys::cgttrf_, lapack_sys::cgtcon_, lapack_sys::cgttrs_);

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /