gsl-shell.git - gsl-shell

index : gsl-shell.git
gsl-shell
summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat
-rw-r--r--matrix.lua 18
-rw-r--r--matrix_arith.c 46
-rw-r--r--matrix_decls_source.c 8
-rw-r--r--matrix_source.c 6
4 files changed, 55 insertions, 23 deletions
diff --git a/matrix.lua b/matrix.lua
index 73dfd421..53a78224 100644
--- a/matrix.lua
+++ b/matrix.lua
@@ -82,7 +82,7 @@ end
local function matrix_to_string(m)
local eps = m:norm() * 1e-8
local fwidth = function(w, val)
- local ln = # tostring_eps(val, eps)
+ local ln = # gsl.tostring_eps(val, eps)
return (ln > w and ln or w)
end
local width = gsl.matrix_reduce(m, fwidth, 0)
@@ -91,7 +91,7 @@ local function matrix_to_string(m)
for i=1,r do
local ln = {}
for j=1,c do
- insert(ln, padstr(tostring_eps(m:get(i,j), eps), width))
+ insert(ln, padstr(gsl.tostring_eps(m:get(i,j), eps), width))
end
insert(lines, fmt('[ %s ]', cat(ln, ' ')))
end
@@ -148,18 +148,6 @@ local function matrix_rows(m)
return gsl.sequence(function(i) m:slice(i, 1, 1, c) end, r)
end
-function gsl.set(d, s)
- local r, c = gsl.dim(d)
- local rs, cs = gsl.dim(s)
- if rs ~= r or cs ~= c then error 'matrix dimensions does not match' end
- local dset, sget = d.set, s.get
- for i=1, r do
- for j=1, c do
- dset(d, i, j, sget(s, i, j))
- end
- end
-end
-
function gsl.null(m)
local r, c = gsl.dim(m)
local mset = m.set
@@ -199,7 +187,7 @@ end
local function hc_print(hc)
local eps = 1e-8 * hc_reduce(hc, function(p,z) return p + csqr(z) end, 0)
local f = function(p, z)
- insert(p, fmt('%6i: %s', #p, tostring_eps(z, eps)))
+ insert(p, fmt('%6i: %s', #p, gsl.tostring_eps(z, eps)))
return p
end
return cat(hc_reduce(hc, f, {}), '\n')
diff --git a/matrix_arith.c b/matrix_arith.c
index 722108dd..bca97ea3 100644
--- a/matrix_arith.c
+++ b/matrix_arith.c
@@ -45,7 +45,8 @@ static int matrix_inv (lua_State *L);
static int matrix_solve (lua_State *L);
static int matrix_dim (lua_State *L);
static int matrix_copy (lua_State *L);
-static int matrix_prod (lua_State *L);
+static int matrix_prod (lua_State *L);
+static int matrix_set (lua_State *L);
static const struct luaL_Reg matrix_arith_functions[] = {
{"dim", matrix_dim},
@@ -53,6 +54,7 @@ static const struct luaL_Reg matrix_arith_functions[] = {
{"solve", matrix_solve},
{"inv", matrix_inv},
{"prod", matrix_prod},
+ {"set", matrix_set},
{NULL, NULL}
};
@@ -403,6 +405,48 @@ matrix_prod (lua_State *L)
return 1;
}
+int
+matrix_set (lua_State *L)
+{
+ struct pmatrix a, b;
+ int rtp;
+
+ check_matrix_type (L, 1, &a);
+ check_matrix_type (L, 2, &b);
+
+ rtp = (a.tp == GS_MATRIX && b.tp == GS_MATRIX ? GS_MATRIX : GS_CMATRIX);
+
+ if (a.tp != rtp)
+ matrix_complex_promote (L, 1, &a);
+
+ if (b.tp != rtp)
+ matrix_complex_promote (L, 2, &b);
+
+ switch (rtp)
+ {
+ case GS_MATRIX:
+ {
+ gsl_matrix *dst = a.m.real, *src = b.m.real;
+ if (dst->size1 != src->size1 || dst->size2 != src->size2)
+ luaL_error (L, "matrix dimensions does not match");
+ gsl_matrix_memcpy (dst, src);
+ break;
+ }
+ case GS_CMATRIX:
+ {
+ gsl_matrix_complex *dst = a.m.cmpl, *src = b.m.cmpl;
+ if (dst->size1 != src->size1 || dst->size2 != src->size2)
+ luaL_error (L, "matrix dimensions does not match");
+ gsl_matrix_complex_memcpy (dst, src);
+ break;
+ }
+ default:
+ /* */;
+ }
+
+ return 0;
+}
+
void
matrix_arith_register (lua_State *L)
{
diff --git a/matrix_decls_source.c b/matrix_decls_source.c
index cb6b0c3d..bc9f2a4b 100644
--- a/matrix_decls_source.c
+++ b/matrix_decls_source.c
@@ -20,11 +20,11 @@
#define NLINFIT_MAX_ITER 30
-static int FUNCTION (matrix, get) (lua_State *L);
static int FUNCTION (matrix, index) (lua_State *L);
static int FUNCTION (matrix, newindex) (lua_State *L);
static int FUNCTION (matrix, len) (lua_State *L);
-static int FUNCTION (matrix, set) (lua_State *L);
+static int FUNCTION (matrix, get_elem) (lua_State *L);
+static int FUNCTION (matrix, set_elem) (lua_State *L);
static int FUNCTION (matrix, free) (lua_State *L);
static int FUNCTION (matrix, new) (lua_State *L);
static int FUNCTION (matrix, slice) (lua_State *L);
@@ -43,8 +43,8 @@ static const struct luaL_Reg FUNCTION (matrix, meta_methods)[] = {
};
static const struct luaL_Reg FUNCTION (matrix, methods)[] = {
- {"get", FUNCTION (matrix, get)},
- {"set", FUNCTION (matrix, set)},
+ {"get", FUNCTION (matrix, get_elem)},
+ {"set", FUNCTION (matrix, set_elem)},
{"slice", FUNCTION (matrix, slice)},
{NULL, NULL}
};
diff --git a/matrix_source.c b/matrix_source.c
index 3f186c05..c8b6b581 100644
--- a/matrix_source.c
+++ b/matrix_source.c
@@ -138,7 +138,7 @@ FUNCTION(matrix, slice) (lua_State *L)
}
int
-FUNCTION(matrix, get) (lua_State *L)
+FUNCTION(matrix, get_elem) (lua_State *L)
{
const TYPE (gsl_matrix) *m = FUNCTION (matrix, check) (L, 1);
lua_Integer r = luaL_checkinteger (L, 2);
@@ -163,7 +163,7 @@ FUNCTION(matrix, get) (lua_State *L)
}
int
-FUNCTION(matrix, set) (lua_State *L)
+FUNCTION(matrix, set_elem) (lua_State *L)
{
TYPE (gsl_matrix) *m = FUNCTION (matrix, check) (L, 1);
lua_Integer r = luaL_checkinteger (L, 2);
@@ -302,7 +302,7 @@ FUNCTION(matrix, solve_raw) (lua_State *L,
if (b->size2 != 1)
gs_type_error (L, 1, "vector");
if (b->size1 != n)
- luaL_error (L, "dimensions of vector does not match with matrix");
+ luaL_error (L, "dimensions of vector does not match");
x = FUNCTION (matrix, push_raw) (L, n, 1);
x_view = FUNCTION (gsl_matrix, column) (x, 0);
generated by cgit v1.2.3 (git 2.25.1) at 2025年09月17日 08:31:54 +0000

AltStyle によって変換されたページ (->オリジナル) /