_meta_registrations.py 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. import functools
  2. import torch
  3. import torch.library
  4. # Ensure that torch.ops.torchvision is visible
  5. import torchvision.extension # noqa: F401
  6. @functools.lru_cache(None)
  7. def get_meta_lib():
  8. return torch.library.Library("torchvision", "IMPL", "Meta")
  9. def register_meta(op_name, overload_name="default"):
  10. def wrapper(fn):
  11. if torchvision.extension._has_ops():
  12. get_meta_lib().impl(getattr(getattr(torch.ops.torchvision, op_name), overload_name), fn)
  13. return fn
  14. return wrapper
  15. @register_meta("roi_align")
  16. def meta_roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
  17. torch._check(rois.size(1) == 5, lambda: "rois must have shape as Tensor[K, 5]")
  18. torch._check(
  19. input.dtype == rois.dtype,
  20. lambda: (
  21. "Expected tensor for input to have the same type as tensor for rois; "
  22. f"but type {input.dtype} does not equal {rois.dtype}"
  23. ),
  24. )
  25. num_rois = rois.size(0)
  26. _, channels, height, width = input.size()
  27. return input.new_empty((num_rois, channels, pooled_height, pooled_width))
  28. @register_meta("_roi_align_backward")
  29. def meta_roi_align_backward(
  30. grad, rois, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width, sampling_ratio, aligned
  31. ):
  32. torch._check(
  33. grad.dtype == rois.dtype,
  34. lambda: (
  35. "Expected tensor for grad to have the same type as tensor for rois; "
  36. f"but type {grad.dtype} does not equal {rois.dtype}"
  37. ),
  38. )
  39. return grad.new_empty((batch_size, channels, height, width))