123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051 |
- from .modules import * # noqa: F403
- from .parameter import (
- Parameter as Parameter,
- UninitializedParameter as UninitializedParameter,
- UninitializedBuffer as UninitializedBuffer,
- )
- from .parallel import DataParallel as DataParallel
- from . import init
- from . import functional
- from . import utils
- def factory_kwargs(kwargs):
- r"""
- Given kwargs, returns a canonicalized dict of factory kwargs that can be directly passed
- to factory functions like torch.empty, or errors if unrecognized kwargs are present.
- This function makes it simple to write code like this::
- class MyModule(nn.Module):
- def __init__(self, **kwargs):
- factory_kwargs = torch.nn.factory_kwargs(kwargs)
- self.weight = Parameter(torch.empty(10, **factory_kwargs))
- Why should you use this function instead of just passing `kwargs` along directly?
- 1. This function does error validation, so if there are unexpected kwargs we will
- immediately report an error, instead of deferring it to the factory call
- 2. This function supports a special `factory_kwargs` argument, which can be used to
- explicitly specify a kwarg to be used for factory functions, in the event one of the
- factory kwargs conflicts with an already existing argument in the signature (e.g.
- in the signature ``def f(dtype, **kwargs)``, you can specify ``dtype`` for factory
- functions, as distinct from the dtype argument, by saying
- ``f(dtype1, factory_kwargs={"dtype": dtype2})``)
- """
- if kwargs is None:
- return {}
- simple_keys = {"device", "dtype", "memory_format"}
- expected_keys = simple_keys | {"factory_kwargs"}
- if not kwargs.keys() <= expected_keys:
- raise TypeError(f"unexpected kwargs {kwargs.keys() - expected_keys}")
- # guarantee no input kwargs is untouched
- r = dict(kwargs.get("factory_kwargs", {}))
- for k in simple_keys:
- if k in kwargs:
- if k in r:
- raise TypeError(f"{k} specified twice, in **kwargs and in factory_kwargs")
- r[k] = kwargs[k]
- return r
|