import functools from inspect import signature from .common_op_utils import _basic_validation """ Common utilities to register ops on ShardedTensor, ReplicatedTensor and PartialTensor. """ def _register_op(op, func, op_table): """ Performs basic validation and registers the provided op in the given op_table. """ if len(signature(func).parameters) != 4: raise TypeError( f'Custom sharded op function expects signature: ' f'(types, args, kwargs, process_group), but received ' f'signature: {signature(func)}') op_table[op] = func def _decorator_func(wrapped_func, op, op_table): """ Decorator function to register the given ``op`` in the provided ``op_table`` """ @functools.wraps(wrapped_func) def wrapper(types, args, kwargs, process_group): _basic_validation(op, args, kwargs) return wrapped_func(types, args, kwargs, process_group) _register_op(op, wrapper, op_table) return wrapper