blas.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484
  1. """
  2. Low-level BLAS functions (:mod:`scipy.linalg.blas`)
  3. ===================================================
  4. This module contains low-level functions from the BLAS library.
  5. .. versionadded:: 0.12.0
  6. .. note::
  7. The common ``overwrite_<>`` option in many routines, allows the
  8. input arrays to be overwritten to avoid extra memory allocation.
  9. However this requires the array to satisfy two conditions
  10. which are memory order and the data type to match exactly the
  11. order and the type expected by the routine.
  12. As an example, if you pass a double precision float array to any
  13. ``S....`` routine which expects single precision arguments, f2py
  14. will create an intermediate array to match the argument types and
  15. overwriting will be performed on that intermediate array.
  16. Similarly, if a C-contiguous array is passed, f2py will pass a
  17. FORTRAN-contiguous array internally. Please make sure that these
  18. details are satisfied. More information can be found in the f2py
  19. documentation.
  20. .. warning::
  21. These functions do little to no error checking.
  22. It is possible to cause crashes by mis-using them,
  23. so prefer using the higher-level routines in `scipy.linalg`.
  24. Finding functions
  25. -----------------
  26. .. autosummary::
  27. :toctree: generated/
  28. get_blas_funcs
  29. find_best_blas_type
  30. BLAS Level 1 functions
  31. ----------------------
  32. .. autosummary::
  33. :toctree: generated/
  34. caxpy
  35. ccopy
  36. cdotc
  37. cdotu
  38. crotg
  39. cscal
  40. csrot
  41. csscal
  42. cswap
  43. dasum
  44. daxpy
  45. dcopy
  46. ddot
  47. dnrm2
  48. drot
  49. drotg
  50. drotm
  51. drotmg
  52. dscal
  53. dswap
  54. dzasum
  55. dznrm2
  56. icamax
  57. idamax
  58. isamax
  59. izamax
  60. sasum
  61. saxpy
  62. scasum
  63. scnrm2
  64. scopy
  65. sdot
  66. snrm2
  67. srot
  68. srotg
  69. srotm
  70. srotmg
  71. sscal
  72. sswap
  73. zaxpy
  74. zcopy
  75. zdotc
  76. zdotu
  77. zdrot
  78. zdscal
  79. zrotg
  80. zscal
  81. zswap
  82. BLAS Level 2 functions
  83. ----------------------
  84. .. autosummary::
  85. :toctree: generated/
  86. sgbmv
  87. sgemv
  88. sger
  89. ssbmv
  90. sspr
  91. sspr2
  92. ssymv
  93. ssyr
  94. ssyr2
  95. stbmv
  96. stpsv
  97. strmv
  98. strsv
  99. dgbmv
  100. dgemv
  101. dger
  102. dsbmv
  103. dspr
  104. dspr2
  105. dsymv
  106. dsyr
  107. dsyr2
  108. dtbmv
  109. dtpsv
  110. dtrmv
  111. dtrsv
  112. cgbmv
  113. cgemv
  114. cgerc
  115. cgeru
  116. chbmv
  117. chemv
  118. cher
  119. cher2
  120. chpmv
  121. chpr
  122. chpr2
  123. ctbmv
  124. ctbsv
  125. ctpmv
  126. ctpsv
  127. ctrmv
  128. ctrsv
  129. csyr
  130. zgbmv
  131. zgemv
  132. zgerc
  133. zgeru
  134. zhbmv
  135. zhemv
  136. zher
  137. zher2
  138. zhpmv
  139. zhpr
  140. zhpr2
  141. ztbmv
  142. ztbsv
  143. ztpmv
  144. ztrmv
  145. ztrsv
  146. zsyr
  147. BLAS Level 3 functions
  148. ----------------------
  149. .. autosummary::
  150. :toctree: generated/
  151. sgemm
  152. ssymm
  153. ssyr2k
  154. ssyrk
  155. strmm
  156. strsm
  157. dgemm
  158. dsymm
  159. dsyr2k
  160. dsyrk
  161. dtrmm
  162. dtrsm
  163. cgemm
  164. chemm
  165. cher2k
  166. cherk
  167. csymm
  168. csyr2k
  169. csyrk
  170. ctrmm
  171. ctrsm
  172. zgemm
  173. zhemm
  174. zher2k
  175. zherk
  176. zsymm
  177. zsyr2k
  178. zsyrk
  179. ztrmm
  180. ztrsm
  181. """
  182. #
  183. # Author: Pearu Peterson, March 2002
  184. # refactoring by Fabian Pedregosa, March 2010
  185. #
  186. __all__ = ['get_blas_funcs', 'find_best_blas_type']
  187. import numpy as _np
  188. import functools
  189. from scipy.linalg import _fblas
  190. try:
  191. from scipy.linalg import _cblas
  192. except ImportError:
  193. _cblas = None
  194. try:
  195. from scipy.linalg import _fblas_64
  196. HAS_ILP64 = True
  197. except ImportError:
  198. HAS_ILP64 = False
  199. _fblas_64 = None
  200. # Expose all functions (only fblas --- cblas is an implementation detail)
  201. empty_module = None
  202. from scipy.linalg._fblas import *
  203. del empty_module
  204. # all numeric dtypes '?bBhHiIlLqQefdgFDGO' that are safe to be converted to
  205. # single precision float : '?bBhH!!!!!!ef!!!!!!'
  206. # double precision float : '?bBhHiIlLqQefdg!!!!'
  207. # single precision complex : '?bBhH!!!!!!ef!!F!!!'
  208. # double precision complex : '?bBhHiIlLqQefdgFDG!'
  209. _type_score = {x: 1 for x in '?bBhHef'}
  210. _type_score.update({x: 2 for x in 'iIlLqQd'})
  211. # Handle float128(g) and complex256(G) separately in case non-Windows systems.
  212. # On Windows, the values will be rewritten to the same key with the same value.
  213. _type_score.update({'F': 3, 'D': 4, 'g': 2, 'G': 4})
  214. # Final mapping to the actual prefixes and dtypes
  215. _type_conv = {1: ('s', _np.dtype('float32')),
  216. 2: ('d', _np.dtype('float64')),
  217. 3: ('c', _np.dtype('complex64')),
  218. 4: ('z', _np.dtype('complex128'))}
  219. # some convenience alias for complex functions
  220. _blas_alias = {'cnrm2': 'scnrm2', 'znrm2': 'dznrm2',
  221. 'cdot': 'cdotc', 'zdot': 'zdotc',
  222. 'cger': 'cgerc', 'zger': 'zgerc',
  223. 'sdotc': 'sdot', 'sdotu': 'sdot',
  224. 'ddotc': 'ddot', 'ddotu': 'ddot'}
  225. def find_best_blas_type(arrays=(), dtype=None):
  226. """Find best-matching BLAS/LAPACK type.
  227. Arrays are used to determine the optimal prefix of BLAS routines.
  228. Parameters
  229. ----------
  230. arrays : sequence of ndarrays, optional
  231. Arrays can be given to determine optimal prefix of BLAS
  232. routines. If not given, double-precision routines will be
  233. used, otherwise the most generic type in arrays will be used.
  234. dtype : str or dtype, optional
  235. Data-type specifier. Not used if `arrays` is non-empty.
  236. Returns
  237. -------
  238. prefix : str
  239. BLAS/LAPACK prefix character.
  240. dtype : dtype
  241. Inferred Numpy data type.
  242. prefer_fortran : bool
  243. Whether to prefer Fortran order routines over C order.
  244. Examples
  245. --------
  246. >>> import numpy as np
  247. >>> import scipy.linalg.blas as bla
  248. >>> rng = np.random.default_rng()
  249. >>> a = rng.random((10,15))
  250. >>> b = np.asfortranarray(a) # Change the memory layout order
  251. >>> bla.find_best_blas_type((a,))
  252. ('d', dtype('float64'), False)
  253. >>> bla.find_best_blas_type((a*1j,))
  254. ('z', dtype('complex128'), False)
  255. >>> bla.find_best_blas_type((b,))
  256. ('d', dtype('float64'), True)
  257. """
  258. dtype = _np.dtype(dtype)
  259. max_score = _type_score.get(dtype.char, 5)
  260. prefer_fortran = False
  261. if arrays:
  262. # In most cases, single element is passed through, quicker route
  263. if len(arrays) == 1:
  264. max_score = _type_score.get(arrays[0].dtype.char, 5)
  265. prefer_fortran = arrays[0].flags['FORTRAN']
  266. else:
  267. # use the most generic type in arrays
  268. scores = [_type_score.get(x.dtype.char, 5) for x in arrays]
  269. max_score = max(scores)
  270. ind_max_score = scores.index(max_score)
  271. # safe upcasting for mix of float64 and complex64 --> prefix 'z'
  272. if max_score == 3 and (2 in scores):
  273. max_score = 4
  274. if arrays[ind_max_score].flags['FORTRAN']:
  275. # prefer Fortran for leading array with column major order
  276. prefer_fortran = True
  277. # Get the LAPACK prefix and the corresponding dtype if not fall back
  278. # to 'd' and double precision float.
  279. prefix, dtype = _type_conv.get(max_score, ('d', _np.dtype('float64')))
  280. return prefix, dtype, prefer_fortran
  281. def _get_funcs(names, arrays, dtype,
  282. lib_name, fmodule, cmodule,
  283. fmodule_name, cmodule_name, alias,
  284. ilp64=False):
  285. """
  286. Return available BLAS/LAPACK functions.
  287. Used also in lapack.py. See get_blas_funcs for docstring.
  288. """
  289. funcs = []
  290. unpack = False
  291. dtype = _np.dtype(dtype)
  292. module1 = (cmodule, cmodule_name)
  293. module2 = (fmodule, fmodule_name)
  294. if isinstance(names, str):
  295. names = (names,)
  296. unpack = True
  297. prefix, dtype, prefer_fortran = find_best_blas_type(arrays, dtype)
  298. if prefer_fortran:
  299. module1, module2 = module2, module1
  300. for name in names:
  301. func_name = prefix + name
  302. func_name = alias.get(func_name, func_name)
  303. func = getattr(module1[0], func_name, None)
  304. module_name = module1[1]
  305. if func is None:
  306. func = getattr(module2[0], func_name, None)
  307. module_name = module2[1]
  308. if func is None:
  309. raise ValueError(
  310. '%s function %s could not be found' % (lib_name, func_name))
  311. func.module_name, func.typecode = module_name, prefix
  312. func.dtype = dtype
  313. if not ilp64:
  314. func.int_dtype = _np.dtype(_np.intc)
  315. else:
  316. func.int_dtype = _np.dtype(_np.int64)
  317. func.prefix = prefix # Backward compatibility
  318. funcs.append(func)
  319. if unpack:
  320. return funcs[0]
  321. else:
  322. return funcs
  323. def _memoize_get_funcs(func):
  324. """
  325. Memoized fast path for _get_funcs instances
  326. """
  327. memo = {}
  328. func.memo = memo
  329. @functools.wraps(func)
  330. def getter(names, arrays=(), dtype=None, ilp64=False):
  331. key = (names, dtype, ilp64)
  332. for array in arrays:
  333. # cf. find_blas_funcs
  334. key += (array.dtype.char, array.flags.fortran)
  335. try:
  336. value = memo.get(key)
  337. except TypeError:
  338. # unhashable key etc.
  339. key = None
  340. value = None
  341. if value is not None:
  342. return value
  343. value = func(names, arrays, dtype, ilp64)
  344. if key is not None:
  345. memo[key] = value
  346. return value
  347. return getter
  348. @_memoize_get_funcs
  349. def get_blas_funcs(names, arrays=(), dtype=None, ilp64=False):
  350. """Return available BLAS function objects from names.
  351. Arrays are used to determine the optimal prefix of BLAS routines.
  352. Parameters
  353. ----------
  354. names : str or sequence of str
  355. Name(s) of BLAS functions without type prefix.
  356. arrays : sequence of ndarrays, optional
  357. Arrays can be given to determine optimal prefix of BLAS
  358. routines. If not given, double-precision routines will be
  359. used, otherwise the most generic type in arrays will be used.
  360. dtype : str or dtype, optional
  361. Data-type specifier. Not used if `arrays` is non-empty.
  362. ilp64 : {True, False, 'preferred'}, optional
  363. Whether to return ILP64 routine variant.
  364. Choosing 'preferred' returns ILP64 routine if available,
  365. and otherwise the 32-bit routine. Default: False
  366. Returns
  367. -------
  368. funcs : list
  369. List containing the found function(s).
  370. Notes
  371. -----
  372. This routine automatically chooses between Fortran/C
  373. interfaces. Fortran code is used whenever possible for arrays with
  374. column major order. In all other cases, C code is preferred.
  375. In BLAS, the naming convention is that all functions start with a
  376. type prefix, which depends on the type of the principal
  377. matrix. These can be one of {'s', 'd', 'c', 'z'} for the NumPy
  378. types {float32, float64, complex64, complex128} respectively.
  379. The code and the dtype are stored in attributes `typecode` and `dtype`
  380. of the returned functions.
  381. Examples
  382. --------
  383. >>> import numpy as np
  384. >>> import scipy.linalg as LA
  385. >>> rng = np.random.default_rng()
  386. >>> a = rng.random((3,2))
  387. >>> x_gemv = LA.get_blas_funcs('gemv', (a,))
  388. >>> x_gemv.typecode
  389. 'd'
  390. >>> x_gemv = LA.get_blas_funcs('gemv',(a*1j,))
  391. >>> x_gemv.typecode
  392. 'z'
  393. """
  394. if isinstance(ilp64, str):
  395. if ilp64 == 'preferred':
  396. ilp64 = HAS_ILP64
  397. else:
  398. raise ValueError("Invalid value for 'ilp64'")
  399. if not ilp64:
  400. return _get_funcs(names, arrays, dtype,
  401. "BLAS", _fblas, _cblas, "fblas", "cblas",
  402. _blas_alias, ilp64=False)
  403. else:
  404. if not HAS_ILP64:
  405. raise RuntimeError("BLAS ILP64 routine requested, but Scipy "
  406. "compiled only with 32-bit BLAS")
  407. return _get_funcs(names, arrays, dtype,
  408. "BLAS", _fblas_64, None, "fblas_64", None,
  409. _blas_alias, ilp64=True)