123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282 |
- # Copyright (c) Facebook, Inc. and its affiliates.
- # All rights reserved.
- #
- # This source code is licensed under the BSD-style license found in the
- # LICENSE file in the root directory of this source tree.
- import torch
- # pointwise operators can go through a faster pathway
- tensor_magic_methods = [
- 'add',
- ''
- ]
- pointwise_magic_methods_with_reverse = (
- 'add', 'sub', 'mul', 'floordiv', 'div', 'truediv', 'mod',
- 'pow', 'lshift', 'rshift', 'and', 'or', 'xor'
- )
- pointwise_magic_methods = (
- *(x for m in pointwise_magic_methods_with_reverse for x in (m, 'r' + m)),
- 'eq', 'gt', 'le', 'lt', 'ge', 'gt', 'ne', 'neg', 'pos',
- 'abs', 'invert',
- 'iadd', 'isub', 'imul', 'ifloordiv', 'idiv',
- 'itruediv', 'imod', 'ipow', 'ilshift', 'irshift', 'iand',
- 'ior', 'ixor',
- 'int', 'long', 'float', 'complex',
- )
- pointwise_methods = (
- *(f'__{m}__' for m in pointwise_magic_methods),
- )
- pointwise = (
- *(getattr(torch.Tensor, m) for m in pointwise_methods),
- torch.nn.functional.dropout,
- torch.where,
- torch.Tensor.abs,
- torch.abs,
- torch.Tensor.acos,
- torch.acos,
- torch.Tensor.acosh,
- torch.acosh,
- torch.Tensor.add,
- torch.add,
- torch.Tensor.addcdiv,
- torch.addcdiv,
- torch.Tensor.addcmul,
- torch.addcmul,
- torch.Tensor.addr,
- torch.addr,
- torch.Tensor.angle,
- torch.angle,
- torch.Tensor.asin,
- torch.asin,
- torch.Tensor.asinh,
- torch.asinh,
- torch.Tensor.atan,
- torch.atan,
- torch.Tensor.atan2,
- torch.atan2,
- torch.Tensor.atanh,
- torch.atanh,
- torch.Tensor.bitwise_and,
- torch.bitwise_and,
- torch.Tensor.bitwise_left_shift,
- torch.bitwise_left_shift,
- torch.Tensor.bitwise_not,
- torch.bitwise_not,
- torch.Tensor.bitwise_or,
- torch.bitwise_or,
- torch.Tensor.bitwise_right_shift,
- torch.bitwise_right_shift,
- torch.Tensor.bitwise_xor,
- torch.bitwise_xor,
- torch.Tensor.ceil,
- torch.ceil,
- torch.celu,
- torch.nn.functional.celu,
- torch.Tensor.clamp,
- torch.clamp,
- torch.Tensor.clamp_max,
- torch.clamp_max,
- torch.Tensor.clamp_min,
- torch.clamp_min,
- torch.Tensor.copysign,
- torch.copysign,
- torch.Tensor.cos,
- torch.cos,
- torch.Tensor.cosh,
- torch.cosh,
- torch.Tensor.deg2rad,
- torch.deg2rad,
- torch.Tensor.digamma,
- torch.digamma,
- torch.Tensor.div,
- torch.div,
- torch.dropout,
- torch.nn.functional.dropout,
- torch.nn.functional.elu,
- torch.Tensor.eq,
- torch.eq,
- torch.Tensor.erf,
- torch.erf,
- torch.Tensor.erfc,
- torch.erfc,
- torch.Tensor.erfinv,
- torch.erfinv,
- torch.Tensor.exp,
- torch.exp,
- torch.Tensor.exp2,
- torch.exp2,
- torch.Tensor.expm1,
- torch.expm1,
- torch.feature_dropout,
- torch.Tensor.float_power,
- torch.float_power,
- torch.Tensor.floor,
- torch.floor,
- torch.Tensor.floor_divide,
- torch.floor_divide,
- torch.Tensor.fmod,
- torch.fmod,
- torch.Tensor.frac,
- torch.frac,
- torch.Tensor.frexp,
- torch.frexp,
- torch.Tensor.gcd,
- torch.gcd,
- torch.Tensor.ge,
- torch.ge,
- torch.nn.functional.gelu,
- torch.nn.functional.glu,
- torch.Tensor.gt,
- torch.gt,
- torch.Tensor.hardshrink,
- torch.hardshrink,
- torch.nn.functional.hardshrink,
- torch.nn.functional.hardsigmoid,
- torch.nn.functional.hardswish,
- torch.nn.functional.hardtanh,
- torch.Tensor.heaviside,
- torch.heaviside,
- torch.Tensor.hypot,
- torch.hypot,
- torch.Tensor.i0,
- torch.i0,
- torch.Tensor.igamma,
- torch.igamma,
- torch.Tensor.igammac,
- torch.igammac,
- torch.Tensor.isclose,
- torch.isclose,
- torch.Tensor.isfinite,
- torch.isfinite,
- torch.Tensor.isinf,
- torch.isinf,
- torch.Tensor.isnan,
- torch.isnan,
- torch.Tensor.isneginf,
- torch.isneginf,
- torch.Tensor.isposinf,
- torch.isposinf,
- torch.Tensor.isreal,
- torch.isreal,
- torch.Tensor.kron,
- torch.kron,
- torch.Tensor.lcm,
- torch.lcm,
- torch.Tensor.ldexp,
- torch.ldexp,
- torch.Tensor.le,
- torch.le,
- torch.nn.functional.leaky_relu,
- torch.Tensor.lerp,
- torch.lerp,
- torch.Tensor.lgamma,
- torch.lgamma,
- torch.Tensor.log,
- torch.log,
- torch.Tensor.log10,
- torch.log10,
- torch.Tensor.log1p,
- torch.log1p,
- torch.Tensor.log2,
- torch.log2,
- torch.nn.functional.logsigmoid,
- torch.Tensor.logical_and,
- torch.logical_and,
- torch.Tensor.logical_not,
- torch.logical_not,
- torch.Tensor.logical_or,
- torch.logical_or,
- torch.Tensor.logical_xor,
- torch.logical_xor,
- torch.Tensor.logit,
- torch.logit,
- torch.Tensor.lt,
- torch.lt,
- torch.Tensor.maximum,
- torch.maximum,
- torch.Tensor.minimum,
- torch.minimum,
- torch.nn.functional.mish,
- torch.Tensor.mvlgamma,
- torch.mvlgamma,
- torch.Tensor.nan_to_num,
- torch.nan_to_num,
- torch.Tensor.ne,
- torch.ne,
- torch.Tensor.neg,
- torch.neg,
- torch.Tensor.nextafter,
- torch.nextafter,
- torch.Tensor.outer,
- torch.outer,
- torch.polar,
- torch.Tensor.polygamma,
- torch.polygamma,
- torch.Tensor.positive,
- torch.positive,
- torch.Tensor.pow,
- torch.pow,
- torch.Tensor.prelu,
- torch.prelu,
- torch.nn.functional.prelu,
- torch.Tensor.rad2deg,
- torch.rad2deg,
- torch.Tensor.reciprocal,
- torch.reciprocal,
- torch.Tensor.relu,
- torch.relu,
- torch.nn.functional.relu,
- torch.nn.functional.relu6,
- torch.Tensor.remainder,
- torch.remainder,
- torch.Tensor.round,
- torch.round,
- torch.rrelu,
- torch.nn.functional.rrelu,
- torch.Tensor.rsqrt,
- torch.rsqrt,
- torch.rsub,
- torch.selu,
- torch.nn.functional.selu,
- torch.Tensor.sgn,
- torch.sgn,
- torch.Tensor.sigmoid,
- torch.sigmoid,
- torch.nn.functional.sigmoid,
- torch.Tensor.sign,
- torch.sign,
- torch.Tensor.signbit,
- torch.signbit,
- torch.nn.functional.silu,
- torch.Tensor.sin,
- torch.sin,
- torch.Tensor.sinc,
- torch.sinc,
- torch.Tensor.sinh,
- torch.sinh,
- torch.nn.functional.softplus,
- torch.nn.functional.softshrink,
- torch.Tensor.sqrt,
- torch.sqrt,
- torch.Tensor.square,
- torch.square,
- torch.Tensor.sub,
- torch.sub,
- torch.Tensor.tan,
- torch.tan,
- torch.Tensor.tanh,
- torch.tanh,
- torch.nn.functional.tanh,
- torch.threshold,
- torch.nn.functional.threshold,
- torch.trapz,
- torch.Tensor.true_divide,
- torch.true_divide,
- torch.Tensor.trunc,
- torch.trunc,
- torch.Tensor.xlogy,
- torch.xlogy,
- torch.rand_like,
- )
|