1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950 |
- import functools
- import torch
- import torch.library
- # Ensure that torch.ops.torchvision is visible
- import torchvision.extension # noqa: F401
- @functools.lru_cache(None)
- def get_meta_lib():
- return torch.library.Library("torchvision", "IMPL", "Meta")
- def register_meta(op_name, overload_name="default"):
- def wrapper(fn):
- if torchvision.extension._has_ops():
- get_meta_lib().impl(getattr(getattr(torch.ops.torchvision, op_name), overload_name), fn)
- return fn
- return wrapper
- @register_meta("roi_align")
- def meta_roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
- torch._check(rois.size(1) == 5, lambda: "rois must have shape as Tensor[K, 5]")
- torch._check(
- input.dtype == rois.dtype,
- lambda: (
- "Expected tensor for input to have the same type as tensor for rois; "
- f"but type {input.dtype} does not equal {rois.dtype}"
- ),
- )
- num_rois = rois.size(0)
- _, channels, height, width = input.size()
- return input.new_empty((num_rois, channels, pooled_height, pooled_width))
- @register_meta("_roi_align_backward")
- def meta_roi_align_backward(
- grad, rois, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width, sampling_ratio, aligned
- ):
- torch._check(
- grad.dtype == rois.dtype,
- lambda: (
- "Expected tensor for grad to have the same type as tensor for rois; "
- f"but type {grad.dtype} does not equal {rois.dtype}"
- ),
- )
- return grad.new_empty((batch_size, channels, height, width))
|