Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 17f9fb8

Browse files
Split definition of tridiagonal matrix
1 parent ad19250 commit 17f9fb8

File tree

2 files changed

+110
-103
lines changed

2 files changed

+110
-103
lines changed

‎lax/src/tridiagonal/matrix.rs‎

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
use crate::layout::*;
2+
use cauchy::*;
3+
use num_traits::Zero;
4+
use std::ops::{Index, IndexMut};
5+
6+
/// Represents a tridiagonal matrix as 3 one-dimensional vectors.
7+
///
8+
/// ```text
9+
/// [d0, u1, 0, ..., 0,
10+
/// l1, d1, u2, ...,
11+
/// 0, l2, d2,
12+
/// ... ..., u{n-1},
13+
/// 0, ..., l{n-1}, d{n-1},]
14+
/// ```
15+
#[derive(Clone, PartialEq, Eq)]
16+
pub struct Tridiagonal<A: Scalar> {
17+
/// layout of raw matrix
18+
pub l: MatrixLayout,
19+
/// (n-1) sub-diagonal elements of matrix.
20+
pub dl: Vec<A>,
21+
/// (n) diagonal elements of matrix.
22+
pub d: Vec<A>,
23+
/// (n-1) super-diagonal elements of matrix.
24+
pub du: Vec<A>,
25+
}
26+
27+
impl<A: Scalar> Tridiagonal<A> {
28+
pub fn opnorm_one(&self) -> A::Real {
29+
let mut col_sum: Vec<A::Real> = self.d.iter().map(|val| val.abs()).collect();
30+
for i in 0..col_sum.len() {
31+
if i < self.dl.len() {
32+
col_sum[i] += self.dl[i].abs();
33+
}
34+
if i > 0 {
35+
col_sum[i] += self.du[i - 1].abs();
36+
}
37+
}
38+
let mut max = A::Real::zero();
39+
for &val in &col_sum {
40+
if max < val {
41+
max = val;
42+
}
43+
}
44+
max
45+
}
46+
}
47+
48+
impl<A: Scalar> Index<(i32, i32)> for Tridiagonal<A> {
49+
type Output = A;
50+
#[inline]
51+
fn index(&self, (row, col): (i32, i32)) -> &A {
52+
let (n, _) = self.l.size();
53+
assert!(
54+
std::cmp::max(row, col) < n,
55+
"ndarray: index {:?} is out of bounds for array of shape {}",
56+
[row, col],
57+
n
58+
);
59+
match row - col {
60+
0 => &self.d[row as usize],
61+
1 => &self.dl[col as usize],
62+
-1 => &self.du[row as usize],
63+
_ => panic!(
64+
"ndarray-linalg::tridiagonal: index {:?} is not tridiagonal element",
65+
[row, col]
66+
),
67+
}
68+
}
69+
}
70+
71+
impl<A: Scalar> Index<[i32; 2]> for Tridiagonal<A> {
72+
type Output = A;
73+
#[inline]
74+
fn index(&self, [row, col]: [i32; 2]) -> &A {
75+
&self[(row, col)]
76+
}
77+
}
78+
79+
impl<A: Scalar> IndexMut<(i32, i32)> for Tridiagonal<A> {
80+
#[inline]
81+
fn index_mut(&mut self, (row, col): (i32, i32)) -> &mut A {
82+
let (n, _) = self.l.size();
83+
assert!(
84+
std::cmp::max(row, col) < n,
85+
"ndarray: index {:?} is out of bounds for array of shape {}",
86+
[row, col],
87+
n
88+
);
89+
match row - col {
90+
0 => &mut self.d[row as usize],
91+
1 => &mut self.dl[col as usize],
92+
-1 => &mut self.du[row as usize],
93+
_ => panic!(
94+
"ndarray-linalg::tridiagonal: index {:?} is not tridiagonal element",
95+
[row, col]
96+
),
97+
}
98+
}
99+
}
100+
101+
impl<A: Scalar> IndexMut<[i32; 2]> for Tridiagonal<A> {
102+
#[inline]
103+
fn index_mut(&mut self, [row, col]: [i32; 2]) -> &mut A {
104+
&mut self[(row, col)]
105+
}
106+
}

‎lax/src/tridiagonal.rs‎ renamed to ‎lax/src/tridiagonal/mod.rs‎

Lines changed: 4 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,13 @@
11
//! Implement linear solver using LU decomposition
22
//! for tridiagonal matrix
33
4+
mod matrix;
5+
6+
pub use matrix::*;
7+
48
use crate::{error::*, layout::*, *};
59
use cauchy::*;
610
use num_traits::Zero;
7-
use std::ops::{Index, IndexMut};
8-
9-
/// Represents a tridiagonal matrix as 3 one-dimensional vectors.
10-
///
11-
/// ```text
12-
/// [d0, u1, 0, ..., 0,
13-
/// l1, d1, u2, ...,
14-
/// 0, l2, d2,
15-
/// ... ..., u{n-1},
16-
/// 0, ..., l{n-1}, d{n-1},]
17-
/// ```
18-
#[derive(Clone, PartialEq, Eq)]
19-
pub struct Tridiagonal<A: Scalar> {
20-
/// layout of raw matrix
21-
pub l: MatrixLayout,
22-
/// (n-1) sub-diagonal elements of matrix.
23-
pub dl: Vec<A>,
24-
/// (n) diagonal elements of matrix.
25-
pub d: Vec<A>,
26-
/// (n-1) super-diagonal elements of matrix.
27-
pub du: Vec<A>,
28-
}
29-
30-
impl<A: Scalar> Tridiagonal<A> {
31-
fn opnorm_one(&self) -> A::Real {
32-
let mut col_sum: Vec<A::Real> = self.d.iter().map(|val| val.abs()).collect();
33-
for i in 0..col_sum.len() {
34-
if i < self.dl.len() {
35-
col_sum[i] += self.dl[i].abs();
36-
}
37-
if i > 0 {
38-
col_sum[i] += self.du[i - 1].abs();
39-
}
40-
}
41-
let mut max = A::Real::zero();
42-
for &val in &col_sum {
43-
if max < val {
44-
max = val;
45-
}
46-
}
47-
max
48-
}
49-
}
5011

5112
/// Represents the LU factorization of a tridiagonal matrix `A` as `A = P*L*U`.
5213
#[derive(Clone, PartialEq)]
@@ -65,66 +26,6 @@ pub struct LUFactorizedTridiagonal<A: Scalar> {
6526
a_opnorm_one: A::Real,
6627
}
6728

68-
impl<A: Scalar> Index<(i32, i32)> for Tridiagonal<A> {
69-
type Output = A;
70-
#[inline]
71-
fn index(&self, (row, col): (i32, i32)) -> &A {
72-
let (n, _) = self.l.size();
73-
assert!(
74-
std::cmp::max(row, col) < n,
75-
"ndarray: index {:?} is out of bounds for array of shape {}",
76-
[row, col],
77-
n
78-
);
79-
match row - col {
80-
0 => &self.d[row as usize],
81-
1 => &self.dl[col as usize],
82-
-1 => &self.du[row as usize],
83-
_ => panic!(
84-
"ndarray-linalg::tridiagonal: index {:?} is not tridiagonal element",
85-
[row, col]
86-
),
87-
}
88-
}
89-
}
90-
91-
impl<A: Scalar> Index<[i32; 2]> for Tridiagonal<A> {
92-
type Output = A;
93-
#[inline]
94-
fn index(&self, [row, col]: [i32; 2]) -> &A {
95-
&self[(row, col)]
96-
}
97-
}
98-
99-
impl<A: Scalar> IndexMut<(i32, i32)> for Tridiagonal<A> {
100-
#[inline]
101-
fn index_mut(&mut self, (row, col): (i32, i32)) -> &mut A {
102-
let (n, _) = self.l.size();
103-
assert!(
104-
std::cmp::max(row, col) < n,
105-
"ndarray: index {:?} is out of bounds for array of shape {}",
106-
[row, col],
107-
n
108-
);
109-
match row - col {
110-
0 => &mut self.d[row as usize],
111-
1 => &mut self.dl[col as usize],
112-
-1 => &mut self.du[row as usize],
113-
_ => panic!(
114-
"ndarray-linalg::tridiagonal: index {:?} is not tridiagonal element",
115-
[row, col]
116-
),
117-
}
118-
}
119-
}
120-
121-
impl<A: Scalar> IndexMut<[i32; 2]> for Tridiagonal<A> {
122-
#[inline]
123-
fn index_mut(&mut self, [row, col]: [i32; 2]) -> &mut A {
124-
&mut self[(row, col)]
125-
}
126-
}
127-
12829
/// Wraps `*gttrf`, `*gtcon` and `*gttrs`
12930
pub trait Tridiagonal_: Scalar + Sized {
13031
/// Computes the LU factorization of a tridiagonal `m x n` matrix `a` using

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /