test_operators.py 649 B

123456789101112131415161718192021222324
  1. import torch.library
  2. from torch import Tensor
  3. from torch.autograd import Function
  4. _test_lib_def = torch.library.Library("_inductor_test", "DEF")
  5. _test_lib_def.define("realize(Tensor self) -> Tensor")
  6. _test_lib_impl = torch.library.Library("_inductor_test", "IMPL")
  7. for dispatch_key in ("CPU", "CUDA", "Meta"):
  8. _test_lib_impl.impl("realize", lambda x: x.clone(), dispatch_key)
  9. class Realize(Function):
  10. @staticmethod
  11. def forward(ctx, x):
  12. return torch.ops._inductor_test.realize(x)
  13. @staticmethod
  14. def backward(ctx, grad_output):
  15. return grad_output
  16. def realize(x: Tensor) -> Tensor:
  17. return Realize.apply(x)