1234567891011121314151617181920212223242526272829303132333435 |
- import functools
- from inspect import signature
- from .common_op_utils import _basic_validation
- """
- Common utilities to register ops on ShardedTensor, ReplicatedTensor
- and PartialTensor.
- """
- def _register_op(op, func, op_table):
- """
- Performs basic validation and registers the provided op in the given
- op_table.
- """
- if len(signature(func).parameters) != 4:
- raise TypeError(
- f'Custom sharded op function expects signature: '
- f'(types, args, kwargs, process_group), but received '
- f'signature: {signature(func)}')
- op_table[op] = func
- def _decorator_func(wrapped_func, op, op_table):
- """
- Decorator function to register the given ``op`` in the provided
- ``op_table``
- """
- @functools.wraps(wrapped_func)
- def wrapper(types, args, kwargs, process_group):
- _basic_validation(op, args, kwargs)
- return wrapped_func(types, args, kwargs, process_group)
- _register_op(op, wrapper, op_table)
- return wrapper
|