@@ -30,13 +30,13 @@ contains
3030 module subroutine stdlib_linalg_${ri}$_expm(A, E, order, err)
3131 !> Input matrix A(n, n).
3232 ${rt},ドル intent(in) :: A(:, :)
33-  !> [optional] Order of the Pade approximation.
33+  !> Exponential of the input matrix E = exp(A).
34+  ${rt},ドル intent(out) :: E(:, :)
35+  !> [optional] Order of the Pade approximation.
3436 integer(ilp), optional, intent(in) :: order
3537 !> [optional] State return flag.
3638 type(linalg_state_type), optional, intent(out) :: err
37-  !> Exponential of the input matrix E = exp(A).
38-  ${rt},ドル intent(out) :: E(:, :)
39- 39+ 4040 type(linalg_state_type) :: err0
4141 integer(ilp) :: lda, n, lde, ne
4242
@@ -68,7 +68,7 @@ contains
6868 type(linalg_state_type), optional, intent(out) :: err
6969
7070 ! Internal variables.
71-  ${rt},ドル allocatable :: A2(:, :), Q(:, :), X(:, :)
71+  ${rt},ドル allocatable :: A2(:, :), Q(:, :), X(:, :), X_tmp(:, :) 
7272 real(${rk}$) :: a_norm, c
7373 integer(ilp) :: m, n, ee, k, s, order_, i, j
7474 logical(lk) :: p
@@ -105,32 +105,29 @@ contains
105105 enddo
106106
107107 ! Iteratively compute the Pade approximation.
108-  block
109-  ${rt},ドル allocatable :: X_tmp(:, :)
110-  p = .true.
111-  do k = 2, order_
112-  c = c * (order_ - k + 1) / (k * (2*order_ - k + 1))
113-  X_tmp = X
114-  #:if rt.startswith('complex')
115-  call gemm("N", "N", n, n, n, one_c${rk},ドル A2, n, X_tmp, n, zero_c${rk},ドル X, n)
116-  #:else
117-  call gemm("N", "N", n, n, n, one_${rk},ドル A2, n, X_tmp, n, zero_${rk},ドル X, n)
118-  #:endif
108+  p = .true.
109+  do k = 2, order_
110+  c = c * (order_ - k + 1) / (k * (2*order_ - k + 1))
111+  X_tmp = X
112+  #:if rt.startswith('complex')
113+  call gemm("N", "N", n, n, n, one_c${rk},ドル A2, n, X_tmp, n, zero_c${rk},ドル X, n)
114+  #:else
115+  call gemm("N", "N", n, n, n, one_${rk},ドル A2, n, X_tmp, n, zero_${rk},ドル X, n)
116+  #:endif
117+  do concurrent(i=1:n, j=1:n)
118+  A(i, j) = A(i, j) + c*X(i, j) ! E = E + c*X
119+  enddo
120+  if (p) then
119121 do concurrent(i=1:n, j=1:n)
120-  A (i, j) = A (i, j) + c*X(i, j)  ! E  = E  + c*X
122+  Q (i, j) = Q (i, j) + c*X(i, j) ! Q  = Q  + c*X
121123 enddo
122-  if (p) then
123-  do concurrent(i=1:n, j=1:n)
124-  Q(i, j) = Q(i, j) + c*X(i, j) ! Q = Q + c*X
125-  enddo
126-  else
127-  do concurrent(i=1:n, j=1:n)
128-  Q(i, j) = Q(i, j) - c*X(i, j) ! Q = Q - c*X
129-  enddo
130-  endif
131-  p = .not. p
132-  enddo
133-  end block
124+  else
125+  do concurrent(i=1:n, j=1:n)
126+  Q(i, j) = Q(i, j) - c*X(i, j) ! Q = Q - c*X
127+  enddo
128+  endif
129+  p = .not. p
130+  enddo
134131
135132 block
136133 integer(ilp) :: ipiv(n), info
@@ -139,17 +136,14 @@ contains
139136 end block
140137
141138 ! Matrix squaring.
142-  block
143-  ${rt},ドル allocatable :: E_tmp(:, :)
144-  do k = 1, s
145-  E_tmp = A
146-  #:if rt.startswith('complex')
147-  call gemm("N", "N", n, n, n, one_c${rk},ドル E_tmp, n, E_tmp, n, zero_c${rk},ドル A, n)
148-  #:else
149-  call gemm("N", "N", n, n, n, one_${rk},ドル E_tmp, n, E_tmp, n, zero_${rk},ドル A, n)
150-  #:endif
151-  enddo
152-  end block
139+  do k = 1, s
140+  X = A ! Re-use X to minimize allocations.
141+  #:if rt.startswith('complex')
142+  call gemm("N", "N", n, n, n, one_c${rk},ドル X, n, X, n, zero_c${rk},ドル A, n)
143+  #:else
144+  call gemm("N", "N", n, n, n, one_${rk},ドル X, n, X, n, zero_${rk},ドル A, n)
145+  #:endif
146+  enddo
153147 endif
154148
155149 call linalg_error_handling(err0, err)
0 commit comments