Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 4111835

Browse files
Merge pull request #272 from janmarthedal/relax-lstsq-types
Relax type bounds for LeastSquaresSvd family
2 parents aee87d9 + a23224f commit 4111835

File tree

1 file changed

+67
-32
lines changed

1 file changed

+67
-32
lines changed

‎ndarray-linalg/src/least_squares.rs‎

Lines changed: 67 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -149,12 +149,13 @@ where
149149

150150
/// Solve least squares for immutable references and a single
151151
/// column vector as a right-hand side.
152-
/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D` can be any
153-
/// valid representation for `ArrayBase`.
154-
impl<E, D> LeastSquaresSvd<D, E, Ix1> for ArrayBase<D, Ix2>
152+
/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D1`, `D2` can be any
153+
/// valid representation for `ArrayBase` (over `E`).
154+
impl<E, D1,D2> LeastSquaresSvd<D2, E, Ix1> for ArrayBase<D1, Ix2>
155155
where
156156
E: Scalar + Lapack,
157-
D: Data<Elem = E>,
157+
D1: Data<Elem = E>,
158+
D2: Data<Elem = E>,
158159
{
159160
/// Solve a least squares problem of the form `Ax = rhs`
160161
/// by calling `A.least_squares(&rhs)`, where `rhs` is a
@@ -163,7 +164,7 @@ where
163164
/// `A` and `rhs` must have the same layout, i.e. they must
164165
/// be both either row- or column-major format, otherwise a
165166
/// `IncompatibleShape` error is raised.
166-
fn least_squares(&self, rhs: &ArrayBase<D, Ix1>) -> Result<LeastSquaresResult<E, Ix1>> {
167+
fn least_squares(&self, rhs: &ArrayBase<D2, Ix1>) -> Result<LeastSquaresResult<E, Ix1>> {
167168
let a = self.to_owned();
168169
let b = rhs.to_owned();
169170
a.least_squares_into(b)
@@ -172,12 +173,13 @@ where
172173

173174
/// Solve least squares for immutable references and matrix
174175
/// (=mulitipe vectors) as a right-hand side.
175-
/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D` can be any
176-
/// valid representation for `ArrayBase`.
177-
impl<E, D> LeastSquaresSvd<D, E, Ix2> for ArrayBase<D, Ix2>
176+
/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D1`, `D2` can be any
177+
/// valid representation for `ArrayBase` (over `E`).
178+
impl<E, D1,D2> LeastSquaresSvd<D2, E, Ix2> for ArrayBase<D1, Ix2>
178179
where
179180
E: Scalar + Lapack,
180-
D: Data<Elem = E>,
181+
D1: Data<Elem = E>,
182+
D2: Data<Elem = E>,
181183
{
182184
/// Solve a least squares problem of the form `Ax = rhs`
183185
/// by calling `A.least_squares(&rhs)`, where `rhs` is
@@ -186,7 +188,7 @@ where
186188
/// `A` and `rhs` must have the same layout, i.e. they must
187189
/// be both either row- or column-major format, otherwise a
188190
/// `IncompatibleShape` error is raised.
189-
fn least_squares(&self, rhs: &ArrayBase<D, Ix2>) -> Result<LeastSquaresResult<E, Ix2>> {
191+
fn least_squares(&self, rhs: &ArrayBase<D2, Ix2>) -> Result<LeastSquaresResult<E, Ix2>> {
190192
let a = self.to_owned();
191193
let b = rhs.to_owned();
192194
a.least_squares_into(b)
@@ -199,10 +201,11 @@ where
199201
///
200202
/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D` can be any
201203
/// valid representation for `ArrayBase`.
202-
impl<E, D> LeastSquaresSvdInto<D, E, Ix1> for ArrayBase<D, Ix2>
204+
impl<E, D1,D2> LeastSquaresSvdInto<D2, E, Ix1> for ArrayBase<D1, Ix2>
203205
where
204206
E: Scalar + Lapack,
205-
D: DataMut<Elem = E>,
207+
D1: DataMut<Elem = E>,
208+
D2: DataMut<Elem = E>,
206209
{
207210
/// Solve a least squares problem of the form `Ax = rhs`
208211
/// by calling `A.least_squares(rhs)`, where `rhs` is a
@@ -213,7 +216,7 @@ where
213216
/// `IncompatibleShape` error is raised.
214217
fn least_squares_into(
215218
mut self,
216-
mut rhs: ArrayBase<D, Ix1>,
219+
mut rhs: ArrayBase<D2, Ix1>,
217220
) -> Result<LeastSquaresResult<E, Ix1>> {
218221
self.least_squares_in_place(&mut rhs)
219222
}
@@ -223,12 +226,13 @@ where
223226
/// as a right-hand side. The matrix and the RHS matrix
224227
/// are consumed.
225228
///
226-
/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D` can be any
227-
/// valid representation for `ArrayBase`.
228-
impl<E, D> LeastSquaresSvdInto<D, E, Ix2> for ArrayBase<D, Ix2>
229+
/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D1`, `D2` can be any
230+
/// valid representation for `ArrayBase` (over `E`).
231+
impl<E, D1,D2> LeastSquaresSvdInto<D2, E, Ix2> for ArrayBase<D1, Ix2>
229232
where
230233
E: Scalar + Lapack,
231-
D: DataMut<Elem = E>,
234+
D1: DataMut<Elem = E>,
235+
D2: DataMut<Elem = E>,
232236
{
233237
/// Solve a least squares problem of the form `Ax = rhs`
234238
/// by calling `A.least_squares(rhs)`, where `rhs` is a
@@ -239,7 +243,7 @@ where
239243
/// `IncompatibleShape` error is raised.
240244
fn least_squares_into(
241245
mut self,
242-
mut rhs: ArrayBase<D, Ix2>,
246+
mut rhs: ArrayBase<D2, Ix2>,
243247
) -> Result<LeastSquaresResult<E, Ix2>> {
244248
self.least_squares_in_place(&mut rhs)
245249
}
@@ -249,12 +253,13 @@ where
249253
/// as a right-hand side. Both values are overwritten in the
250254
/// call.
251255
///
252-
/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D` can be any
253-
/// valid representation for `ArrayBase`.
254-
impl<E, D> LeastSquaresSvdInPlace<D, E, Ix1> for ArrayBase<D, Ix2>
256+
/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D1`, `D2` can be any
257+
/// valid representation for `ArrayBase` (over `E`).
258+
impl<E, D1,D2> LeastSquaresSvdInPlace<D2, E, Ix1> for ArrayBase<D1, Ix2>
255259
where
256260
E: Scalar + Lapack,
257-
D: DataMut<Elem = E>,
261+
D1: DataMut<Elem = E>,
262+
D2: DataMut<Elem = E>,
258263
{
259264
/// Solve a least squares problem of the form `Ax = rhs`
260265
/// by calling `A.least_squares(rhs)`, where `rhs` is a
@@ -265,7 +270,7 @@ where
265270
/// `IncompatibleShape` error is raised.
266271
fn least_squares_in_place(
267272
&mut self,
268-
rhs: &mut ArrayBase<D, Ix1>,
273+
rhs: &mut ArrayBase<D2, Ix1>,
269274
) -> Result<LeastSquaresResult<E, Ix1>> {
270275
if self.shape()[0] != rhs.shape()[0] {
271276
return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape).into());
@@ -331,12 +336,13 @@ fn compute_residual_scalar<E: Scalar, D: Data<Elem = E>>(
331336
/// as a right-hand side. Both values are overwritten in the
332337
/// call.
333338
///
334-
/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D` can be any
335-
/// valid representation for `ArrayBase`.
336-
impl<E, D> LeastSquaresSvdInPlace<D, E, Ix2> for ArrayBase<D, Ix2>
339+
/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D1`, `D2` can be any
340+
/// valid representation for `ArrayBase` (over `E`).
341+
impl<E, D1,D2> LeastSquaresSvdInPlace<D2, E, Ix2> for ArrayBase<D1, Ix2>
337342
where
338343
E: Scalar + Lapack + LeastSquaresSvdDivideConquer_,
339-
D: DataMut<Elem = E>,
344+
D1: DataMut<Elem = E>,
345+
D2: DataMut<Elem = E>,
340346
{
341347
/// Solve a least squares problem of the form `Ax = rhs`
342348
/// by calling `A.least_squares(rhs)`, where `rhs` is a
@@ -347,7 +353,7 @@ where
347353
/// `IncompatibleShape` error is raised.
348354
fn least_squares_in_place(
349355
&mut self,
350-
rhs: &mut ArrayBase<D, Ix2>,
356+
rhs: &mut ArrayBase<D2, Ix2>,
351357
) -> Result<LeastSquaresResult<E, Ix2>> {
352358
if self.shape()[0] != rhs.shape()[0] {
353359
return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape).into());
@@ -425,7 +431,7 @@ mod tests {
425431
use ndarray::*;
426432

427433
//
428-
// Test that the different lest squares traits work as intended on the
434+
// Test that the different least squares traits work as intended on the
429435
// different array types.
430436
//
431437
// | least_squares | ls_into | ls_in_place |
@@ -437,9 +443,9 @@ mod tests {
437443
// ArrayViewMut | yes | no | yes |
438444
//
439445

440-
fn assert_result<D: Data<Elem = f64>>(
441-
a: &ArrayBase<D, Ix2>,
442-
b: &ArrayBase<D, Ix1>,
446+
fn assert_result<D1:Data<Elem = f64>,D2: Data<Elem = f64>>(
447+
a: &ArrayBase<D1, Ix2>,
448+
b: &ArrayBase<D2, Ix1>,
443449
res: &LeastSquaresResult<f64, Ix1>,
444450
) {
445451
assert_eq!(res.rank, 2);
@@ -487,6 +493,15 @@ mod tests {
487493
assert_result(&av, &bv, &res);
488494
}
489495

496+
#[test]
497+
fn on_cow_view() {
498+
let a = CowArray::from(array![[1., 2.], [4., 5.], [3., 4.]]);
499+
let b: Array1<f64> = array![1., 2., 3.];
500+
let bv = b.view();
501+
let res = a.least_squares(&bv).unwrap();
502+
assert_result(&a, &bv, &res);
503+
}
504+
490505
#[test]
491506
fn into_on_owned() {
492507
let a: Array2<f64> = array![[1., 2.], [4., 5.], [3., 4.]];
@@ -517,6 +532,16 @@ mod tests {
517532
assert_result(&a, &b, &res);
518533
}
519534

535+
#[test]
536+
fn into_on_owned_cow() {
537+
let a: Array2<f64> = array![[1., 2.], [4., 5.], [3., 4.]];
538+
let b = CowArray::from(array![1., 2., 3.]);
539+
let ac = a.clone();
540+
let b2 = b.clone();
541+
let res = ac.least_squares_into(b2).unwrap();
542+
assert_result(&a, &b, &res);
543+
}
544+
520545
#[test]
521546
fn in_place_on_owned() {
522547
let a = array![[1., 2.], [4., 5.], [3., 4.]];
@@ -549,6 +574,16 @@ mod tests {
549574
assert_result(&a, &b, &res);
550575
}
551576

577+
#[test]
578+
fn in_place_on_owned_cow() {
579+
let a = array![[1., 2.], [4., 5.], [3., 4.]];
580+
let b = CowArray::from(array![1., 2., 3.]);
581+
let mut a2 = a.clone();
582+
let mut b2 = b.clone();
583+
let res = a2.least_squares_in_place(&mut b2).unwrap();
584+
assert_result(&a, &b, &res);
585+
}
586+
552587
//
553588
// Testing error cases
554589
//

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /