author | Francesco Abbate <francesco.bbt@gmail.com> | 2012年11月12日 23:51:18 +0100 |
---|---|---|
committer | Francesco Abbate <francesco.bbt@gmail.com> | 2012年11月12日 23:51:18 +0100 |
commit | 70ee4afa111c6f7456f102b75b314f0f83fd010e (patch) | |
tree | 8dea547262fd7e264052ca858a00ac6fadf334c4 | |
parent | ca4b11a05451867f6ec2df647c61044df3b69b50 (diff) | |
download | gsl-shell-70ee4afa111c6f7456f102b75b314f0f83fd010e.tar.gz |
-rw-r--r-- | doc/user-manual/matrices.rst | 5 | ||||
-rw-r--r-- | help/matrix.lua | 8 | ||||
-rw-r--r-- | matrix.lua | 24 |
diff --git a/doc/user-manual/matrices.rst b/doc/user-manual/matrices.rst index 3377a6fc..23f5d05d 100644 --- a/doc/user-manual/matrices.rst +++ b/doc/user-manual/matrices.rst @@ -215,10 +215,9 @@ All the functions described in this section have an equivalent function for comp Return the hermitian conjugate of the matrix. -.. function:: diag(v) +.. function:: diag(t) - Given a column vector ``v`` of length ``n`` returns a diagonal - matrix whose diagonal elements are equal to the elements of ``v``. + Given a table ``t`` of length ``n`` returns a diagonal matrix whose diagonal elements are equal to the elements of ``t``. .. function:: unit(n) diff --git a/help/matrix.lua b/help/matrix.lua index ca1361f2..168f7d16 100644 --- a/help/matrix.lua +++ b/help/matrix.lua @@ -60,16 +60,16 @@ matrix.hc(m) ]], [matrix.diag] = [[ -matrix.diag(v) +matrix.diag(t) - Given a column vector "v" of length "n", returns a diagonal - matrix whose diagonal elements are equal to the elements of "v". + Given a table "t" of length "n", returns a diagonal matrix whose + diagonal elements are equal to the elements of "t". ]], [matrix.unit] = [[ matrix.unit(n) - Return the unit matrix of dimension nxn. + Return the unit matrix of dimension n by n. ]], [matrix.set] = [[ diff --git a/matrix.lua b/matrix.lua index cbf24590..4e60391b 100644 --- a/matrix.lua +++ b/matrix.lua @@ -580,12 +580,9 @@ local complex_mt = { ffi.metatype(gsl_complex, complex_mt) local function matrix_new_unit(n) - local m = matrix_alloc(n, n) - for i=0, n-1 do - for j=0, n-1 do - m.data[i*n+j] = (i == j and 1 or 0) - end - end + local m = matrix.alloc(n, n) + for k = 0, n*n - 1 do m.data[k] = 0 end + for k = 0, n-1 do m.data[k*(n+1)] = 1 end return m end @@ -940,16 +937,11 @@ function matrix.svd(a) return u, s, v end -matrix.diag = function(d) - local n = #d - local m = d.alloc(n, n) - local mset, dget = m.set, d.get - for i=1, n do - for j= 1, n do - local x = (i ~= j and 0 or dget(d, i, 1)) - mset(m, i, j, x) - end - end +matrix.diag = function(t) + local n = #t + local m = matrix.alloc(n, n) + for k = 0, n*n - 1 do m.data[k] = 0 end + for k = 0, n-1 do m.data[k*(n+1)] = t[k+1] end return m end |