author | francesco-ST <francesco.abbate@st.com> | 2010年10月25日 16:10:41 +0200 |
---|---|---|
committer | francesco-ST <francesco.abbate@st.com> | 2010年10月25日 16:10:41 +0200 |
commit | 53b62a19f80914481b3c4ddd29a34e58fb43b543 (patch) | |
tree | 6497e00aa60de23d10f098740524095b64537dfc | |
parent | 16b347628af89740c5cd4f5c6c9fd632fea55360 (diff) | |
download | gsl-shell-53b62a19f80914481b3c4ddd29a34e58fb43b543.tar.gz |
-rw-r--r-- | lua-gsl.c | 2 | ||||
-rw-r--r-- | matrix_arith.c | 146 | ||||
-rw-r--r-- | matrix_arith.h | 2 | ||||
-rw-r--r-- | matrix_decls_source.c | 11 | ||||
-rw-r--r-- | matrix_headers_source.h | 7 | ||||
-rw-r--r-- | matrix_source.c | 40 |
@@ -28,6 +28,7 @@ #include "cnlinfit.h" #include "matrix.h" #include "cmatrix.h" +#include "matrix_arith.h" #include "linalg.h" #include "integ.h" #include "fft.h" @@ -76,6 +77,7 @@ luaopen_gsl (lua_State *L) solver_register (L); matrix_register (L); + matrix_arith_register (L); linalg_register (L); integ_register (L); ode_register (L); diff --git a/matrix_arith.c b/matrix_arith.c index 9b6ae233..13155d68 100644 --- a/matrix_arith.c +++ b/matrix_arith.c @@ -23,6 +23,9 @@ #include <assert.h> #include <string.h> #include <gsl/gsl_matrix.h> +#include <gsl/gsl_blas.h> +#include <gsl/gsl_permutation.h> +#include <gsl/gsl_linalg.h> #include "gs-types.h" #include "matrix.h" @@ -30,6 +33,27 @@ #include "matrix_arith.h" #include "lua-utils.h" +struct pmatrix { + int tp; + union { + gsl_matrix *real; + gsl_matrix_complex *cmpl; + } m; +}; + +static int matrix_inv (lua_State *L); +static int matrix_solve (lua_State *L); +static int matrix_mul (lua_State *L); + +static const struct luaL_Reg matrix_arith_functions[] = { + // {"dim", matrix_dim)}, + // {"copy", matrix_copy}, + {"mul", matrix_mul}, + {"solve", matrix_solve}, + {"inv", matrix_inv}, + {NULL, NULL} +}; + static const char * size_err_msg = "matrices should have the same size in %s"; static gsl_matrix_complex * @@ -53,6 +77,33 @@ push_matrix_complex_of_real (lua_State *L, const gsl_matrix *a) return r; } +static void +check_matrix_type (lua_State *L, int index, struct pmatrix *r) +{ + if (gs_is_userdata (L, index, GS_MATRIX)) + { + r->tp = GS_MATRIX; + r->m.real = lua_touserdata (L, index); + } + else if (gs_is_userdata (L, index, GS_CMATRIX)) + { + r->tp = GS_CMATRIX; + r->m.cmpl = lua_touserdata (L, index); + } + else + { + gs_type_error (L, index, "matrix"); + } +} + +static void +matrix_complex_promote (lua_State *L, int index, struct pmatrix *a) +{ + a->tp = GS_CMATRIX; + a->m.cmpl = push_matrix_complex_of_real (L, a->m.real); + lua_replace (L, index); +} + #define OPER_ADD #include "template_matrix_oper_on.h" #include "matrix_op_source.c" @@ -120,3 +171,98 @@ matrix_unm (lua_State *L) return 1; } + +int +matrix_mul (lua_State *L) +{ + int nargs = lua_gettop (L); + struct pmatrix a, b, r; + int k; + + for (k = nargs - 1; k >= 1; k--) + { + check_matrix_type (L, k, &a); + check_matrix_type (L, k+1, &b); + + r.tp = (a.tp == GS_MATRIX && b.tp == GS_MATRIX ? GS_MATRIX : GS_CMATRIX); + + if (a.tp != r.tp) + matrix_complex_promote (L, k, &a); + + if (b.tp != r.tp) + matrix_complex_promote (L, k+1, &b); + + if (r.tp == GS_MATRIX) + r.m.real = matrix_push (L, a.m.real->size1, b.m.real->size2); + else + r.m.cmpl = matrix_complex_push (L, a.m.cmpl->size1, b.m.cmpl->size2); + + if (r.tp == GS_MATRIX) + { + gsl_blas_dgemm (CblasNoTrans, CblasNoTrans, + 1.0, a.m.real, b.m.real, 1.0, r.m.real); + } + else + { + gsl_complex u = {{1.0, 0.0}}; + gsl_blas_zgemm (CblasNoTrans, CblasNoTrans, + u, a.m.cmpl, b.m.cmpl, u, r.m.cmpl); + } + + lua_insert (L, k); + lua_pop (L, 2); + } + + return 1; +} + +int +matrix_inv (lua_State *L) +{ + struct pmatrix a; + check_matrix_type (L, 1, &a); + switch (a.tp) + { + case GS_MATRIX: + return matrix_inverse_raw (L, a.m.real); + case GS_CMATRIX: + return matrix_complex_inverse_raw (L, a.m.cmpl); + default: + /* */; + } + return 0; +} + +int +matrix_solve (lua_State *L) +{ + struct pmatrix a, b, r; + check_matrix_type (L, 1, &a); + check_matrix_type (L, 2, &b); + + r.tp = (a.tp == GS_MATRIX && b.tp == GS_MATRIX ? GS_MATRIX : GS_CMATRIX); + + if (a.tp != r.tp) + matrix_complex_promote (L, 1, &a); + + if (b.tp != r.tp) + matrix_complex_promote (L, 2, &b); + + switch (r.tp) + { + case GS_MATRIX: + return matrix_solve_raw (L, a.m.real, b.m.real); + case GS_CMATRIX: + return matrix_complex_solve_raw (L, a.m.cmpl, b.m.cmpl); + default: + /* */; + } + + return 0; +} + +void +matrix_arith_register (lua_State *L) +{ + luaL_register (L, NULL, matrix_arith_functions); +} diff --git a/matrix_arith.h b/matrix_arith.h index a3f9beca..7116f7bb 100644 --- a/matrix_arith.h +++ b/matrix_arith.h @@ -10,4 +10,6 @@ extern int matrix_mul_elements (lua_State *L); extern int matrix_div_elements (lua_State *L); extern int matrix_unm (lua_State *L); +extern void matrix_arith_register (lua_State *L); + #endif diff --git a/matrix_decls_source.c b/matrix_decls_source.c index a3402bcf..6afe0a48 100644 --- a/matrix_decls_source.c +++ b/matrix_decls_source.c @@ -28,10 +28,6 @@ static int FUNCTION (matrix, dims) (lua_State *L); static int FUNCTION (matrix, copy) (lua_State *L); static int FUNCTION (matrix, slice) (lua_State *L); -static int FUNCTION (matrix, mul) (lua_State *L); -static int FUNCTION (matrix, inverse) (lua_State *L); -static int FUNCTION (matrix, solve) (lua_State *L); - static void FUNCTION (matrix, set_ref) (lua_State *L, int index); static const struct luaL_Reg FUNCTION (matrix, methods)[] = { @@ -41,18 +37,15 @@ static const struct luaL_Reg FUNCTION (matrix, methods)[] = { {"__div", matrix_div_elements}, {"__unm", matrix_unm}, {"__gc", FUNCTION (matrix, free)}, - {"get", FUNCTION (matrix, get)}, - {"set", FUNCTION (matrix, set)}, {"dims", FUNCTION (matrix, dims)}, {"copy", FUNCTION (matrix, copy)}, + {"get", FUNCTION (matrix, get)}, + {"set", FUNCTION (matrix, set)}, {"slice", FUNCTION (matrix, slice)}, {NULL, NULL} }; static const struct luaL_Reg FUNCTION (matrix, functions)[] = { {PREFIX "new", FUNCTION (matrix, new)}, - {PREFIX "mul", FUNCTION (matrix, mul)}, - {PREFIX "solve", FUNCTION (matrix, solve)}, - {PREFIX "inverse", FUNCTION (matrix, inverse)}, {NULL, NULL} }; diff --git a/matrix_headers_source.h b/matrix_headers_source.h index e63bdab7..0515d990 100644 --- a/matrix_headers_source.h +++ b/matrix_headers_source.h @@ -20,6 +20,13 @@ extern void FUNCTION (matrix, check_size) (lua_State *L, TYPE (gsl_matrix) *m, size_t n1, size_t n2); +extern int FUNCTION (matrix, inverse_raw)(lua_State *L, + const TYPE (gsl_matrix) *a); + +extern int FUNCTION (matrix, solve_raw) (lua_State *L, + const TYPE (gsl_matrix) *a, + const TYPE (gsl_matrix) *b); + /* matrix helper functions */ extern void diff --git a/matrix_source.c b/matrix_source.c index b55e1280..ea04763b 100644 --- a/matrix_source.c +++ b/matrix_source.c @@ -268,40 +268,8 @@ FUNCTION(matrix, new) (lua_State *L) } int -FUNCTION(matrix, mul) (lua_State *L) +FUNCTION(matrix, inverse_raw) (lua_State *L, const TYPE (gsl_matrix) *a) { - int k, nargs = lua_gettop (L); - TYPE (gsl_matrix) *r, *a, *b; - BASE one = ONE; - - r = FUNCTION (matrix, check) (L, 1); - - for (k = 2; k <= nargs; k++) - { - a = r; - b = FUNCTION (matrix, check) (L, k); - - if (a->size2 != b->size1) - luaL_error (L, "incompatible matrix dimensions in multiplication"); - - if (k == nargs) - r = FUNCTION (matrix, push) (L, a->size1, b->size2); - else - r = FUNCTION (gsl_matrix, calloc) (a->size1, b->size2); - - BLAS_FUNCTION(gemm) (CblasNoTrans, CblasNoTrans, one, a, b, one, r); - - if (k > 2) - FUNCTION (gsl_matrix, free) (a); - } - - return 1; -} - -int -FUNCTION(matrix, inverse) (lua_State *L) -{ - const TYPE (gsl_matrix) *a = FUNCTION (matrix, check) (L, 1); TYPE (gsl_matrix) *lu, *inverse; gsl_permutation *p; size_t n = a->size1; @@ -326,10 +294,10 @@ FUNCTION(matrix, inverse) (lua_State *L) return 1; } int -FUNCTION(matrix, solve) (lua_State *L) +FUNCTION(matrix, solve_raw) (lua_State *L, + const TYPE (gsl_matrix) *a, + const TYPE (gsl_matrix) *b) { - const TYPE (gsl_matrix) *a = FUNCTION (matrix, check) (L, 1); - const TYPE (gsl_matrix) *b = FUNCTION (matrix, check) (L, 2); TYPE (gsl_matrix) *x; CONST_VIEW (gsl_vector) b_view = CONST_FUNCTION (gsl_matrix, column) (b, 0); VIEW (gsl_vector) x_view; |