fnodes.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657
  1. """
  2. AST nodes specific to Fortran.
  3. The functions defined in this module allows the user to express functions such as ``dsign``
  4. as a SymPy function for symbolic manipulation.
  5. """
  6. from sympy.codegen.ast import (
  7. Attribute, CodeBlock, FunctionCall, Node, none, String,
  8. Token, _mk_Tuple, Variable
  9. )
  10. from sympy.core.basic import Basic
  11. from sympy.core.containers import Tuple
  12. from sympy.core.expr import Expr
  13. from sympy.core.function import Function
  14. from sympy.core.numbers import Float, Integer
  15. from sympy.core.symbol import Str
  16. from sympy.core.sympify import sympify
  17. from sympy.logic import true, false
  18. from sympy.utilities.iterables import iterable
  19. pure = Attribute('pure')
  20. elemental = Attribute('elemental') # (all elemental procedures are also pure)
  21. intent_in = Attribute('intent_in')
  22. intent_out = Attribute('intent_out')
  23. intent_inout = Attribute('intent_inout')
  24. allocatable = Attribute('allocatable')
  25. class Program(Token):
  26. """ Represents a 'program' block in Fortran.
  27. Examples
  28. ========
  29. >>> from sympy.codegen.ast import Print
  30. >>> from sympy.codegen.fnodes import Program
  31. >>> prog = Program('myprogram', [Print([42])])
  32. >>> from sympy import fcode
  33. >>> print(fcode(prog, source_format='free'))
  34. program myprogram
  35. print *, 42
  36. end program
  37. """
  38. __slots__ = _fields = ('name', 'body')
  39. _construct_name = String
  40. _construct_body = staticmethod(lambda body: CodeBlock(*body))
  41. class use_rename(Token):
  42. """ Represents a renaming in a use statement in Fortran.
  43. Examples
  44. ========
  45. >>> from sympy.codegen.fnodes import use_rename, use
  46. >>> from sympy import fcode
  47. >>> ren = use_rename("thingy", "convolution2d")
  48. >>> print(fcode(ren, source_format='free'))
  49. thingy => convolution2d
  50. >>> full = use('signallib', only=['snr', ren])
  51. >>> print(fcode(full, source_format='free'))
  52. use signallib, only: snr, thingy => convolution2d
  53. """
  54. __slots__ = _fields = ('local', 'original')
  55. _construct_local = String
  56. _construct_original = String
  57. def _name(arg):
  58. if hasattr(arg, 'name'):
  59. return arg.name
  60. else:
  61. return String(arg)
  62. class use(Token):
  63. """ Represents a use statement in Fortran.
  64. Examples
  65. ========
  66. >>> from sympy.codegen.fnodes import use
  67. >>> from sympy import fcode
  68. >>> fcode(use('signallib'), source_format='free')
  69. 'use signallib'
  70. >>> fcode(use('signallib', [('metric', 'snr')]), source_format='free')
  71. 'use signallib, metric => snr'
  72. >>> fcode(use('signallib', only=['snr', 'convolution2d']), source_format='free')
  73. 'use signallib, only: snr, convolution2d'
  74. """
  75. __slots__ = _fields = ('namespace', 'rename', 'only')
  76. defaults = {'rename': none, 'only': none}
  77. _construct_namespace = staticmethod(_name)
  78. _construct_rename = staticmethod(lambda args: Tuple(*[arg if isinstance(arg, use_rename) else use_rename(*arg) for arg in args]))
  79. _construct_only = staticmethod(lambda args: Tuple(*[arg if isinstance(arg, use_rename) else _name(arg) for arg in args]))
  80. class Module(Token):
  81. """ Represents a module in Fortran.
  82. Examples
  83. ========
  84. >>> from sympy.codegen.fnodes import Module
  85. >>> from sympy import fcode
  86. >>> print(fcode(Module('signallib', ['implicit none'], []), source_format='free'))
  87. module signallib
  88. implicit none
  89. <BLANKLINE>
  90. contains
  91. <BLANKLINE>
  92. <BLANKLINE>
  93. end module
  94. """
  95. __slots__ = _fields = ('name', 'declarations', 'definitions')
  96. defaults = {'declarations': Tuple()}
  97. _construct_name = String
  98. @classmethod
  99. def _construct_declarations(cls, args):
  100. args = [Str(arg) if isinstance(arg, str) else arg for arg in args]
  101. return CodeBlock(*args)
  102. _construct_definitions = staticmethod(lambda arg: CodeBlock(*arg))
  103. class Subroutine(Node):
  104. """ Represents a subroutine in Fortran.
  105. Examples
  106. ========
  107. >>> from sympy import fcode, symbols
  108. >>> from sympy.codegen.ast import Print
  109. >>> from sympy.codegen.fnodes import Subroutine
  110. >>> x, y = symbols('x y', real=True)
  111. >>> sub = Subroutine('mysub', [x, y], [Print([x**2 + y**2, x*y])])
  112. >>> print(fcode(sub, source_format='free', standard=2003))
  113. subroutine mysub(x, y)
  114. real*8 :: x
  115. real*8 :: y
  116. print *, x**2 + y**2, x*y
  117. end subroutine
  118. """
  119. __slots__ = ('name', 'parameters', 'body')
  120. _fields = __slots__ + Node._fields
  121. _construct_name = String
  122. _construct_parameters = staticmethod(lambda params: Tuple(*map(Variable.deduced, params)))
  123. @classmethod
  124. def _construct_body(cls, itr):
  125. if isinstance(itr, CodeBlock):
  126. return itr
  127. else:
  128. return CodeBlock(*itr)
  129. class SubroutineCall(Token):
  130. """ Represents a call to a subroutine in Fortran.
  131. Examples
  132. ========
  133. >>> from sympy.codegen.fnodes import SubroutineCall
  134. >>> from sympy import fcode
  135. >>> fcode(SubroutineCall('mysub', 'x y'.split()))
  136. ' call mysub(x, y)'
  137. """
  138. __slots__ = _fields = ('name', 'subroutine_args')
  139. _construct_name = staticmethod(_name)
  140. _construct_subroutine_args = staticmethod(_mk_Tuple)
  141. class Do(Token):
  142. """ Represents a Do loop in in Fortran.
  143. Examples
  144. ========
  145. >>> from sympy import fcode, symbols
  146. >>> from sympy.codegen.ast import aug_assign, Print
  147. >>> from sympy.codegen.fnodes import Do
  148. >>> i, n = symbols('i n', integer=True)
  149. >>> r = symbols('r', real=True)
  150. >>> body = [aug_assign(r, '+', 1/i), Print([i, r])]
  151. >>> do1 = Do(body, i, 1, n)
  152. >>> print(fcode(do1, source_format='free'))
  153. do i = 1, n
  154. r = r + 1d0/i
  155. print *, i, r
  156. end do
  157. >>> do2 = Do(body, i, 1, n, 2)
  158. >>> print(fcode(do2, source_format='free'))
  159. do i = 1, n, 2
  160. r = r + 1d0/i
  161. print *, i, r
  162. end do
  163. """
  164. __slots__ = _fields = ('body', 'counter', 'first', 'last', 'step', 'concurrent')
  165. defaults = {'step': Integer(1), 'concurrent': false}
  166. _construct_body = staticmethod(lambda body: CodeBlock(*body))
  167. _construct_counter = staticmethod(sympify)
  168. _construct_first = staticmethod(sympify)
  169. _construct_last = staticmethod(sympify)
  170. _construct_step = staticmethod(sympify)
  171. _construct_concurrent = staticmethod(lambda arg: true if arg else false)
  172. class ArrayConstructor(Token):
  173. """ Represents an array constructor.
  174. Examples
  175. ========
  176. >>> from sympy import fcode
  177. >>> from sympy.codegen.fnodes import ArrayConstructor
  178. >>> ac = ArrayConstructor([1, 2, 3])
  179. >>> fcode(ac, standard=95, source_format='free')
  180. '(/1, 2, 3/)'
  181. >>> fcode(ac, standard=2003, source_format='free')
  182. '[1, 2, 3]'
  183. """
  184. __slots__ = _fields = ('elements',)
  185. _construct_elements = staticmethod(_mk_Tuple)
  186. class ImpliedDoLoop(Token):
  187. """ Represents an implied do loop in Fortran.
  188. Examples
  189. ========
  190. >>> from sympy import Symbol, fcode
  191. >>> from sympy.codegen.fnodes import ImpliedDoLoop, ArrayConstructor
  192. >>> i = Symbol('i', integer=True)
  193. >>> idl = ImpliedDoLoop(i**3, i, -3, 3, 2) # -27, -1, 1, 27
  194. >>> ac = ArrayConstructor([-28, idl, 28]) # -28, -27, -1, 1, 27, 28
  195. >>> fcode(ac, standard=2003, source_format='free')
  196. '[-28, (i**3, i = -3, 3, 2), 28]'
  197. """
  198. __slots__ = _fields = ('expr', 'counter', 'first', 'last', 'step')
  199. defaults = {'step': Integer(1)}
  200. _construct_expr = staticmethod(sympify)
  201. _construct_counter = staticmethod(sympify)
  202. _construct_first = staticmethod(sympify)
  203. _construct_last = staticmethod(sympify)
  204. _construct_step = staticmethod(sympify)
  205. class Extent(Basic):
  206. """ Represents a dimension extent.
  207. Examples
  208. ========
  209. >>> from sympy.codegen.fnodes import Extent
  210. >>> e = Extent(-3, 3) # -3, -2, -1, 0, 1, 2, 3
  211. >>> from sympy import fcode
  212. >>> fcode(e, source_format='free')
  213. '-3:3'
  214. >>> from sympy.codegen.ast import Variable, real
  215. >>> from sympy.codegen.fnodes import dimension, intent_out
  216. >>> dim = dimension(e, e)
  217. >>> arr = Variable('x', real, attrs=[dim, intent_out])
  218. >>> fcode(arr.as_Declaration(), source_format='free', standard=2003)
  219. 'real*8, dimension(-3:3, -3:3), intent(out) :: x'
  220. """
  221. def __new__(cls, *args):
  222. if len(args) == 2:
  223. low, high = args
  224. return Basic.__new__(cls, sympify(low), sympify(high))
  225. elif len(args) == 0 or (len(args) == 1 and args[0] in (':', None)):
  226. return Basic.__new__(cls) # assumed shape
  227. else:
  228. raise ValueError("Expected 0 or 2 args (or one argument == None or ':')")
  229. def _sympystr(self, printer):
  230. if len(self.args) == 0:
  231. return ':'
  232. return ":".join(str(arg) for arg in self.args)
  233. assumed_extent = Extent() # or Extent(':'), Extent(None)
  234. def dimension(*args):
  235. """ Creates a 'dimension' Attribute with (up to 7) extents.
  236. Examples
  237. ========
  238. >>> from sympy import fcode
  239. >>> from sympy.codegen.fnodes import dimension, intent_in
  240. >>> dim = dimension('2', ':') # 2 rows, runtime determined number of columns
  241. >>> from sympy.codegen.ast import Variable, integer
  242. >>> arr = Variable('a', integer, attrs=[dim, intent_in])
  243. >>> fcode(arr.as_Declaration(), source_format='free', standard=2003)
  244. 'integer*4, dimension(2, :), intent(in) :: a'
  245. """
  246. if len(args) > 7:
  247. raise ValueError("Fortran only supports up to 7 dimensional arrays")
  248. parameters = []
  249. for arg in args:
  250. if isinstance(arg, Extent):
  251. parameters.append(arg)
  252. elif isinstance(arg, str):
  253. if arg == ':':
  254. parameters.append(Extent())
  255. else:
  256. parameters.append(String(arg))
  257. elif iterable(arg):
  258. parameters.append(Extent(*arg))
  259. else:
  260. parameters.append(sympify(arg))
  261. if len(args) == 0:
  262. raise ValueError("Need at least one dimension")
  263. return Attribute('dimension', parameters)
  264. assumed_size = dimension('*')
  265. def array(symbol, dim, intent=None, *, attrs=(), value=None, type=None):
  266. """ Convenience function for creating a Variable instance for a Fortran array.
  267. Parameters
  268. ==========
  269. symbol : symbol
  270. dim : Attribute or iterable
  271. If dim is an ``Attribute`` it need to have the name 'dimension'. If it is
  272. not an ``Attribute``, then it is passed to :func:`dimension` as ``*dim``
  273. intent : str
  274. One of: 'in', 'out', 'inout' or None
  275. \\*\\*kwargs:
  276. Keyword arguments for ``Variable`` ('type' & 'value')
  277. Examples
  278. ========
  279. >>> from sympy import fcode
  280. >>> from sympy.codegen.ast import integer, real
  281. >>> from sympy.codegen.fnodes import array
  282. >>> arr = array('a', '*', 'in', type=integer)
  283. >>> print(fcode(arr.as_Declaration(), source_format='free', standard=2003))
  284. integer*4, dimension(*), intent(in) :: a
  285. >>> x = array('x', [3, ':', ':'], intent='out', type=real)
  286. >>> print(fcode(x.as_Declaration(value=1), source_format='free', standard=2003))
  287. real*8, dimension(3, :, :), intent(out) :: x = 1
  288. """
  289. if isinstance(dim, Attribute):
  290. if str(dim.name) != 'dimension':
  291. raise ValueError("Got an unexpected Attribute argument as dim: %s" % str(dim))
  292. else:
  293. dim = dimension(*dim)
  294. attrs = list(attrs) + [dim]
  295. if intent is not None:
  296. if intent not in (intent_in, intent_out, intent_inout):
  297. intent = {'in': intent_in, 'out': intent_out, 'inout': intent_inout}[intent]
  298. attrs.append(intent)
  299. if type is None:
  300. return Variable.deduced(symbol, value=value, attrs=attrs)
  301. else:
  302. return Variable(symbol, type, value=value, attrs=attrs)
  303. def _printable(arg):
  304. return String(arg) if isinstance(arg, str) else sympify(arg)
  305. def allocated(array):
  306. """ Creates an AST node for a function call to Fortran's "allocated(...)"
  307. Examples
  308. ========
  309. >>> from sympy import fcode
  310. >>> from sympy.codegen.fnodes import allocated
  311. >>> alloc = allocated('x')
  312. >>> fcode(alloc, source_format='free')
  313. 'allocated(x)'
  314. """
  315. return FunctionCall('allocated', [_printable(array)])
  316. def lbound(array, dim=None, kind=None):
  317. """ Creates an AST node for a function call to Fortran's "lbound(...)"
  318. Parameters
  319. ==========
  320. array : Symbol or String
  321. dim : expr
  322. kind : expr
  323. Examples
  324. ========
  325. >>> from sympy import fcode
  326. >>> from sympy.codegen.fnodes import lbound
  327. >>> lb = lbound('arr', dim=2)
  328. >>> fcode(lb, source_format='free')
  329. 'lbound(arr, 2)'
  330. """
  331. return FunctionCall(
  332. 'lbound',
  333. [_printable(array)] +
  334. ([_printable(dim)] if dim else []) +
  335. ([_printable(kind)] if kind else [])
  336. )
  337. def ubound(array, dim=None, kind=None):
  338. return FunctionCall(
  339. 'ubound',
  340. [_printable(array)] +
  341. ([_printable(dim)] if dim else []) +
  342. ([_printable(kind)] if kind else [])
  343. )
  344. def shape(source, kind=None):
  345. """ Creates an AST node for a function call to Fortran's "shape(...)"
  346. Parameters
  347. ==========
  348. source : Symbol or String
  349. kind : expr
  350. Examples
  351. ========
  352. >>> from sympy import fcode
  353. >>> from sympy.codegen.fnodes import shape
  354. >>> shp = shape('x')
  355. >>> fcode(shp, source_format='free')
  356. 'shape(x)'
  357. """
  358. return FunctionCall(
  359. 'shape',
  360. [_printable(source)] +
  361. ([_printable(kind)] if kind else [])
  362. )
  363. def size(array, dim=None, kind=None):
  364. """ Creates an AST node for a function call to Fortran's "size(...)"
  365. Examples
  366. ========
  367. >>> from sympy import fcode, Symbol
  368. >>> from sympy.codegen.ast import FunctionDefinition, real, Return
  369. >>> from sympy.codegen.fnodes import array, sum_, size
  370. >>> a = Symbol('a', real=True)
  371. >>> body = [Return((sum_(a**2)/size(a))**.5)]
  372. >>> arr = array(a, dim=[':'], intent='in')
  373. >>> fd = FunctionDefinition(real, 'rms', [arr], body)
  374. >>> print(fcode(fd, source_format='free', standard=2003))
  375. real*8 function rms(a)
  376. real*8, dimension(:), intent(in) :: a
  377. rms = sqrt(sum(a**2)*1d0/size(a))
  378. end function
  379. """
  380. return FunctionCall(
  381. 'size',
  382. [_printable(array)] +
  383. ([_printable(dim)] if dim else []) +
  384. ([_printable(kind)] if kind else [])
  385. )
  386. def reshape(source, shape, pad=None, order=None):
  387. """ Creates an AST node for a function call to Fortran's "reshape(...)"
  388. Parameters
  389. ==========
  390. source : Symbol or String
  391. shape : ArrayExpr
  392. """
  393. return FunctionCall(
  394. 'reshape',
  395. [_printable(source), _printable(shape)] +
  396. ([_printable(pad)] if pad else []) +
  397. ([_printable(order)] if pad else [])
  398. )
  399. def bind_C(name=None):
  400. """ Creates an Attribute ``bind_C`` with a name.
  401. Parameters
  402. ==========
  403. name : str
  404. Examples
  405. ========
  406. >>> from sympy import fcode, Symbol
  407. >>> from sympy.codegen.ast import FunctionDefinition, real, Return
  408. >>> from sympy.codegen.fnodes import array, sum_, bind_C
  409. >>> a = Symbol('a', real=True)
  410. >>> s = Symbol('s', integer=True)
  411. >>> arr = array(a, dim=[s], intent='in')
  412. >>> body = [Return((sum_(a**2)/s)**.5)]
  413. >>> fd = FunctionDefinition(real, 'rms', [arr, s], body, attrs=[bind_C('rms')])
  414. >>> print(fcode(fd, source_format='free', standard=2003))
  415. real*8 function rms(a, s) bind(C, name="rms")
  416. real*8, dimension(s), intent(in) :: a
  417. integer*4 :: s
  418. rms = sqrt(sum(a**2)/s)
  419. end function
  420. """
  421. return Attribute('bind_C', [String(name)] if name else [])
  422. class GoTo(Token):
  423. """ Represents a goto statement in Fortran
  424. Examples
  425. ========
  426. >>> from sympy.codegen.fnodes import GoTo
  427. >>> go = GoTo([10, 20, 30], 'i')
  428. >>> from sympy import fcode
  429. >>> fcode(go, source_format='free')
  430. 'go to (10, 20, 30), i'
  431. """
  432. __slots__ = _fields = ('labels', 'expr')
  433. defaults = {'expr': none}
  434. _construct_labels = staticmethod(_mk_Tuple)
  435. _construct_expr = staticmethod(sympify)
  436. class FortranReturn(Token):
  437. """ AST node explicitly mapped to a fortran "return".
  438. Explanation
  439. ===========
  440. Because a return statement in fortran is different from C, and
  441. in order to aid reuse of our codegen ASTs the ordinary
  442. ``.codegen.ast.Return`` is interpreted as assignment to
  443. the result variable of the function. If one for some reason needs
  444. to generate a fortran RETURN statement, this node should be used.
  445. Examples
  446. ========
  447. >>> from sympy.codegen.fnodes import FortranReturn
  448. >>> from sympy import fcode
  449. >>> fcode(FortranReturn('x'))
  450. ' return x'
  451. """
  452. __slots__ = _fields = ('return_value',)
  453. defaults = {'return_value': none}
  454. _construct_return_value = staticmethod(sympify)
  455. class FFunction(Function):
  456. _required_standard = 77
  457. def _fcode(self, printer):
  458. name = self.__class__.__name__
  459. if printer._settings['standard'] < self._required_standard:
  460. raise NotImplementedError("%s requires Fortran %d or newer" %
  461. (name, self._required_standard))
  462. return '{}({})'.format(name, ', '.join(map(printer._print, self.args)))
  463. class F95Function(FFunction):
  464. _required_standard = 95
  465. class isign(FFunction):
  466. """ Fortran sign intrinsic for integer arguments. """
  467. nargs = 2
  468. class dsign(FFunction):
  469. """ Fortran sign intrinsic for double precision arguments. """
  470. nargs = 2
  471. class cmplx(FFunction):
  472. """ Fortran complex conversion function. """
  473. nargs = 2 # may be extended to (2, 3) at a later point
  474. class kind(FFunction):
  475. """ Fortran kind function. """
  476. nargs = 1
  477. class merge(F95Function):
  478. """ Fortran merge function """
  479. nargs = 3
  480. class _literal(Float):
  481. _token = None # type: str
  482. _decimals = None # type: int
  483. def _fcode(self, printer, *args, **kwargs):
  484. mantissa, sgnd_ex = ('%.{}e'.format(self._decimals) % self).split('e')
  485. mantissa = mantissa.strip('0').rstrip('.')
  486. ex_sgn, ex_num = sgnd_ex[0], sgnd_ex[1:].lstrip('0')
  487. ex_sgn = '' if ex_sgn == '+' else ex_sgn
  488. return (mantissa or '0') + self._token + ex_sgn + (ex_num or '0')
  489. class literal_sp(_literal):
  490. """ Fortran single precision real literal """
  491. _token = 'e'
  492. _decimals = 9
  493. class literal_dp(_literal):
  494. """ Fortran double precision real literal """
  495. _token = 'd'
  496. _decimals = 17
  497. class sum_(Token, Expr):
  498. __slots__ = _fields = ('array', 'dim', 'mask')
  499. defaults = {'dim': none, 'mask': none}
  500. _construct_array = staticmethod(sympify)
  501. _construct_dim = staticmethod(sympify)
  502. class product_(Token, Expr):
  503. __slots__ = _fields = ('array', 'dim', 'mask')
  504. defaults = {'dim': none, 'mask': none}
  505. _construct_array = staticmethod(sympify)
  506. _construct_dim = staticmethod(sympify)