-rw-r--r-- | matrix.lua | 18 | ||||
-rw-r--r-- | matrix_arith.c | 46 | ||||
-rw-r--r-- | matrix_decls_source.c | 8 | ||||
-rw-r--r-- | matrix_source.c | 6 |
diff --git a/matrix.lua b/matrix.lua index 73dfd421..53a78224 100644 --- a/matrix.lua +++ b/matrix.lua @@ -82,7 +82,7 @@ end local function matrix_to_string(m) local eps = m:norm() * 1e-8 local fwidth = function(w, val) - local ln = # tostring_eps(val, eps) + local ln = # gsl.tostring_eps(val, eps) return (ln > w and ln or w) end local width = gsl.matrix_reduce(m, fwidth, 0) @@ -91,7 +91,7 @@ local function matrix_to_string(m) for i=1,r do local ln = {} for j=1,c do - insert(ln, padstr(tostring_eps(m:get(i,j), eps), width)) + insert(ln, padstr(gsl.tostring_eps(m:get(i,j), eps), width)) end insert(lines, fmt('[ %s ]', cat(ln, ' '))) end @@ -148,18 +148,6 @@ local function matrix_rows(m) return gsl.sequence(function(i) m:slice(i, 1, 1, c) end, r) end -function gsl.set(d, s) - local r, c = gsl.dim(d) - local rs, cs = gsl.dim(s) - if rs ~= r or cs ~= c then error 'matrix dimensions does not match' end - local dset, sget = d.set, s.get - for i=1, r do - for j=1, c do - dset(d, i, j, sget(s, i, j)) - end - end -end - function gsl.null(m) local r, c = gsl.dim(m) local mset = m.set @@ -199,7 +187,7 @@ end local function hc_print(hc) local eps = 1e-8 * hc_reduce(hc, function(p,z) return p + csqr(z) end, 0) local f = function(p, z) - insert(p, fmt('%6i: %s', #p, tostring_eps(z, eps))) + insert(p, fmt('%6i: %s', #p, gsl.tostring_eps(z, eps))) return p end return cat(hc_reduce(hc, f, {}), '\n') diff --git a/matrix_arith.c b/matrix_arith.c index 722108dd..bca97ea3 100644 --- a/matrix_arith.c +++ b/matrix_arith.c @@ -45,7 +45,8 @@ static int matrix_inv (lua_State *L); static int matrix_solve (lua_State *L); static int matrix_dim (lua_State *L); static int matrix_copy (lua_State *L); -static int matrix_prod (lua_State *L); +static int matrix_prod (lua_State *L); +static int matrix_set (lua_State *L); static const struct luaL_Reg matrix_arith_functions[] = { {"dim", matrix_dim}, @@ -53,6 +54,7 @@ static const struct luaL_Reg matrix_arith_functions[] = { {"solve", matrix_solve}, {"inv", matrix_inv}, {"prod", matrix_prod}, + {"set", matrix_set}, {NULL, NULL} }; @@ -403,6 +405,48 @@ matrix_prod (lua_State *L) return 1; } +int +matrix_set (lua_State *L) +{ + struct pmatrix a, b; + int rtp; + + check_matrix_type (L, 1, &a); + check_matrix_type (L, 2, &b); + + rtp = (a.tp == GS_MATRIX && b.tp == GS_MATRIX ? GS_MATRIX : GS_CMATRIX); + + if (a.tp != rtp) + matrix_complex_promote (L, 1, &a); + + if (b.tp != rtp) + matrix_complex_promote (L, 2, &b); + + switch (rtp) + { + case GS_MATRIX: + { + gsl_matrix *dst = a.m.real, *src = b.m.real; + if (dst->size1 != src->size1 || dst->size2 != src->size2) + luaL_error (L, "matrix dimensions does not match"); + gsl_matrix_memcpy (dst, src); + break; + } + case GS_CMATRIX: + { + gsl_matrix_complex *dst = a.m.cmpl, *src = b.m.cmpl; + if (dst->size1 != src->size1 || dst->size2 != src->size2) + luaL_error (L, "matrix dimensions does not match"); + gsl_matrix_complex_memcpy (dst, src); + break; + } + default: + /* */; + } + + return 0; +} + void matrix_arith_register (lua_State *L) { diff --git a/matrix_decls_source.c b/matrix_decls_source.c index cb6b0c3d..bc9f2a4b 100644 --- a/matrix_decls_source.c +++ b/matrix_decls_source.c @@ -20,11 +20,11 @@ #define NLINFIT_MAX_ITER 30 -static int FUNCTION (matrix, get) (lua_State *L); static int FUNCTION (matrix, index) (lua_State *L); static int FUNCTION (matrix, newindex) (lua_State *L); static int FUNCTION (matrix, len) (lua_State *L); -static int FUNCTION (matrix, set) (lua_State *L); +static int FUNCTION (matrix, get_elem) (lua_State *L); +static int FUNCTION (matrix, set_elem) (lua_State *L); static int FUNCTION (matrix, free) (lua_State *L); static int FUNCTION (matrix, new) (lua_State *L); static int FUNCTION (matrix, slice) (lua_State *L); @@ -43,8 +43,8 @@ static const struct luaL_Reg FUNCTION (matrix, meta_methods)[] = { }; static const struct luaL_Reg FUNCTION (matrix, methods)[] = { - {"get", FUNCTION (matrix, get)}, - {"set", FUNCTION (matrix, set)}, + {"get", FUNCTION (matrix, get_elem)}, + {"set", FUNCTION (matrix, set_elem)}, {"slice", FUNCTION (matrix, slice)}, {NULL, NULL} }; diff --git a/matrix_source.c b/matrix_source.c index 3f186c05..c8b6b581 100644 --- a/matrix_source.c +++ b/matrix_source.c @@ -138,7 +138,7 @@ FUNCTION(matrix, slice) (lua_State *L) } int -FUNCTION(matrix, get) (lua_State *L) +FUNCTION(matrix, get_elem) (lua_State *L) { const TYPE (gsl_matrix) *m = FUNCTION (matrix, check) (L, 1); lua_Integer r = luaL_checkinteger (L, 2); @@ -163,7 +163,7 @@ FUNCTION(matrix, get) (lua_State *L) } int -FUNCTION(matrix, set) (lua_State *L) +FUNCTION(matrix, set_elem) (lua_State *L) { TYPE (gsl_matrix) *m = FUNCTION (matrix, check) (L, 1); lua_Integer r = luaL_checkinteger (L, 2); @@ -302,7 +302,7 @@ FUNCTION(matrix, solve_raw) (lua_State *L, if (b->size2 != 1) gs_type_error (L, 1, "vector"); if (b->size1 != n) - luaL_error (L, "dimensions of vector does not match with matrix"); + luaL_error (L, "dimensions of vector does not match"); x = FUNCTION (matrix, push_raw) (L, n, 1); x_view = FUNCTION (gsl_matrix, column) (x, 0); |