-rw-r--r-- | matrix_arith.c | 146 |
diff --git a/matrix_arith.c b/matrix_arith.c index 9b6ae233..13155d68 100644 --- a/matrix_arith.c +++ b/matrix_arith.c @@ -23,6 +23,9 @@ #include <assert.h> #include <string.h> #include <gsl/gsl_matrix.h> +#include <gsl/gsl_blas.h> +#include <gsl/gsl_permutation.h> +#include <gsl/gsl_linalg.h> #include "gs-types.h" #include "matrix.h" @@ -30,6 +33,27 @@ #include "matrix_arith.h" #include "lua-utils.h" +struct pmatrix { + int tp; + union { + gsl_matrix *real; + gsl_matrix_complex *cmpl; + } m; +}; + +static int matrix_inv (lua_State *L); +static int matrix_solve (lua_State *L); +static int matrix_mul (lua_State *L); + +static const struct luaL_Reg matrix_arith_functions[] = { + // {"dim", matrix_dim)}, + // {"copy", matrix_copy}, + {"mul", matrix_mul}, + {"solve", matrix_solve}, + {"inv", matrix_inv}, + {NULL, NULL} +}; + static const char * size_err_msg = "matrices should have the same size in %s"; static gsl_matrix_complex * @@ -53,6 +77,33 @@ push_matrix_complex_of_real (lua_State *L, const gsl_matrix *a) return r; } +static void +check_matrix_type (lua_State *L, int index, struct pmatrix *r) +{ + if (gs_is_userdata (L, index, GS_MATRIX)) + { + r->tp = GS_MATRIX; + r->m.real = lua_touserdata (L, index); + } + else if (gs_is_userdata (L, index, GS_CMATRIX)) + { + r->tp = GS_CMATRIX; + r->m.cmpl = lua_touserdata (L, index); + } + else + { + gs_type_error (L, index, "matrix"); + } +} + +static void +matrix_complex_promote (lua_State *L, int index, struct pmatrix *a) +{ + a->tp = GS_CMATRIX; + a->m.cmpl = push_matrix_complex_of_real (L, a->m.real); + lua_replace (L, index); +} + #define OPER_ADD #include "template_matrix_oper_on.h" #include "matrix_op_source.c" @@ -120,3 +171,98 @@ matrix_unm (lua_State *L) return 1; } + +int +matrix_mul (lua_State *L) +{ + int nargs = lua_gettop (L); + struct pmatrix a, b, r; + int k; + + for (k = nargs - 1; k >= 1; k--) + { + check_matrix_type (L, k, &a); + check_matrix_type (L, k+1, &b); + + r.tp = (a.tp == GS_MATRIX && b.tp == GS_MATRIX ? GS_MATRIX : GS_CMATRIX); + + if (a.tp != r.tp) + matrix_complex_promote (L, k, &a); + + if (b.tp != r.tp) + matrix_complex_promote (L, k+1, &b); + + if (r.tp == GS_MATRIX) + r.m.real = matrix_push (L, a.m.real->size1, b.m.real->size2); + else + r.m.cmpl = matrix_complex_push (L, a.m.cmpl->size1, b.m.cmpl->size2); + + if (r.tp == GS_MATRIX) + { + gsl_blas_dgemm (CblasNoTrans, CblasNoTrans, + 1.0, a.m.real, b.m.real, 1.0, r.m.real); + } + else + { + gsl_complex u = {{1.0, 0.0}}; + gsl_blas_zgemm (CblasNoTrans, CblasNoTrans, + u, a.m.cmpl, b.m.cmpl, u, r.m.cmpl); + } + + lua_insert (L, k); + lua_pop (L, 2); + } + + return 1; +} + +int +matrix_inv (lua_State *L) +{ + struct pmatrix a; + check_matrix_type (L, 1, &a); + switch (a.tp) + { + case GS_MATRIX: + return matrix_inverse_raw (L, a.m.real); + case GS_CMATRIX: + return matrix_complex_inverse_raw (L, a.m.cmpl); + default: + /* */; + } + return 0; +} + +int +matrix_solve (lua_State *L) +{ + struct pmatrix a, b, r; + check_matrix_type (L, 1, &a); + check_matrix_type (L, 2, &b); + + r.tp = (a.tp == GS_MATRIX && b.tp == GS_MATRIX ? GS_MATRIX : GS_CMATRIX); + + if (a.tp != r.tp) + matrix_complex_promote (L, 1, &a); + + if (b.tp != r.tp) + matrix_complex_promote (L, 2, &b); + + switch (r.tp) + { + case GS_MATRIX: + return matrix_solve_raw (L, a.m.real, b.m.real); + case GS_CMATRIX: + return matrix_complex_solve_raw (L, a.m.cmpl, b.m.cmpl); + default: + /* */; + } + + return 0; +} + +void +matrix_arith_register (lua_State *L) +{ + luaL_register (L, NULL, matrix_arith_functions); +} |