constant.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. import operator
  2. from typing import Dict, List
  3. import torch
  4. from .. import variables
  5. from ..exc import unimplemented
  6. from ..utils import HAS_NUMPY, istype, np
  7. from .base import typestr, VariableTracker
  8. class ConstantVariable(VariableTracker):
  9. def __init__(self, value, **kwargs):
  10. super().__init__(**kwargs)
  11. assert not isinstance(value, torch.Tensor)
  12. assert not isinstance(value, torch.SymInt)
  13. assert not isinstance(value, torch.SymFloat)
  14. if HAS_NUMPY and isinstance(value, np.number):
  15. self.value = value.item()
  16. else:
  17. self.value = value
  18. def as_proxy(self):
  19. return self.value
  20. def __str__(self):
  21. # return f"ConstantVariable({self.value})"
  22. return f"ConstantVariable({type(self.value).__name__})"
  23. def python_type(self):
  24. return type(self.value)
  25. def as_python_constant(self):
  26. return self.value
  27. @property
  28. def items(self):
  29. """
  30. Need this when adding a BaseListVariable and a ConstantVariable together.
  31. Happens in detectron2.
  32. """
  33. return self.unpack_var_sequence(tx=None)
  34. def getitem_const(self, arg: VariableTracker):
  35. return ConstantVariable(
  36. self.value[arg.as_python_constant()],
  37. **VariableTracker.propagate([self, arg]),
  38. )
  39. @staticmethod
  40. def is_literal(obj):
  41. if type(obj) in (int, float, bool, type(None), str):
  42. return True
  43. if type(obj) in (list, tuple, set, frozenset):
  44. return all(ConstantVariable.is_literal(x) for x in obj)
  45. return False
  46. def unpack_var_sequence(self, tx):
  47. try:
  48. options = VariableTracker.propagate([self])
  49. return [ConstantVariable(x, **options) for x in self.as_python_constant()]
  50. except TypeError as e:
  51. raise NotImplementedError from e
  52. def const_getattr(self, tx, name):
  53. member = getattr(self.value, name)
  54. if callable(member):
  55. raise NotImplementedError()
  56. return member
  57. def call_method(
  58. self,
  59. tx,
  60. name,
  61. args: "List[VariableTracker]",
  62. kwargs: "Dict[str, VariableTracker]",
  63. ) -> "VariableTracker":
  64. from .tensor import SymNodeVariable
  65. options = VariableTracker.propagate(self, args, kwargs.values())
  66. if istype(self.value, tuple):
  67. # empty tuple constant etc
  68. return variables.TupleVariable(
  69. items=self.unpack_var_sequence(tx), source=self.source, **options
  70. ).call_method(tx, name, args, kwargs)
  71. if any([isinstance(x, SymNodeVariable) for x in args]):
  72. # Promote to SymNodeVariable for operations involving dynamic shapes.
  73. return variables.SymNodeVariable(self.as_proxy(), self.value).call_method(
  74. tx, name, args, kwargs
  75. )
  76. try:
  77. const_args = [a.as_python_constant() for a in args]
  78. const_kwargs = {k: v.as_python_constant() for k, v in kwargs.items()}
  79. except NotImplementedError:
  80. return super().call_method(tx, name, args, kwargs)
  81. def has_arith_binop(num_ty):
  82. return (
  83. isinstance(self.value, num_ty)
  84. and hasattr(operator, name)
  85. and len(args) == 1
  86. and args[0].is_python_constant()
  87. )
  88. if isinstance(self.value, str) and name in str.__dict__.keys():
  89. assert not kwargs
  90. method = getattr(self.value, name)
  91. return ConstantVariable(method(*const_args, **const_kwargs), **options)
  92. elif has_arith_binop(int) or has_arith_binop(float):
  93. op = getattr(operator, name)
  94. add_target = const_args[0]
  95. if isinstance(add_target, (torch.SymInt, torch.SymFloat)):
  96. from .tensor import SymNodeVariable
  97. # Addition between a non sym and sym makes a sym
  98. # sym_num = tx.output.register_attr_or_module(
  99. # add_target, f"sym_shape_{add_target}", source=None
  100. # )
  101. proxy = tx.output.create_proxy(
  102. "call_function", op, (self.value, add_target), {}
  103. )
  104. return SymNodeVariable.create(tx, proxy, add_target, **options)
  105. return ConstantVariable(op(self.value, add_target), **options)
  106. elif name == "__len__" and not (args or kwargs):
  107. return ConstantVariable(len(self.value), **options)
  108. elif name == "__contains__" and len(args) == 1 and args[0].is_python_constant():
  109. assert not kwargs
  110. search = args[0].as_python_constant()
  111. result = search in self.value
  112. return ConstantVariable(result, **options)
  113. unimplemented(f"const method call {typestr(self.value)}.{name}")
  114. class EnumVariable(VariableTracker):
  115. def __init__(self, value, **kwargs):
  116. super().__init__(**kwargs)
  117. self.value = value
  118. def as_proxy(self):
  119. return self.value
  120. def __str__(self):
  121. return f"EnumVariable({type(self.value)})"
  122. def python_type(self):
  123. return type(self.value)
  124. def as_python_constant(self):
  125. return self.value
  126. def const_getattr(self, tx, name):
  127. member = getattr(self.value, name)
  128. if callable(member):
  129. raise NotImplementedError()
  130. return member