annotate.py 929 B

123456789101112131415161718192021
  1. from torch.fx.proxy import Proxy
  2. from ._compatibility import compatibility
  3. @compatibility(is_backward_compatible=False)
  4. def annotate(val, type):
  5. # val could be either a regular value (not tracing)
  6. # or fx.Proxy (tracing)
  7. if isinstance(val, Proxy):
  8. if val.node.type:
  9. raise RuntimeError(f"Tried to annotate a value that already had a type on it!"
  10. f" Existing type is {val.node.type} "
  11. f"and new type is {type}. "
  12. f"This could happen if you tried to annotate a function parameter "
  13. f"value (in which case you should use the type slot "
  14. f"on the function signature) or you called "
  15. f"annotate on the same value twice")
  16. else:
  17. val.node.type = type
  18. return val
  19. else:
  20. return val