gsl-shell.git - gsl-shell

index : gsl-shell.git
gsl-shell
summary refs log tree commit diff
path: root/matrix_arith.c
diff options
context:
space:
mode:
Diffstat (limited to 'matrix_arith.c')
-rw-r--r--matrix_arith.c 146
1 files changed, 146 insertions, 0 deletions
diff --git a/matrix_arith.c b/matrix_arith.c
index 9b6ae233..13155d68 100644
--- a/matrix_arith.c
+++ b/matrix_arith.c
@@ -23,6 +23,9 @@
#include <assert.h>
#include <string.h>
#include <gsl/gsl_matrix.h>
+#include <gsl/gsl_blas.h>
+#include <gsl/gsl_permutation.h>
+#include <gsl/gsl_linalg.h>
#include "gs-types.h"
#include "matrix.h"
@@ -30,6 +33,27 @@
#include "matrix_arith.h"
#include "lua-utils.h"
+struct pmatrix {
+ int tp;
+ union {
+ gsl_matrix *real;
+ gsl_matrix_complex *cmpl;
+ } m;
+};
+
+static int matrix_inv (lua_State *L);
+static int matrix_solve (lua_State *L);
+static int matrix_mul (lua_State *L);
+
+static const struct luaL_Reg matrix_arith_functions[] = {
+ // {"dim", matrix_dim)},
+ // {"copy", matrix_copy},
+ {"mul", matrix_mul},
+ {"solve", matrix_solve},
+ {"inv", matrix_inv},
+ {NULL, NULL}
+};
+
static const char * size_err_msg = "matrices should have the same size in %s";
static gsl_matrix_complex *
@@ -53,6 +77,33 @@ push_matrix_complex_of_real (lua_State *L, const gsl_matrix *a)
return r;
}
+static void
+check_matrix_type (lua_State *L, int index, struct pmatrix *r)
+{
+ if (gs_is_userdata (L, index, GS_MATRIX))
+ {
+ r->tp = GS_MATRIX;
+ r->m.real = lua_touserdata (L, index);
+ }
+ else if (gs_is_userdata (L, index, GS_CMATRIX))
+ {
+ r->tp = GS_CMATRIX;
+ r->m.cmpl = lua_touserdata (L, index);
+ }
+ else
+ {
+ gs_type_error (L, index, "matrix");
+ }
+}
+
+static void
+matrix_complex_promote (lua_State *L, int index, struct pmatrix *a)
+{
+ a->tp = GS_CMATRIX;
+ a->m.cmpl = push_matrix_complex_of_real (L, a->m.real);
+ lua_replace (L, index);
+}
+
#define OPER_ADD
#include "template_matrix_oper_on.h"
#include "matrix_op_source.c"
@@ -120,3 +171,98 @@ matrix_unm (lua_State *L)
return 1;
}
+
+int
+matrix_mul (lua_State *L)
+{
+ int nargs = lua_gettop (L);
+ struct pmatrix a, b, r;
+ int k;
+
+ for (k = nargs - 1; k >= 1; k--)
+ {
+ check_matrix_type (L, k, &a);
+ check_matrix_type (L, k+1, &b);
+
+ r.tp = (a.tp == GS_MATRIX && b.tp == GS_MATRIX ? GS_MATRIX : GS_CMATRIX);
+
+ if (a.tp != r.tp)
+ matrix_complex_promote (L, k, &a);
+
+ if (b.tp != r.tp)
+ matrix_complex_promote (L, k+1, &b);
+
+ if (r.tp == GS_MATRIX)
+ r.m.real = matrix_push (L, a.m.real->size1, b.m.real->size2);
+ else
+ r.m.cmpl = matrix_complex_push (L, a.m.cmpl->size1, b.m.cmpl->size2);
+
+ if (r.tp == GS_MATRIX)
+ {
+ gsl_blas_dgemm (CblasNoTrans, CblasNoTrans,
+ 1.0, a.m.real, b.m.real, 1.0, r.m.real);
+ }
+ else
+ {
+ gsl_complex u = {{1.0, 0.0}};
+ gsl_blas_zgemm (CblasNoTrans, CblasNoTrans,
+ u, a.m.cmpl, b.m.cmpl, u, r.m.cmpl);
+ }
+
+ lua_insert (L, k);
+ lua_pop (L, 2);
+ }
+
+ return 1;
+}
+
+int
+matrix_inv (lua_State *L)
+{
+ struct pmatrix a;
+ check_matrix_type (L, 1, &a);
+ switch (a.tp)
+ {
+ case GS_MATRIX:
+ return matrix_inverse_raw (L, a.m.real);
+ case GS_CMATRIX:
+ return matrix_complex_inverse_raw (L, a.m.cmpl);
+ default:
+ /* */;
+ }
+ return 0;
+}
+
+int
+matrix_solve (lua_State *L)
+{
+ struct pmatrix a, b, r;
+ check_matrix_type (L, 1, &a);
+ check_matrix_type (L, 2, &b);
+
+ r.tp = (a.tp == GS_MATRIX && b.tp == GS_MATRIX ? GS_MATRIX : GS_CMATRIX);
+
+ if (a.tp != r.tp)
+ matrix_complex_promote (L, 1, &a);
+
+ if (b.tp != r.tp)
+ matrix_complex_promote (L, 2, &b);
+
+ switch (r.tp)
+ {
+ case GS_MATRIX:
+ return matrix_solve_raw (L, a.m.real, b.m.real);
+ case GS_CMATRIX:
+ return matrix_complex_solve_raw (L, a.m.cmpl, b.m.cmpl);
+ default:
+ /* */;
+ }
+
+ return 0;
+}
+
+void
+matrix_arith_register (lua_State *L)
+{
+ luaL_register (L, NULL, matrix_arith_functions);
+}
generated by cgit v1.2.3 (git 2.46.0) at 2025年10月01日 04:58:35 +0000

AltStyle によって変換されたページ (->オリジナル) /