Rationalize initialization of modules - gsl-shell.git - gsl-shell

index : gsl-shell.git
gsl-shell
summary refs log tree commit diff
path: root/matrix.lua
diff options
context:
space:
mode:
authorFrancesco Abbate <francesco.bbt@gmail.com>2011年10月26日 11:22:04 +0200
committerFrancesco Abbate <francesco.bbt@gmail.com>2011年10月26日 11:22:04 +0200
commit1295e713310a8c57f9b9d3a34e094f600d61fa75 (patch)
tree3ee0ece5c34c6813e1eb886a94e2fd66867eb956 /matrix.lua
parent505c0c9226842d4989e90fba8fe56cd3c42040d3 (diff)
downloadgsl-shell-1295e713310a8c57f9b9d3a34e094f600d61fa75.tar.gz
Rationalize initialization of modules
Diffstat (limited to 'matrix.lua')
-rw-r--r--matrix.lua 935
1 files changed, 935 insertions, 0 deletions
diff --git a/matrix.lua b/matrix.lua
new file mode 100644
index 00000000..04ad6973
--- /dev/null
+++ b/matrix.lua
@@ -0,0 +1,935 @@
+
+local ffi = require 'ffi'
+local gsl = require 'gsl'
+
+local sqrt, abs = math.sqrt, math.abs
+local format = string.format
+
+local check = require 'check'
+local is_integer, is_real = check.is_integer, check.is_real
+
+local lua_index_style = gslsh.lua_index_style
+
+local gsl_matrix = ffi.typeof('gsl_matrix')
+local gsl_matrix_complex = ffi.typeof('gsl_matrix_complex')
+local gsl_complex = ffi.typeof('complex')
+
+local gsl_check = require 'gsl-check'
+
+local function check_real(x)
+ if type(x) ~= 'number' then error('expected real number', 3) end
+ return x
+end
+
+local function get_typeid(a)
+ if is_real(a) then return true, true
+ elseif ffi.istype(gsl_complex, a) then return false, true
+ elseif ffi.istype(gsl_matrix, a) then return true, false
+ elseif ffi.istype(gsl_matrix_complex, a) then return false, false end
+end
+
+local function check_typeid(a)
+ local isr, iss = get_typeid(a)
+ if isr == nil then error('expected matrix of scalar', 2) end
+ return isr, iss
+end
+
+local function cartesian(x)
+ if is_real(x) then
+ return x, 0
+ else
+ return x[0], x[1]
+ end
+end
+
+local function block_alloc(n)
+ local b = ffi.cast('gsl_block *', ffi.C.malloc(ffi.sizeof('gsl_block')))
+ local data = ffi.C.malloc(n * ffi.sizeof('double'))
+ b.size, b.data, b.ref_count = n, data, 1
+ return b
+end
+
+local function block_calloc(n)
+ local b = ffi.cast('gsl_block_complex *', ffi.C.malloc(ffi.sizeof('gsl_block_complex')))
+ local data = ffi.C.malloc(2 * n * ffi.sizeof('double'))
+ b.size, b.data, b.ref_count = n, data, 1
+ return b
+end
+
+local function matrix_alloc(n1, n2)
+ local b = block_alloc(n1 * n2)
+ local m = gsl_matrix(n1, n2, n2, b.data, b, 1)
+ return m
+end
+
+local function matrix_calloc(n1, n2)
+ local b = block_calloc(n1 * n2)
+ local m = gsl_matrix_complex(n1, n2, n2, b.data, b, 1)
+ return m
+end
+
+local function matrix_new(n1, n2, f)
+ local m = matrix_alloc(n1, n2)
+ if f then
+ for i=0, n1-1 do
+ for j=0, n2-1 do
+ local x = check_real(lua_index_style and f(i+1, j+1) or f(i,j))
+ m.data[i*n2+j] = x
+ end
+ end
+ else
+ ffi.fill(m.data, n1 * n2 * ffi.sizeof('double'), 0)
+ end
+ return m
+end
+
+local function matrix_cnew(n1, n2, f)
+ local m = matrix_calloc(n1, n2)
+ if f then
+ for i=0, n1-1 do
+ for j=0, n2-1 do
+ local z = lua_index_style and f(i+1, j+1) or f(i,j)
+ local x, y = cartesian(z)
+ m.data[2*i*n2+2*j ] = x
+ m.data[2*i*n2+2*j+1] = y
+ end
+ end
+ else
+ ffi.fill(m.data, 2 * n1 * n2 * ffi.sizeof('double'), 0)
+ end
+ return m
+end
+
+local function matrix_free(m)
+ if m.owner then
+ local b = m.block
+ b.ref_count = b.ref_count - 1
+ if b.ref_count == 0 then
+ ffi.C.free(b.data)
+ ffi.C.free(b)
+ end
+ end
+end
+
+local function matrix_dim(m)
+ return m.size1, m.size2
+end
+
+local function matrix_len(m)
+ return m.size1
+end
+
+local function matrix_copy(a)
+ local n1, n2 = matrix_dim(a)
+ local b = matrix_alloc(n1, n2)
+ gsl.gsl_matrix_memcpy(b, a)
+ return b
+end
+
+local function matrix_complex_copy(a)
+ local n1, n2 = matrix_dim(a)
+ local b = matrix_calloc(n1, n2)
+ gsl.gsl_matrix_complex_memcpy(b, a)
+ return b
+end
+
+local function check_indices(m, i, j)
+ local r, c = matrix_dim(m)
+ if lua_index_style then i, j = i-1, j-1 end
+ if i < 0 or i >= r or j < 0 or j >= c then
+ error('matrix index out of bounds', 3)
+ end
+ return i, j
+end
+
+local function check_row_index(m, i)
+ if lua_index_style then i = i-1 end
+ if i < 0 or i >= matrix_len(m) then
+ error('matrix index out of bounds', 3)
+ end
+ return i
+end
+
+local function check_col_index(m, j)
+ if lua_index_style then j = j-1 end
+ if j < 0 or j >= m.size2 then
+ error('matrix index out of bounds', 3)
+ end
+ return j
+end
+
+local function matrix_get(m, i, j)
+ i, j = check_indices(m, i, j)
+ return gsl.gsl_matrix_get(m, i, j)
+end
+
+local function matrix_complex_get(m, i, j)
+ i, j = check_indices(m, i, j)
+ return gsl.gsl_matrix_complex_get(m, i, j)
+end
+
+local function matrix_set(m, i, j, v)
+ i, j = check_indices(m, i, j)
+ return gsl.gsl_matrix_set(m, i, j, v)
+end
+
+local function matrix_complex_set(m, i, j, v)
+ i, j = check_indices(m, i, j)
+ return gsl.gsl_matrix_complex_set(m, i, j, v)
+end
+
+local function complex_conj(z)
+ local x, y = cartesian(z)
+ return gsl_complex(x, -y)
+end
+
+local function complex_real(z)
+ local x = cartesian(z)
+ return x
+end
+
+local function complex_imag(z)
+ local x, y = cartesian(z)
+ return y
+end
+
+local function complex_abs(z)
+ local x, y = cartesian(z)
+ return sqrt(x*x + y*y)
+end
+
+local function itostr(im, eps, fmt, signed)
+ local absim = abs(im)
+ local sign = im + eps < 0 and '-' or (signed and '+' or '')
+ if absim < eps then return (signed and '' or '0') else
+ return sign .. (abs(absim-1) < eps and 'i' or format(fmt..'i', absim))
+ end
+end
+
+local function recttostr(x, y, eps)
+ local x_sub, y_sub = abs(x) < eps, abs(y) < eps
+
+ local fmt = '%.8g'
+ if x_sub and y_sub then
+ fmt, x_sub, y_sub = '%.0f', x==0, y==0
+ end
+
+ if not x_sub then
+ local sign = x+eps < 0 and '-' or ''
+ local ax = abs(x)
+ if y_sub then
+ return format('%s'..fmt, sign, ax)
+ else
+ return format('%s'..fmt..'%s', sign, ax, itostr(y, eps, fmt, true))
+ end
+ else
+ return (y_sub and '0' or itostr(y, eps, fmt, false))
+ end
+end
+
+local function concat_pad(t, pad)
+ local sep = ' '
+ local row
+ for i, s in ipairs(t) do
+ local x = string.rep(' ', pad - #s) .. s
+ row = row and (row .. sep .. x) or x
+ end
+ return row
+end
+
+local function matrix_tostring_gen(sel)
+ return function(m)
+ local n1, n2 = matrix_dim(m)
+ local sq = 0
+ for i=0, n1-1 do
+ for j=0, n2-1 do
+ local x, y = sel(m, i, j)
+ sq = sq + x*x + y*y
+ end
+ end
+ local eps = sqrt(sq) * 1e-8
+ eps = eps > 0 and eps or 1
+
+ lsrow = {}
+ local lmax = 0
+ for i=0, n1-1 do
+ local row = {}
+ for j=0, n2-1 do
+ local x, y = sel(m, i, j)
+ local s = recttostr(x, y, eps)
+ if #s > lmax then lmax = #s end
+ row[j+1] = s
+ end
+ lsrow[i+1] = row
+ end
+
+ local ss = {}
+ for i=0, n1-1 do
+ ss[i+1] = '[ ' .. concat_pad(lsrow[i+1], lmax) .. ' ]'
+ end
+
+ return table.concat(ss, '\n')
+ end
+end
+
+local function matrix_col(m, j)
+ j = check_col_index (m, j)
+ local mb = m.block
+ local r = gsl_matrix(m.size1, 1, m.tda, m.data + j, mb, 1)
+ mb.ref_count = mb.ref_count + 1
+ return r
+end
+
+local function matrix_row(m, i)
+ i = check_row_index (m, i)
+ local mb = m.block
+ local r = gsl_matrix(1, m.size2, 1, m.data + i*m.tda, mb, 1)
+ mb.ref_count = mb.ref_count + 1
+ return r
+end
+
+local function matrix_row_as_column(m, i)
+ i = check_row_index (m, i)
+ local mb = m.block
+ local r = gsl_matrix(m.size2, 1, 1, m.data + i*m.tda, mb, 1)
+ mb.ref_count = mb.ref_count + 1
+ return r
+end
+
+local function matrix_slice(m, i, j, ni, nj)
+ check_indices (m, i+ni-1, j+nj-1)
+ i, j = check_indices (m, i, j)
+ local mb = m.block
+ local r = gsl_matrix(ni, nj, m.tda, m.data + i*m.tda + j, mb, 1)
+ mb.ref_count = mb.ref_count + 1
+ return r
+end
+
+local function matrix_complex_col(m, j)
+ j = check_col_index (m, j)
+ local mb = m.block
+ local r = gsl_matrix_complex(m.size1, 1, m.tda, m.data + 2*j, mb, 1)
+ mb.ref_count = mb.ref_count + 1
+ return r
+end
+
+local function matrix_complex_row(m, i)
+ i = check_row_index (m, i)
+ local mb = m.block
+ local r = gsl_matrix_complex(1, m.size2, 1, m.data + 2*i*m.tda, mb, 1)
+ mb.ref_count = mb.ref_count + 1
+ return r
+end
+
+local function matrix_complex_row_as_column(m, i)
+ i = check_row_index (m, i)
+ local mb = m.block
+ local r = gsl_matrix_complex(m.size2, 1, 1, m.data + 2*i*m.tda, mb, 1)
+ mb.ref_count = mb.ref_count + 1
+ return r
+end
+
+local function matrix_complex_slice(m, i, j, ni, nj)
+ check_indices (m, i+ni-1, j+nj-1)
+ i, j = check_indices (m, i, j)
+ local mb = m.block
+ local r = gsl_matrix_complex(ni, nj, m.tda, m.data + 2*i*m.tda + 2*j, mb, 1)
+ mb.ref_count = mb.ref_count + 1
+ return r
+end
+
+local function matrix_vect_def(t)
+ local n = #t
+ local isr = true
+ for i=1,n do
+ if not is_real(t[i]) then
+ isr = false
+ break
+ end
+ end
+
+ if isr then
+ local m = matrix_alloc(n, 1)
+ for i=0, n-1 do
+ m.data[i] = t[i+1]
+ end
+ return m
+ else
+ local m = matrix_calloc(n, 1)
+ for i=0, n-1 do
+ local x, y = cartesian(t[i+1])
+ m.data[2*i ] = x
+ m.data[2*i+1] = y
+ end
+ return m
+ end
+end
+
+local function mat_op_gen(n1, n2, opa, a, opb, b, oper)
+ local c = matrix_alloc(n1, n2)
+ for i = 0, n1-1 do
+ for j = 0, n2-1 do
+ local ar = opa(a,i,j)
+ local br = opb(b,i,j)
+ c.data[i*n2+j] = oper(ar, br)
+ end
+ end
+ return c
+end
+
+local function mat_comp_op_gen(n1, n2, opa, a, opb, b, oper)
+ local c = matrix_calloc(n1, n2)
+ for i = 0, n1-1 do
+ for j = 0, n2-1 do
+ local ar, ai = opa(a,i,j)
+ local br, bi = opb(b,i,j)
+ local zr, zi = oper(ar, br, ai, bi)
+ c.data[2*i*n2+2*j ] = zr
+ c.data[2*i*n2+2*j+1] = zi
+ end
+ end
+ return c
+end
+
+local function real_get(x) return x, 0 end
+local function complex_get(z) return z[0], z[1] end
+local function mat_real_get(m,i,j) return m.data[i*m.tda+j], 0 end
+
+local function mat_complex_get(m,i,j)
+ local idx = 2*i*m.tda+2*j
+ return m.data[idx], m.data[idx+1]
+end
+
+local function selector(r, s)
+ if s then
+ return r and real_get or complex_get
+ else
+ return r and mat_real_get or mat_complex_get
+ end
+end
+
+local function mat_complex_of_real(m)
+ local n1, n2 = matrix_dim(m)
+ local mc = matrix_calloc(n1, n2)
+ for i=0, n1-1 do
+ for j=0, n2-1 do
+ mc.data[2*i*n2+2*j ] = m.data[i*n2+j]
+ mc.data[2*i*n2+2*j+1] = 0
+ end
+ end
+ return mc
+end
+
+local function opadd(ar, br, ai, bi)
+ if ai then return ar+br, ai+bi else return ar+br end
+end
+
+local function opsub(ar, br, ai, bi)
+ if ai then return ar-br, ai-bi else return ar-br end
+end
+
+local function opmul(ar, br, ai, bi)
+ if ai then return ar*br-ai*bi, ar*bi+ai*br else return ar*br end
+end
+
+local function opdiv(ar, br, ai, bi)
+ if ai then
+ local d = br^2 + bi^2
+ return (ar*br + ai*bi)/d, (-ar*bi + ai*br)/d
+ else
+ return ar/br
+ end
+end
+
+local function vector_op(scalar_op, element_wise, no_inverse)
+ return function(a, b)
+ local ra, sa = get_typeid(a)
+ local rb, sb = get_typeid(b)
+ if not sb and no_inverse then
+ error 'invalid operation on matrix'
+ end
+ if sa and sb then
+ local ar, ai = cartesian(a)
+ local br, bi = cartesian(b)
+ local zr, zi = scalar_op(ar, br, ai, bi)
+ return gsl_complex(zr, zi)
+ elseif element_wise or sa or sb then
+ local sela, selb = selector(ra, sa), selector(rb, sb)
+ local n1 = (sa and b.size1 or a.size1)
+ local n2 = (sa and b.size2 or a.size2)
+ if ra and rb then
+ return mat_op_gen(n1, n2, sela, a, selb, b, scalar_op)
+ else
+ return mat_comp_op_gen(n1, n2, sela, a, selb, b, scalar_op)
+ end
+ else
+ if ra and rb then
+ local n1, n2 = a.size1, b.size2
+ local c = matrix_new(n1, n2)
+ local NT = gsl.CblasNoTrans
+ gsl_check(gsl.gsl_blas_dgemm(NT, NT, 1, a, b, 1, c))
+ return c
+ else
+ if ra then a = mat_complex_of_real(a) end
+ if rb then b = mat_complex_of_real(b) end
+ local n1, n2 = a.size1, b.size2
+ local c = matrix_cnew(n1, n2)
+ local NT = gsl.CblasNoTrans
+ gsl_check(gsl.gsl_blas_zgemm(NT, NT, 1, a, b, 1, c))
+ return c
+ end
+ end
+ end
+end
+
+local function matrix_norm(m)
+ local r, c = matrix_dim(m)
+ local tda = m.tda
+ local ssq, idx = 0, 0
+ for i = 0, r-1 do
+ for j = 0, c-1 do
+ local x = m.data[idx + j]
+ ssq = ssq + x*x
+ end
+ idx = idx + tda
+ end
+ return sqrt(ssq)
+end
+
+local function matrix_complex_norm(m)
+ local r, c = matrix_dim(m)
+ local tda = m.tda
+ local ssq, idx = 0, 0
+ for i = 0, r-1 do
+ for j = 0, c-1 do
+ local x, y = m.data[idx+2*j], m.data[idx+2*j+1]
+ ssq = ssq + x*x + y*y
+ end
+ idx = idx + 2*tda
+ end
+ return sqrt(ssq)
+end
+
+complex = {
+ new = gsl_complex,
+ conj = complex_conj,
+ real = complex_real,
+ imag = complex_imag,
+ abs = complex_abs,
+ rect = cartesian,
+}
+
+local generic_add = vector_op(opadd, true)
+local generic_sub = vector_op(opsub, true)
+local generic_mul = vector_op(opmul, false)
+local generic_div = vector_op(opdiv, true, true)
+
+local complex_mt = {
+
+ __add = generic_add,
+ __sub = generic_sub,
+ __mul = generic_mul,
+ __div = generic_div,
+
+ __pow = function(z,n)
+ if is_real(n) then
+ return gsl.gsl_complex_pow_real (z, n)
+ else
+ if is_real(z) then z = gsl_complex(z,0) end
+ return gsl.gsl_complex_pow (z, n)
+ end
+ end,
+}
+
+ffi.metatype(gsl_complex, complex_mt)
+
+local function matrix_new_unit(n)
+ local m = matrix_alloc(n, n)
+ for i=0, n-1 do
+ for j=0, n-1 do
+ m.data[i*n+j] = (i == j and 1 or 0)
+ end
+ end
+ return m
+end
+
+local function matrix_fset(m, f)
+ local n1, n2 = matrix_dim(m)
+ for i=1, n1 do
+ for j=1, n2 do
+ m:set(i, j, f(i,j))
+ end
+ end
+end
+
+local function matrix_set_equal(a, b)
+ local n1a, n2a = matrix_dim(a)
+ local n1b, n2b = matrix_dim(b)
+
+ if n1a ~= n1b or n2a ~= n2b then
+ error('matrix dimensions does not match', 2)
+ end
+
+ for i=1, n1b do
+ for j=1, n2b do
+ a:set(i, j, b:get(i,j))
+ end
+ end
+end
+
+local function matrix_new_transpose(a)
+ local n1, n2 = matrix_dim(a)
+ local b = a.alloc(n2, n1)
+ for i=1, n2 do
+ for j=1, n1 do
+ b:set(i, j, a:get(j,i))
+ end
+ end
+ return b
+end
+
+local function matrix_new_hc(a)
+ local n1, n2 = matrix_dim(a)
+ local b = a.alloc(n2, n1)
+ for i=1, n2 do
+ for j=1, n1 do
+ b:set(i, j, complex.conj(a:get(j,i)))
+ end
+ end
+ return b
+end
+
+local function matrix_new_copy(m)
+ return m:copy()
+end
+
+matrix = {
+ new = matrix_new,
+ cnew = matrix_cnew,
+ alloc = matrix_alloc,
+ calloc = matrix_calloc,
+ copy = matrix_new_copy,
+ unit = matrix_new_unit,
+ dim = matrix_dim,
+ vec = matrix_vect_def,
+ set = matrix_set_equal,
+ fset = matrix_fset,
+ block = block_alloc,
+
+ transpose = matrix_new_transpose,
+ hc = matrix_new_hc,
+}
+
+local matrix_methods = {
+ alloc = matrix_alloc,
+ col = matrix_col,
+ row = matrix_row,
+ get = matrix_get,
+ set = matrix_set,
+ copy = matrix_copy,
+ norm = matrix_norm,
+ slice = matrix_slice,
+}
+
+local function matrix_index(m, i)
+ if is_integer(i) then
+ if m.size2 == 1 then
+ i = check_row_index (m, i)
+ return m.data[i * m.tda]
+ else
+ return matrix_row_as_column(m, i)
+ end
+ end
+ return matrix_methods[i]
+end
+
+local function matrix_newindex(m, k, v)
+ if is_integer(k) then
+ local nr, nc = matrix_dim(m)
+ local isr, iss = check_typeid(v)
+ k = check_row_index (m, k)
+ if not isr then error('cannot assign element to a complex value') end
+ if nc == 1 then
+ if not iss then error('invalid assignment: expecting a scalar') end
+ m.data[k*m.tda] = v
+ else
+ if iss then error('invalid assignment: expecting a row matrix') end
+ if v.size1 ~= nc or v.size2 ~= 1 then
+ error('incompatible matrix dimensions in assignment')
+ end
+ for j = 0, nc-1 do
+ m.data[k*m.tda+j] = v.data[v.tda*j]
+ end
+ end
+ else
+ error 'cannot set a matrix field'
+ end
+end
+
+local matrix_mt = {
+ __gc = matrix_free,
+
+ __add = generic_add,
+ __sub = generic_sub,
+ __mul = generic_mul,
+ __div = generic_div,
+
+ __len = matrix_len,
+
+ __index = matrix_index,
+ __newindex = matrix_newindex,
+
+ __tostring = matrix_tostring_gen(mat_real_get),
+}
+
+ffi.metatype(gsl_matrix, matrix_mt)
+
+local matrix_complex_methods = {
+ alloc = matrix_calloc,
+ col = matrix_complex_col,
+ row = matrix_complex_row,
+ get = matrix_complex_get,
+ set = matrix_complex_set,
+ copy = matrix_complex_copy,
+ norm = matrix_complex_norm,
+ slice = matrix_complex_slice,
+}
+
+local function matrix_complex_index(m, i)
+ if is_integer(i) then
+ if m.size2 == 1 then
+ i = check_row_index (m, i)
+ return gsl_complex(m.data[2*i*m.tda], m.data[2*i*m.tda+1])
+ else
+ return matrix_complex_row_as_column(m, i)
+ end
+ end
+ return matrix_complex_methods[i]
+end
+
+local function matrix_complex_newindex(m, k, v)
+ if is_integer(k) then
+ local nr, nc = matrix_dim(m)
+ local isr, iss = check_typeid(v)
+ k = check_row_index (m, k)
+ if nc == 1 then
+ if not iss then error('invalid assignment: expecting a scalar') end
+ local vx, vy = cartesian(v)
+ m.data[2*k*m.tda ] = vx
+ m.data[2*k*m.tda+1] = vy
+ else
+ if iss then error('invalid assignment: expecting a row matrix') end
+ if v.size1 ~= nc or v.size2 ~= 1 then
+ error('incompatible matrix dimensions in assignment')
+ end
+ local sel = selector(isr, iss)
+ for j = 0, nc-1 do
+ local vx, vy = sel(v, j, 0)
+ m.data[2*k*m.tda+2*j ] = vx
+ m.data[2*k*m.tda+2*j+1] = vy
+ end
+ end
+ else
+ error 'cannot set a matrix field'
+ end
+end
+
+local matrix_complex_mt = {
+ __gc = matrix_free,
+
+ __add = generic_add,
+ __sub = generic_sub,
+ __mul = generic_mul,
+ __div = generic_div,
+
+ __len = matrix_len,
+
+ __index = matrix_complex_index,
+ __newindex = matrix_complex_newindex,
+
+ __tostring = matrix_tostring_gen(mat_complex_get),
+}
+
+ffi.metatype(gsl_matrix_complex, matrix_complex_mt)
+
+local function c_function_lookup(name)
+ return gsl['gsl_complex_' .. name]
+end
+
+local function c_invtrig_lookup(name)
+ return gsl['gsl_complex_arc' .. name]
+end
+
+local function csqrt(x)
+ return (is_real(x) and x >= 0) and sqrt(x) or gsl.gsl_complex_sqrt(x)
+end
+
+local gsl_function_list = {
+ 'exp', 'log', 'log10',
+ 'sin', 'cos', 'sec', 'csc', 'tan', 'cot',
+ 'sinh', 'cosh', 'sech', 'csch', 'tanh', 'coth',
+}
+
+local gsl_inverse_trig_list = {
+ 'sin', 'cos', 'sec', 'csc', 'tan', 'cot',
+ 'sinh', 'cosh', 'sech', 'csch', 'tanh', 'coth'
+}
+
+for _, name in ipairs(gsl_function_list) do
+ complex[name] = c_function_lookup(name)
+end
+
+for _, name in ipairs(gsl_inverse_trig_list) do
+ complex['a' .. name] = c_invtrig_lookup(name)
+end
+
+complex.sqrt = csqrt
+
+local function matrix_def(t)
+ local r, c = #t, #t[1]
+ local m = matrix_alloc(r, c)
+ for i= 0, r-1 do
+ local row = t[i+1]
+ for j = 0, c-1 do
+ local x = row[j+1]
+ if not is_real(x) then error('expected real number') end
+ m.data[i*c+j] = x
+ end
+ end
+ return m
+end
+
+local function matrix_cdef(t)
+ local r, c = #t, #t[1]
+ local m = matrix_calloc(r, c)
+ for i= 0, r-1 do
+ local row = t[i+1]
+ for j = 0, c-1 do
+ local x, y = cartesian(row[j+1])
+ m.data[2*i*c+2*j ] = x
+ m.data[2*i*c+2*j+1] = y
+ end
+ end
+ return m
+end
+
+local signum = ffi.new('int[1]')
+
+local function matrix_inv(m)
+ local n = m.size1
+ local lu = matrix_copy(m)
+ local p = ffi.gc(gsl.gsl_permutation_alloc(n), gsl.gsl_permutation_free)
+ gsl_check(gsl.gsl_linalg_LU_decomp(lu, p, signum))
+ local mi = matrix_alloc(n, n)
+ gsl_check(gsl.gsl_linalg_LU_invert(lu, p, mi))
+ return mi
+end
+
+local function matrix_solve(m, b)
+ local n = m.size1
+ local lu = matrix_copy(m)
+ local p = ffi.gc(gsl.gsl_permutation_alloc(n), gsl.gsl_permutation_free)
+ gsl_check(gsl.gsl_linalg_LU_decomp(lu, p, signum))
+ local x = matrix_alloc(n, 1)
+ local xv = gsl.gsl_matrix_column(x, 0)
+ local bv = gsl.gsl_matrix_column(b, 0)
+ gsl_check(gsl.gsl_linalg_LU_solve(lu, p, bv, xv))
+ return x
+end
+
+local function matrix_complex_inv(m)
+ local n = m.size1
+ local lu = matrix_complex_copy(m)
+ local p = ffi.gc(gsl.gsl_permutation_alloc(n), gsl.gsl_permutation_free)
+ gsl_check(gsl.gsl_linalg_complex_LU_decomp(lu, p, signum))
+ local mi = matrix_calloc(n, n)
+ gsl_check(gsl.gsl_linalg_complex_LU_invert(lu, p, mi))
+ return mi
+end
+
+local function matrix_complex_solve(m, b)
+ local n = m.size1
+ local lu = matrix_complex_copy(m)
+ local p = ffi.gc(gsl.gsl_permutation_alloc(n), gsl.gsl_permutation_free)
+ gsl_check(gsl.gsl_linalg_complex_LU_decomp(lu, p, signum))
+ local x = matrix_calloc(n, 1)
+ local xv = gsl.gsl_matrix_complex_column(x, 0)
+ local bv = gsl.gsl_matrix_complex_column(b, 0)
+ gsl_check(gsl.gsl_linalg_complex_LU_solve(lu, p, bv, xv))
+ return x
+end
+
+function matrix.inv(m)
+ if ffi.istype(gsl_matrix, m) then
+ return matrix_inv(m)
+ else
+ return matrix_complex_inv(m)
+ end
+end
+
+function matrix.solve(m, b)
+ local mr = ffi.istype(gsl_matrix, m)
+ local br = ffi.istype(gsl_matrix, b)
+ if mr and br then
+ return matrix_solve(m, b)
+ else
+ if mr then m = mat_complex_of_real(m) end
+ if br then b = mat_complex_of_real(b) end
+ return matrix_complex_solve(m, b)
+ end
+end
+
+local function matrix_sv_decomp(a, v, s, w)
+ local sv = gsl.gsl_matrix_column(s, 0)
+ local w
+ if w then
+ wv = gsl.gsl_matrix_column(w, 0)
+ else
+ local m, n = matrix_dim(a)
+ wv = ffi.gc(gsl.gsl_vector_alloc(n), gsl.gsl_vector_free)
+ end
+ gsl_check(gsl.gsl_linalg_SV_decomp (a, v, sv, wv))
+end
+
+function matrix.svd(a)
+ local m, n = matrix_dim(a)
+ local u = matrix_copy(a)
+ local v = matrix_alloc(n, n)
+ local s = matrix_new(n, n)
+ local sv = gsl.gsl_matrix_diagonal(s)
+ local wv = ffi.gc(gsl.gsl_vector_alloc(n), gsl.gsl_vector_free)
+ gsl_check(gsl.gsl_linalg_SV_decomp (u, v, sv, wv))
+ return u, s, v
+end
+
+matrix.sv_decomp = matrix_sv_decomp
+
+matrix.diag = function(d)
+ local n = #d
+ local m = d.alloc(n, n)
+ local mset, dget = m.set, d.get
+ for i=1, n do
+ for j= 1, n do
+ local x = (i ~= j and 0 or dget(d, i, 1))
+ mset(m, i, j, x)
+ end
+ end
+ return m
+ end
+
+matrix.tr = function(a)
+ local m, n = matrix_dim(a)
+ local b = a.alloc(n, m)
+ local bset, aget = b.set, a.get
+ for i=1, n do
+ for j= 1, m do
+ bset(b, i, j, aget(a, j, i))
+ end
+ end
+ return b
+ end
+
+matrix.def = matrix_def
+matrix.cdef = matrix_cdef
+
+return matrix
generated by cgit v1.2.3 (git 2.39.1) at 2025年09月16日 09:47:36 +0000

AltStyle によって変換されたページ (->オリジナル) /