|  | 
|  | 1 | +#:include "common.fypp" | 
|  | 2 | +#:set R_KINDS_TYPES = list(zip(REAL_KINDS, REAL_TYPES, REAL_SUFFIX, REAL_INIT)) | 
|  | 3 | +#:set C_KINDS_TYPES = list(zip(CMPLX_KINDS, CMPLX_TYPES, CMPLX_SUFFIX, CMPLX_INIT)) | 
|  | 4 | +#:set RC_KINDS_TYPES = R_KINDS_TYPES + C_KINDS_TYPES | 
|  | 5 | +submodule (stdlib_linalg) stdlib_linalg_matrix_functions | 
|  | 6 | + use stdlib_constants | 
|  | 7 | + use stdlib_linalg_constants | 
|  | 8 | + use stdlib_linalg_blas, only: gemm | 
|  | 9 | + use stdlib_linalg_lapack, only: gesv, lacpy | 
|  | 10 | + use stdlib_linalg_lapack_aux, only: handle_gesv_info | 
|  | 11 | + use stdlib_linalg_state, only: linalg_state_type, linalg_error_handling, LINALG_ERROR, & | 
|  | 12 | + LINALG_INTERNAL_ERROR, LINALG_VALUE_ERROR | 
|  | 13 | + implicit none(type, external) | 
|  | 14 | + | 
|  | 15 | + character(len=*), parameter :: this = "matrix_exponential" | 
|  | 16 | + | 
|  | 17 | +contains | 
|  | 18 | + | 
|  | 19 | + #:for k,t,s, i in RC_KINDS_TYPES  | 
|  | 20 | + module function stdlib_linalg_${i}$_expm_fun(A, order) result(E) | 
|  | 21 | + !> Input matrix A(n, n). | 
|  | 22 | + ${t},ドル intent(in) :: A(:, :) | 
|  | 23 | + !> [optional] Order of the Pade approximation. | 
|  | 24 | + integer(ilp), optional, intent(in) :: order | 
|  | 25 | + !> Exponential of the input matrix E = exp(A). | 
|  | 26 | + ${t},ドル allocatable :: E(:, :) | 
|  | 27 | + | 
|  | 28 | + E = A | 
|  | 29 | + call stdlib_linalg_${i}$_expm_inplace(E, order) | 
|  | 30 | + end function stdlib_linalg_${i}$_expm_fun | 
|  | 31 | + | 
|  | 32 | + module subroutine stdlib_linalg_${i}$_expm(A, E, order, err) | 
|  | 33 | + !> Input matrix A(n, n). | 
|  | 34 | + ${t},ドル intent(in) :: A(:, :) | 
|  | 35 | + !> Exponential of the input matrix E = exp(A). | 
|  | 36 | + ${t},ドル intent(out) :: E(:, :) | 
|  | 37 | + !> [optional] Order of the Pade approximation. | 
|  | 38 | + integer(ilp), optional, intent(in) :: order | 
|  | 39 | + !> [optional] State return flag. | 
|  | 40 | + type(linalg_state_type), optional, intent(out) :: err | 
|  | 41 | + | 
|  | 42 | + type(linalg_state_type) :: err0 | 
|  | 43 | + integer(ilp) :: lda, n, lde, ne | 
|  | 44 | + | 
|  | 45 | + ! Check E sizes | 
|  | 46 | + lda = size(A, 1, kind=ilp) ; n = size(A, 2, kind=ilp) | 
|  | 47 | + lde = size(E, 1, kind=ilp) ; ne = size(E, 2, kind=ilp) | 
|  | 48 | + | 
|  | 49 | + if (lda<1 .or. n<1 .or. lda/=n .or. lde/=n .or. ne/=n) then  | 
|  | 50 | + err0 = linalg_state_type(this,LINALG_VALUE_ERROR, & | 
|  | 51 | + 'invalid matrix sizes: A must be square (lda=', lda, ', n=', n, ')', & | 
|  | 52 | + ' E must be square (lde=', lde, ', ne=', ne, ')') | 
|  | 53 | + else | 
|  | 54 | + call lacpy("n", n, n, A, n, E, n) ! E = A | 
|  | 55 | + call stdlib_linalg_${i}$_expm_inplace(E, order, err0) | 
|  | 56 | + endif | 
|  | 57 | + | 
|  | 58 | + ! Process output and return | 
|  | 59 | + call linalg_error_handling(err0,err) | 
|  | 60 | + | 
|  | 61 | + return | 
|  | 62 | + end subroutine stdlib_linalg_${i}$_expm | 
|  | 63 | + | 
|  | 64 | + module subroutine stdlib_linalg_${i}$_expm_inplace(A, order, err) | 
|  | 65 | + !> Input matrix A(n, n) / Output matrix exponential. | 
|  | 66 | + ${t},ドル intent(inout) :: A(:, :) | 
|  | 67 | + !> [optional] Order of the Pade approximation. | 
|  | 68 | + integer(ilp), optional, intent(in) :: order | 
|  | 69 | + !> [optional] State return flag. | 
|  | 70 | + type(linalg_state_type), optional, intent(out) :: err | 
|  | 71 | + | 
|  | 72 | + ! Internal variables. | 
|  | 73 | + ${t}$ :: A2(size(A, 1), size(A, 2)), Q(size(A, 1), size(A, 2)) | 
|  | 74 | + ${t}$ :: X(size(A, 1), size(A, 2)), X_tmp(size(A, 1), size(A, 2)) | 
|  | 75 | + real(${k}$) :: a_norm, c | 
|  | 76 | + integer(ilp) :: m, n, ee, k, s, order_, i, j | 
|  | 77 | + logical(lk) :: p | 
|  | 78 | + type(linalg_state_type) :: err0 | 
|  | 79 | + | 
|  | 80 | + ! Deal with optional args. | 
|  | 81 | + order_ = 10 ; if (present(order)) order_ = order | 
|  | 82 | + | 
|  | 83 | + ! Problem's dimension. | 
|  | 84 | + m = size(A, dim=1, kind=ilp) ; n = size(A, dim=2, kind=ilp) | 
|  | 85 | + | 
|  | 86 | + if (m /= n) then | 
|  | 87 | + err0 = linalg_state_type(this,LINALG_VALUE_ERROR,'Invalid matrix size A=',[m, n]) | 
|  | 88 | + else if (order_ < 0) then | 
|  | 89 | + err0 = linalg_state_type(this, LINALG_VALUE_ERROR, 'Order of Pade approximation & | 
|  | 90 | + needs to be positive, order=', order_) | 
|  | 91 | + else | 
|  | 92 | + ! Compute the L-infinity norm. | 
|  | 93 | + a_norm = mnorm(A, "inf") | 
|  | 94 | + | 
|  | 95 | + ! Determine scaling factor for the matrix. | 
|  | 96 | + ee = int(log(a_norm) / log2_${k},ドル kind=ilp) + 1 | 
|  | 97 | + s = max(0, ee+1) | 
|  | 98 | + | 
|  | 99 | + ! Scale the input matrix & initialize polynomial. | 
|  | 100 | + A2 = A/2.0_${k}$**s | 
|  | 101 | + call lacpy("n", n, n, A2, n, X, n) ! X = A2 | 
|  | 102 | + | 
|  | 103 | + ! First step of the Pade approximation. | 
|  | 104 | + c = 0.5_${k}$ | 
|  | 105 | + do concurrent(i=1:n, j=1:n) | 
|  | 106 | + A(i, j) = merge(1.0_${k}$ + c*A2(i, j), c*A2(i, j), i == j) | 
|  | 107 | + Q(i, j) = merge(1.0_${k}$ - c*A2(i, j), -c*A2(i, j), i == j) | 
|  | 108 | + enddo | 
|  | 109 | + | 
|  | 110 | + ! Iteratively compute the Pade approximation. | 
|  | 111 | + p = .true. | 
|  | 112 | + do k = 2, order_ | 
|  | 113 | + c = c * (order_ - k + 1) / (k * (2*order_ - k + 1)) | 
|  | 114 | + call lacpy("n", n, n, X, n, X_tmp, n) ! X_tmp = X | 
|  | 115 | + call gemm("N", "N", n, n, n, one_${s},ドル A2, n, X_tmp, n, zero_${s},ドル X, n) | 
|  | 116 | + do concurrent(i=1:n, j=1:n) | 
|  | 117 | + A(i, j) = A(i, j) + c*X(i, j) ! E = E + c*X | 
|  | 118 | + Q(i, j) = merge(Q(i, j) + c*X(i, j), Q(i, j) - c*X(i, j), p) | 
|  | 119 | + enddo | 
|  | 120 | + p = .not. p | 
|  | 121 | + enddo | 
|  | 122 | + | 
|  | 123 | + block | 
|  | 124 | + integer(ilp) :: ipiv(n), info | 
|  | 125 | + call gesv(n, n, Q, n, ipiv, A, n, info) ! E = inv(Q) @ E | 
|  | 126 | + call handle_gesv_info(this, info, n, n, n, err0) | 
|  | 127 | + end block | 
|  | 128 | + | 
|  | 129 | + ! Matrix squaring. | 
|  | 130 | + do k = 1, s | 
|  | 131 | + call lacpy("n", n, n, A, n, X, n) ! X = A | 
|  | 132 | + call gemm("N", "N", n, n, n, one_${s},ドル X, n, X, n, zero_${s},ドル A, n) | 
|  | 133 | + enddo | 
|  | 134 | + endif | 
|  | 135 | + | 
|  | 136 | + call linalg_error_handling(err0, err) | 
|  | 137 | + | 
|  | 138 | + return | 
|  | 139 | + end subroutine stdlib_linalg_${i}$_expm_inplace | 
|  | 140 | + #:endfor | 
|  | 141 | + | 
|  | 142 | +end submodule stdlib_linalg_matrix_functions | 
0 commit comments