op_registry_utils.py 1009 B

1234567891011121314151617181920212223242526272829303132333435
  1. import functools
  2. from inspect import signature
  3. from .common_op_utils import _basic_validation
  4. """
  5. Common utilities to register ops on ShardedTensor, ReplicatedTensor
  6. and PartialTensor.
  7. """
  8. def _register_op(op, func, op_table):
  9. """
  10. Performs basic validation and registers the provided op in the given
  11. op_table.
  12. """
  13. if len(signature(func).parameters) != 4:
  14. raise TypeError(
  15. f'Custom sharded op function expects signature: '
  16. f'(types, args, kwargs, process_group), but received '
  17. f'signature: {signature(func)}')
  18. op_table[op] = func
  19. def _decorator_func(wrapped_func, op, op_table):
  20. """
  21. Decorator function to register the given ``op`` in the provided
  22. ``op_table``
  23. """
  24. @functools.wraps(wrapped_func)
  25. def wrapper(types, args, kwargs, process_group):
  26. _basic_validation(op, args, kwargs)
  27. return wrapped_func(types, args, kwargs, process_group)
  28. _register_op(op, wrapper, op_table)
  29. return wrapper