added covariant matrix multiplication, inverse and solve - gsl-shell.git - gsl-shell

index : gsl-shell.git
gsl-shell
summary refs log tree commit diff
diff options
context:
space:
mode:
authorfrancesco-ST <francesco.abbate@st.com>2010年10月25日 16:10:41 +0200
committerfrancesco-ST <francesco.abbate@st.com>2010年10月25日 16:10:41 +0200
commit53b62a19f80914481b3c4ddd29a34e58fb43b543 (patch)
tree6497e00aa60de23d10f098740524095b64537dfc
parent16b347628af89740c5cd4f5c6c9fd632fea55360 (diff)
downloadgsl-shell-53b62a19f80914481b3c4ddd29a34e58fb43b543.tar.gz
added covariant matrix multiplication, inverse and solve
Diffstat
-rw-r--r--lua-gsl.c 2
-rw-r--r--matrix_arith.c 146
-rw-r--r--matrix_arith.h 2
-rw-r--r--matrix_decls_source.c 11
-rw-r--r--matrix_headers_source.h 7
-rw-r--r--matrix_source.c 40
6 files changed, 163 insertions, 45 deletions
diff --git a/lua-gsl.c b/lua-gsl.c
index 41380454..8fc08550 100644
--- a/lua-gsl.c
+++ b/lua-gsl.c
@@ -28,6 +28,7 @@
#include "cnlinfit.h"
#include "matrix.h"
#include "cmatrix.h"
+#include "matrix_arith.h"
#include "linalg.h"
#include "integ.h"
#include "fft.h"
@@ -76,6 +77,7 @@ luaopen_gsl (lua_State *L)
solver_register (L);
matrix_register (L);
+ matrix_arith_register (L);
linalg_register (L);
integ_register (L);
ode_register (L);
diff --git a/matrix_arith.c b/matrix_arith.c
index 9b6ae233..13155d68 100644
--- a/matrix_arith.c
+++ b/matrix_arith.c
@@ -23,6 +23,9 @@
#include <assert.h>
#include <string.h>
#include <gsl/gsl_matrix.h>
+#include <gsl/gsl_blas.h>
+#include <gsl/gsl_permutation.h>
+#include <gsl/gsl_linalg.h>
#include "gs-types.h"
#include "matrix.h"
@@ -30,6 +33,27 @@
#include "matrix_arith.h"
#include "lua-utils.h"
+struct pmatrix {
+ int tp;
+ union {
+ gsl_matrix *real;
+ gsl_matrix_complex *cmpl;
+ } m;
+};
+
+static int matrix_inv (lua_State *L);
+static int matrix_solve (lua_State *L);
+static int matrix_mul (lua_State *L);
+
+static const struct luaL_Reg matrix_arith_functions[] = {
+ // {"dim", matrix_dim)},
+ // {"copy", matrix_copy},
+ {"mul", matrix_mul},
+ {"solve", matrix_solve},
+ {"inv", matrix_inv},
+ {NULL, NULL}
+};
+
static const char * size_err_msg = "matrices should have the same size in %s";
static gsl_matrix_complex *
@@ -53,6 +77,33 @@ push_matrix_complex_of_real (lua_State *L, const gsl_matrix *a)
return r;
}
+static void
+check_matrix_type (lua_State *L, int index, struct pmatrix *r)
+{
+ if (gs_is_userdata (L, index, GS_MATRIX))
+ {
+ r->tp = GS_MATRIX;
+ r->m.real = lua_touserdata (L, index);
+ }
+ else if (gs_is_userdata (L, index, GS_CMATRIX))
+ {
+ r->tp = GS_CMATRIX;
+ r->m.cmpl = lua_touserdata (L, index);
+ }
+ else
+ {
+ gs_type_error (L, index, "matrix");
+ }
+}
+
+static void
+matrix_complex_promote (lua_State *L, int index, struct pmatrix *a)
+{
+ a->tp = GS_CMATRIX;
+ a->m.cmpl = push_matrix_complex_of_real (L, a->m.real);
+ lua_replace (L, index);
+}
+
#define OPER_ADD
#include "template_matrix_oper_on.h"
#include "matrix_op_source.c"
@@ -120,3 +171,98 @@ matrix_unm (lua_State *L)
return 1;
}
+
+int
+matrix_mul (lua_State *L)
+{
+ int nargs = lua_gettop (L);
+ struct pmatrix a, b, r;
+ int k;
+
+ for (k = nargs - 1; k >= 1; k--)
+ {
+ check_matrix_type (L, k, &a);
+ check_matrix_type (L, k+1, &b);
+
+ r.tp = (a.tp == GS_MATRIX && b.tp == GS_MATRIX ? GS_MATRIX : GS_CMATRIX);
+
+ if (a.tp != r.tp)
+ matrix_complex_promote (L, k, &a);
+
+ if (b.tp != r.tp)
+ matrix_complex_promote (L, k+1, &b);
+
+ if (r.tp == GS_MATRIX)
+ r.m.real = matrix_push (L, a.m.real->size1, b.m.real->size2);
+ else
+ r.m.cmpl = matrix_complex_push (L, a.m.cmpl->size1, b.m.cmpl->size2);
+
+ if (r.tp == GS_MATRIX)
+ {
+ gsl_blas_dgemm (CblasNoTrans, CblasNoTrans,
+ 1.0, a.m.real, b.m.real, 1.0, r.m.real);
+ }
+ else
+ {
+ gsl_complex u = {{1.0, 0.0}};
+ gsl_blas_zgemm (CblasNoTrans, CblasNoTrans,
+ u, a.m.cmpl, b.m.cmpl, u, r.m.cmpl);
+ }
+
+ lua_insert (L, k);
+ lua_pop (L, 2);
+ }
+
+ return 1;
+}
+
+int
+matrix_inv (lua_State *L)
+{
+ struct pmatrix a;
+ check_matrix_type (L, 1, &a);
+ switch (a.tp)
+ {
+ case GS_MATRIX:
+ return matrix_inverse_raw (L, a.m.real);
+ case GS_CMATRIX:
+ return matrix_complex_inverse_raw (L, a.m.cmpl);
+ default:
+ /* */;
+ }
+ return 0;
+}
+
+int
+matrix_solve (lua_State *L)
+{
+ struct pmatrix a, b, r;
+ check_matrix_type (L, 1, &a);
+ check_matrix_type (L, 2, &b);
+
+ r.tp = (a.tp == GS_MATRIX && b.tp == GS_MATRIX ? GS_MATRIX : GS_CMATRIX);
+
+ if (a.tp != r.tp)
+ matrix_complex_promote (L, 1, &a);
+
+ if (b.tp != r.tp)
+ matrix_complex_promote (L, 2, &b);
+
+ switch (r.tp)
+ {
+ case GS_MATRIX:
+ return matrix_solve_raw (L, a.m.real, b.m.real);
+ case GS_CMATRIX:
+ return matrix_complex_solve_raw (L, a.m.cmpl, b.m.cmpl);
+ default:
+ /* */;
+ }
+
+ return 0;
+}
+
+void
+matrix_arith_register (lua_State *L)
+{
+ luaL_register (L, NULL, matrix_arith_functions);
+}
diff --git a/matrix_arith.h b/matrix_arith.h
index a3f9beca..7116f7bb 100644
--- a/matrix_arith.h
+++ b/matrix_arith.h
@@ -10,4 +10,6 @@ extern int matrix_mul_elements (lua_State *L);
extern int matrix_div_elements (lua_State *L);
extern int matrix_unm (lua_State *L);
+extern void matrix_arith_register (lua_State *L);
+
#endif
diff --git a/matrix_decls_source.c b/matrix_decls_source.c
index a3402bcf..6afe0a48 100644
--- a/matrix_decls_source.c
+++ b/matrix_decls_source.c
@@ -28,10 +28,6 @@ static int FUNCTION (matrix, dims) (lua_State *L);
static int FUNCTION (matrix, copy) (lua_State *L);
static int FUNCTION (matrix, slice) (lua_State *L);
-static int FUNCTION (matrix, mul) (lua_State *L);
-static int FUNCTION (matrix, inverse) (lua_State *L);
-static int FUNCTION (matrix, solve) (lua_State *L);
-
static void FUNCTION (matrix, set_ref) (lua_State *L, int index);
static const struct luaL_Reg FUNCTION (matrix, methods)[] = {
@@ -41,18 +37,15 @@ static const struct luaL_Reg FUNCTION (matrix, methods)[] = {
{"__div", matrix_div_elements},
{"__unm", matrix_unm},
{"__gc", FUNCTION (matrix, free)},
- {"get", FUNCTION (matrix, get)},
- {"set", FUNCTION (matrix, set)},
{"dims", FUNCTION (matrix, dims)},
{"copy", FUNCTION (matrix, copy)},
+ {"get", FUNCTION (matrix, get)},
+ {"set", FUNCTION (matrix, set)},
{"slice", FUNCTION (matrix, slice)},
{NULL, NULL}
};
static const struct luaL_Reg FUNCTION (matrix, functions)[] = {
{PREFIX "new", FUNCTION (matrix, new)},
- {PREFIX "mul", FUNCTION (matrix, mul)},
- {PREFIX "solve", FUNCTION (matrix, solve)},
- {PREFIX "inverse", FUNCTION (matrix, inverse)},
{NULL, NULL}
};
diff --git a/matrix_headers_source.h b/matrix_headers_source.h
index e63bdab7..0515d990 100644
--- a/matrix_headers_source.h
+++ b/matrix_headers_source.h
@@ -20,6 +20,13 @@ extern void FUNCTION (matrix, check_size) (lua_State *L,
TYPE (gsl_matrix) *m,
size_t n1, size_t n2);
+extern int FUNCTION (matrix, inverse_raw)(lua_State *L,
+ const TYPE (gsl_matrix) *a);
+
+extern int FUNCTION (matrix, solve_raw) (lua_State *L,
+ const TYPE (gsl_matrix) *a,
+ const TYPE (gsl_matrix) *b);
+
/* matrix helper functions */
extern void
diff --git a/matrix_source.c b/matrix_source.c
index b55e1280..ea04763b 100644
--- a/matrix_source.c
+++ b/matrix_source.c
@@ -268,40 +268,8 @@ FUNCTION(matrix, new) (lua_State *L)
}
int
-FUNCTION(matrix, mul) (lua_State *L)
+FUNCTION(matrix, inverse_raw) (lua_State *L, const TYPE (gsl_matrix) *a)
{
- int k, nargs = lua_gettop (L);
- TYPE (gsl_matrix) *r, *a, *b;
- BASE one = ONE;
-
- r = FUNCTION (matrix, check) (L, 1);
-
- for (k = 2; k <= nargs; k++)
- {
- a = r;
- b = FUNCTION (matrix, check) (L, k);
-
- if (a->size2 != b->size1)
- luaL_error (L, "incompatible matrix dimensions in multiplication");
-
- if (k == nargs)
- r = FUNCTION (matrix, push) (L, a->size1, b->size2);
- else
- r = FUNCTION (gsl_matrix, calloc) (a->size1, b->size2);
-
- BLAS_FUNCTION(gemm) (CblasNoTrans, CblasNoTrans, one, a, b, one, r);
-
- if (k > 2)
- FUNCTION (gsl_matrix, free) (a);
- }
-
- return 1;
-}
-
-int
-FUNCTION(matrix, inverse) (lua_State *L)
-{
- const TYPE (gsl_matrix) *a = FUNCTION (matrix, check) (L, 1);
TYPE (gsl_matrix) *lu, *inverse;
gsl_permutation *p;
size_t n = a->size1;
@@ -326,10 +294,10 @@ FUNCTION(matrix, inverse) (lua_State *L)
return 1;
}
int
-FUNCTION(matrix, solve) (lua_State *L)
+FUNCTION(matrix, solve_raw) (lua_State *L,
+ const TYPE (gsl_matrix) *a,
+ const TYPE (gsl_matrix) *b)
{
- const TYPE (gsl_matrix) *a = FUNCTION (matrix, check) (L, 1);
- const TYPE (gsl_matrix) *b = FUNCTION (matrix, check) (L, 2);
TYPE (gsl_matrix) *x;
CONST_VIEW (gsl_vector) b_view = CONST_FUNCTION (gsl_matrix, column) (b, 0);
VIEW (gsl_vector) x_view;
generated by cgit v1.2.3 (git 2.39.1) at 2025年09月26日 20:20:17 +0000

AltStyle によって変換されたページ (->オリジナル) /