op_properties.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282
  1. # Copyright (c) Facebook, Inc. and its affiliates.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the BSD-style license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. import torch
  7. # pointwise operators can go through a faster pathway
  8. tensor_magic_methods = [
  9. 'add',
  10. ''
  11. ]
  12. pointwise_magic_methods_with_reverse = (
  13. 'add', 'sub', 'mul', 'floordiv', 'div', 'truediv', 'mod',
  14. 'pow', 'lshift', 'rshift', 'and', 'or', 'xor'
  15. )
  16. pointwise_magic_methods = (
  17. *(x for m in pointwise_magic_methods_with_reverse for x in (m, 'r' + m)),
  18. 'eq', 'gt', 'le', 'lt', 'ge', 'gt', 'ne', 'neg', 'pos',
  19. 'abs', 'invert',
  20. 'iadd', 'isub', 'imul', 'ifloordiv', 'idiv',
  21. 'itruediv', 'imod', 'ipow', 'ilshift', 'irshift', 'iand',
  22. 'ior', 'ixor',
  23. 'int', 'long', 'float', 'complex',
  24. )
  25. pointwise_methods = (
  26. *(f'__{m}__' for m in pointwise_magic_methods),
  27. )
  28. pointwise = (
  29. *(getattr(torch.Tensor, m) for m in pointwise_methods),
  30. torch.nn.functional.dropout,
  31. torch.where,
  32. torch.Tensor.abs,
  33. torch.abs,
  34. torch.Tensor.acos,
  35. torch.acos,
  36. torch.Tensor.acosh,
  37. torch.acosh,
  38. torch.Tensor.add,
  39. torch.add,
  40. torch.Tensor.addcdiv,
  41. torch.addcdiv,
  42. torch.Tensor.addcmul,
  43. torch.addcmul,
  44. torch.Tensor.addr,
  45. torch.addr,
  46. torch.Tensor.angle,
  47. torch.angle,
  48. torch.Tensor.asin,
  49. torch.asin,
  50. torch.Tensor.asinh,
  51. torch.asinh,
  52. torch.Tensor.atan,
  53. torch.atan,
  54. torch.Tensor.atan2,
  55. torch.atan2,
  56. torch.Tensor.atanh,
  57. torch.atanh,
  58. torch.Tensor.bitwise_and,
  59. torch.bitwise_and,
  60. torch.Tensor.bitwise_left_shift,
  61. torch.bitwise_left_shift,
  62. torch.Tensor.bitwise_not,
  63. torch.bitwise_not,
  64. torch.Tensor.bitwise_or,
  65. torch.bitwise_or,
  66. torch.Tensor.bitwise_right_shift,
  67. torch.bitwise_right_shift,
  68. torch.Tensor.bitwise_xor,
  69. torch.bitwise_xor,
  70. torch.Tensor.ceil,
  71. torch.ceil,
  72. torch.celu,
  73. torch.nn.functional.celu,
  74. torch.Tensor.clamp,
  75. torch.clamp,
  76. torch.Tensor.clamp_max,
  77. torch.clamp_max,
  78. torch.Tensor.clamp_min,
  79. torch.clamp_min,
  80. torch.Tensor.copysign,
  81. torch.copysign,
  82. torch.Tensor.cos,
  83. torch.cos,
  84. torch.Tensor.cosh,
  85. torch.cosh,
  86. torch.Tensor.deg2rad,
  87. torch.deg2rad,
  88. torch.Tensor.digamma,
  89. torch.digamma,
  90. torch.Tensor.div,
  91. torch.div,
  92. torch.dropout,
  93. torch.nn.functional.dropout,
  94. torch.nn.functional.elu,
  95. torch.Tensor.eq,
  96. torch.eq,
  97. torch.Tensor.erf,
  98. torch.erf,
  99. torch.Tensor.erfc,
  100. torch.erfc,
  101. torch.Tensor.erfinv,
  102. torch.erfinv,
  103. torch.Tensor.exp,
  104. torch.exp,
  105. torch.Tensor.exp2,
  106. torch.exp2,
  107. torch.Tensor.expm1,
  108. torch.expm1,
  109. torch.feature_dropout,
  110. torch.Tensor.float_power,
  111. torch.float_power,
  112. torch.Tensor.floor,
  113. torch.floor,
  114. torch.Tensor.floor_divide,
  115. torch.floor_divide,
  116. torch.Tensor.fmod,
  117. torch.fmod,
  118. torch.Tensor.frac,
  119. torch.frac,
  120. torch.Tensor.frexp,
  121. torch.frexp,
  122. torch.Tensor.gcd,
  123. torch.gcd,
  124. torch.Tensor.ge,
  125. torch.ge,
  126. torch.nn.functional.gelu,
  127. torch.nn.functional.glu,
  128. torch.Tensor.gt,
  129. torch.gt,
  130. torch.Tensor.hardshrink,
  131. torch.hardshrink,
  132. torch.nn.functional.hardshrink,
  133. torch.nn.functional.hardsigmoid,
  134. torch.nn.functional.hardswish,
  135. torch.nn.functional.hardtanh,
  136. torch.Tensor.heaviside,
  137. torch.heaviside,
  138. torch.Tensor.hypot,
  139. torch.hypot,
  140. torch.Tensor.i0,
  141. torch.i0,
  142. torch.Tensor.igamma,
  143. torch.igamma,
  144. torch.Tensor.igammac,
  145. torch.igammac,
  146. torch.Tensor.isclose,
  147. torch.isclose,
  148. torch.Tensor.isfinite,
  149. torch.isfinite,
  150. torch.Tensor.isinf,
  151. torch.isinf,
  152. torch.Tensor.isnan,
  153. torch.isnan,
  154. torch.Tensor.isneginf,
  155. torch.isneginf,
  156. torch.Tensor.isposinf,
  157. torch.isposinf,
  158. torch.Tensor.isreal,
  159. torch.isreal,
  160. torch.Tensor.kron,
  161. torch.kron,
  162. torch.Tensor.lcm,
  163. torch.lcm,
  164. torch.Tensor.ldexp,
  165. torch.ldexp,
  166. torch.Tensor.le,
  167. torch.le,
  168. torch.nn.functional.leaky_relu,
  169. torch.Tensor.lerp,
  170. torch.lerp,
  171. torch.Tensor.lgamma,
  172. torch.lgamma,
  173. torch.Tensor.log,
  174. torch.log,
  175. torch.Tensor.log10,
  176. torch.log10,
  177. torch.Tensor.log1p,
  178. torch.log1p,
  179. torch.Tensor.log2,
  180. torch.log2,
  181. torch.nn.functional.logsigmoid,
  182. torch.Tensor.logical_and,
  183. torch.logical_and,
  184. torch.Tensor.logical_not,
  185. torch.logical_not,
  186. torch.Tensor.logical_or,
  187. torch.logical_or,
  188. torch.Tensor.logical_xor,
  189. torch.logical_xor,
  190. torch.Tensor.logit,
  191. torch.logit,
  192. torch.Tensor.lt,
  193. torch.lt,
  194. torch.Tensor.maximum,
  195. torch.maximum,
  196. torch.Tensor.minimum,
  197. torch.minimum,
  198. torch.nn.functional.mish,
  199. torch.Tensor.mvlgamma,
  200. torch.mvlgamma,
  201. torch.Tensor.nan_to_num,
  202. torch.nan_to_num,
  203. torch.Tensor.ne,
  204. torch.ne,
  205. torch.Tensor.neg,
  206. torch.neg,
  207. torch.Tensor.nextafter,
  208. torch.nextafter,
  209. torch.Tensor.outer,
  210. torch.outer,
  211. torch.polar,
  212. torch.Tensor.polygamma,
  213. torch.polygamma,
  214. torch.Tensor.positive,
  215. torch.positive,
  216. torch.Tensor.pow,
  217. torch.pow,
  218. torch.Tensor.prelu,
  219. torch.prelu,
  220. torch.nn.functional.prelu,
  221. torch.Tensor.rad2deg,
  222. torch.rad2deg,
  223. torch.Tensor.reciprocal,
  224. torch.reciprocal,
  225. torch.Tensor.relu,
  226. torch.relu,
  227. torch.nn.functional.relu,
  228. torch.nn.functional.relu6,
  229. torch.Tensor.remainder,
  230. torch.remainder,
  231. torch.Tensor.round,
  232. torch.round,
  233. torch.rrelu,
  234. torch.nn.functional.rrelu,
  235. torch.Tensor.rsqrt,
  236. torch.rsqrt,
  237. torch.rsub,
  238. torch.selu,
  239. torch.nn.functional.selu,
  240. torch.Tensor.sgn,
  241. torch.sgn,
  242. torch.Tensor.sigmoid,
  243. torch.sigmoid,
  244. torch.nn.functional.sigmoid,
  245. torch.Tensor.sign,
  246. torch.sign,
  247. torch.Tensor.signbit,
  248. torch.signbit,
  249. torch.nn.functional.silu,
  250. torch.Tensor.sin,
  251. torch.sin,
  252. torch.Tensor.sinc,
  253. torch.sinc,
  254. torch.Tensor.sinh,
  255. torch.sinh,
  256. torch.nn.functional.softplus,
  257. torch.nn.functional.softshrink,
  258. torch.Tensor.sqrt,
  259. torch.sqrt,
  260. torch.Tensor.square,
  261. torch.square,
  262. torch.Tensor.sub,
  263. torch.sub,
  264. torch.Tensor.tan,
  265. torch.tan,
  266. torch.Tensor.tanh,
  267. torch.tanh,
  268. torch.nn.functional.tanh,
  269. torch.threshold,
  270. torch.nn.functional.threshold,
  271. torch.trapz,
  272. torch.Tensor.true_divide,
  273. torch.true_divide,
  274. torch.Tensor.trunc,
  275. torch.trunc,
  276. torch.Tensor.xlogy,
  277. torch.xlogy,
  278. torch.rand_like,
  279. )