Skip to content

Commit a34ecce

Browse files
committed
ENH: linalg: speedup _sqrtm_triu by moving tight loop to Cython
Move the bottleneck loop in _sqrtm_triu to Cython to speed it up. The second loop time is dominated by D/ZTRSYL, and Cythonizing that won't give a speedup.
1 parent 70428a9 commit a34ecce

File tree

3 files changed

+53
-18
lines changed

3 files changed

+53
-18
lines changed

scipy/linalg/_matfuncs_sqrtm.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ class SqrtmError(np.linalg.LinAlgError):
2121
pass
2222

2323

24+
from ._matfuncs_sqrtm_triu import within_block_loop
25+
26+
2427
def _sqrtm_triu(T, blocksize=64):
2528
"""
2629
Matrix square root of an upper triangular matrix.
@@ -49,8 +52,15 @@ def _sqrtm_triu(T, blocksize=64):
4952
"""
5053
T_diag = np.diag(T)
5154
keep_it_real = np.isrealobj(T) and np.min(T_diag) >= 0
55+
56+
# Cast to complex as necessary + ensure double precision
5257
if not keep_it_real:
53-
T_diag = T_diag.astype(complex)
58+
T = np.asarray(T, dtype=np.complex128, order="C")
59+
T_diag = np.asarray(T_diag, dtype=np.complex128)
60+
else:
61+
T = np.asarray(T, dtype=np.float64, order="C")
62+
T_diag = np.asarray(T_diag, dtype=np.float64)
63+
5464
R = np.diag(np.sqrt(T_diag))
5565

5666
# Compute the number of blocks to use; use at least one block.
@@ -73,23 +83,10 @@ def _sqrtm_triu(T, blocksize=64):
7383
start_stop_pairs.append((start, start + size))
7484
start += size
7585

76-
# Within-block interactions.
77-
for start, stop in start_stop_pairs:
78-
for j in range(start, stop):
79-
for i in range(j-1, start-1, -1):
80-
s = 0
81-
if j - i > 1:
82-
s = R[i, i+1:j].dot(R[i+1:j, j])
83-
denom = R[i, i] + R[j, j]
84-
num = T[i, j] - s
85-
if denom != 0:
86-
R[i, j] = (T[i, j] - s) / denom
87-
elif denom == 0 and num == 0:
88-
R[i, j] = 0
89-
else:
90-
raise SqrtmError('failed to find the matrix square root')
91-
92-
# Between-block interactions.
86+
# Within-block interactions (Cythonized)
87+
within_block_loop(R, T, start_stop_pairs, nblocks)
88+
89+
# Between-block interactions (Cython would give no significant speedup)
9390
for j in range(nblocks):
9491
jstart, jstop = start_stop_pairs[j]
9592
for i in range(j-1, -1, -1):
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# cython: boundscheck=False, wraparound=False, cdivision=True
2+
import numpy as np
3+
from ._matfuncs_sqrtm import SqrtmError
4+
5+
from numpy cimport complex128_t, float64_t, intp_t
6+
7+
8+
cdef fused floating:
9+
float64_t
10+
complex128_t
11+
12+
13+
def within_block_loop(floating[:,::1] R, floating[:,::1] T, start_stop_pairs, intp_t nblocks):
14+
cdef intp_t start, stop, i, j, k
15+
cdef floating s, denom, num
16+
17+
for start, stop in start_stop_pairs:
18+
for j in range(start, stop):
19+
for i in range(j-1, start-1, -1):
20+
s = 0
21+
if j - i > 1:
22+
# s = R[i,i+1:j] @ R[i+1:j,j]
23+
for k in range(i + 1, j):
24+
s += R[i,k] * R[k,j]
25+
26+
denom = R[i, i] + R[j, j]
27+
num = T[i, j] - s
28+
if denom != 0:
29+
R[i, j] = (T[i, j] - s) / denom
30+
elif denom == 0 and num == 0:
31+
R[i, j] = 0
32+
else:
33+
raise SqrtmError('failed to find the matrix square root')

scipy/linalg/setup.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,11 @@ def configuration(parent_package='', top_path=None):
105105
sources=[('_solve_toeplitz.c')],
106106
include_dirs=[get_numpy_include_dirs()])
107107

108+
# _matfuncs_sqrtm_triu:
109+
config.add_extension('_matfuncs_sqrtm_triu',
110+
sources=[('_matfuncs_sqrtm_triu.c')],
111+
include_dirs=[get_numpy_include_dirs()])
112+
108113
config.add_data_dir('tests')
109114

110115
# Cython BLAS/LAPACK

0 commit comments

Comments
 (0)