-rw-r--r-- | fft-init.lua | 30 | ||||
-rw-r--r-- | iter.lua | 22 | ||||
-rw-r--r-- | matrix.lua | 11 |
diff --git a/fft-init.lua b/fft-init.lua index cd298417..d7bad58b 100644 --- a/fft-init.lua +++ b/fft-init.lua @@ -228,11 +228,22 @@ local function hc_length(ft) return tonumber(ft.size) end +local function halfcomplex_to_matrix(hc) + return matrix.cnew(tonumber(hc.size), 1, function(i) return hc[i-1] end) +end + +local function hc_tostring(hc) + local m = halfcomplex_to_matrix(hc) + return m:show() +end + local function hc_radix2_index(ft, k) if is_integer(k) then local idx = halfcomplex_radix2_index local size, stride = tonumber(ft.size), tonumber(ft.stride) return halfcomplex_get(idx, ft.data, size, stride, k) + elseif k == 'show' then + return hc_tostring end end @@ -249,6 +260,8 @@ local function hc_index(ft, k) local idx = halfcomplex_index local size, stride = tonumber(ft.size), tonumber(ft.stride) return halfcomplex_get(idx, ft.data, size, stride, k) + elseif k == 'show' then + return hc_tostring end end @@ -260,14 +273,6 @@ local function hc_newindex(ft, k, z) end end -local function halfcomplex_to_matrix(hc) - return matrix.cnew(tonumber(hc.size), 1, function(i) return hc[i-1] end) -end - -local function hc_tostring(hc) - return tostring(halfcomplex_to_matrix(hc)) -end - local function hc_free(hc) local b = hc.block b.ref_count = b.ref_count - 1 @@ -282,7 +287,7 @@ ffi.metatype(fft_hc, { __index = hc_index, __newindex = hc_newindex, __len = hc_length, - __tostring = hc_tostring, +-- __tostring = hc_tostring, } ) @@ -291,6 +296,11 @@ ffi.metatype(fft_radix2_hc, { __index = hc_radix2_index, __newindex = hc_radix2_newindex, __len = hc_length, - __tostring = hc_tostring, +-- __tostring = hc_tostring, } ) + +local register_ffi_type = debug.getregistry().__gsl_reg_ffi_type + +register_ffi_type(fft_radix2_hc, "radix2 half-complex vector") +register_ffi_type(fft_hc, "half-complex vector") @@ -22,11 +22,25 @@ local cat = table.concat local fmt = string.format do + local ffi = require('ffi') local reg = debug.getregistry() - gsl_type = function(t) - local s = reg.__gsl_type(t) - return (s == "cdata" and reg.__gsl_ffi_type(t) or s) + reg.__gsl_ffi_types = {} + + function reg.__gsl_reg_ffi_type(ctype, name) + local t = reg.__gsl_ffi_types + t[#t + 1] = {ctype, name} + end + + gsl_type = function(obj) + local s = reg.__gsl_type(obj) + if s == "cdata" then + for _, item in ipairs(reg.__gsl_ffi_types) do + local ctype, name = unpack(item) + if ffi.istype(ctype, obj) then return name end + end + end + return s end end @@ -79,7 +93,7 @@ tos = function (t, depth) end elseif tp == 'cdata' then local tpext = gsl_type and gsl_type(t) or tp - if tpext == 'matrix' or tpext == 'complex matrix' then + if tpext ~= 'cdata' and t.show then return (depth == 0 and t:show() or fmt("<%s: %p>", tpext, t)) end end diff --git a/matrix.lua b/matrix.lua index 4e60391b..19390dfc 100644 --- a/matrix.lua +++ b/matrix.lua @@ -960,13 +960,10 @@ matrix.tr = function(a) matrix.def = matrix_def matrix.cdef = matrix_cdef -local reg = debug.getregistry() +local register_ffi_type = debug.getregistry().__gsl_reg_ffi_type -function reg.__gsl_ffi_type(a) - if ffi.istype(gsl_complex, a) then return "complex" - elseif ffi.istype(gsl_matrix, a) then return "matrix" - elseif ffi.istype(gsl_matrix_complex, a) then return "complex matrix" end - return "cdata" -end +register_ffi_type(gsl_complex, "complex") +register_ffi_type(gsl_matrix, "matrix") +register_ffi_type(gsl_matrix_complex, "complex matrix") return matrix |