123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545 |
- from dataclasses import dataclass
- from typing import Dict, List, Optional, Sequence, Tuple, Union
- import torchgen.api.ufunc as ufunc
- from torchgen.api.translate import translate
- from torchgen.api.types import (
- BaseCType,
- Binding,
- CType,
- Expr,
- NamedCType,
- opmath_t,
- scalar_t,
- StructuredImplSignature,
- VectorizedCType,
- )
- from torchgen.api.ufunc import UfunctorBindings
- from torchgen.context import with_native_function
- from torchgen.model import (
- Argument,
- BaseTy,
- BaseType,
- DispatchKey,
- NativeFunctionsGroup,
- ScalarType,
- UfuncKey,
- )
- from torchgen.utils import OrderedSet
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
- #
- # CUDA STUFF
- #
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
- # NB: not bothering to generate dispatch stub forward declaration in header,
- # we can just paste it whereever necessary
- # TODO: use BackendIndex
- # dispatch_key: DispatchKey # only CPU/CUDA right now
- # Represents functors for implementing CUDA ufuncs.
- # Functors are templated by scalar_t because when USERS instantiate functors
- # they are templated. A functor looks something like this:
- #
- # template <typename scalar_t>
- # struct CUDAFunctorOnSelf_add {
- # using opmath_t = at::opmath_type<scalar_t>;
- # opmath_t other_;
- # opmath_t alpha_;
- # CUDAFunctorOnSelf_add(opmath_t other, opmath_t alpha)
- # : other_(other), alpha_(alpha) {}
- # __device__ scalar_t operator()(scalar_t self) {
- # return ufunc::add(static_cast<opmath_t>(self), other_, alpha_);
- # }
- # };
- #
- @dataclass(frozen=True)
- class UfunctorSignature:
- g: NativeFunctionsGroup
- scalar_tensor_idx: Optional[int]
- name: str
- def arguments(self) -> UfunctorBindings:
- return ufunc.ufunctor_arguments(
- self.g, scalar_tensor_idx=self.scalar_tensor_idx, scalar_t=scalar_t
- )
- def fields(self) -> List[Binding]:
- # fields are renamed to have a trailing underscore, as is conventional
- return [b.rename(f"{b.name}_") for b in self.arguments().ctor]
- def returns_type(self) -> CType:
- # TODO: don't hardcode; return type will be inferred based on tags on
- # the native function
- return BaseCType(scalar_t)
- def decl_fields(self) -> str:
- return "\n".join(f"{f.type} {f.name};" for f in self.fields())
- def inline_defn_ctor(self) -> str:
- args_str = ", ".join(a.decl() for a in self.arguments().ctor)
- # NB: hypothetically could do this with translate but the
- # transition here is very regular
- init_str = ", ".join(f"{a.name}_({a.name})" for a in self.arguments().ctor)
- return f"{self.name}({args_str}) : {init_str} {{}}"
- def decl_apply(self) -> str:
- args_str = ", ".join(a.decl() for a in self.arguments().apply)
- return f"{self.returns_type().cpp_type()} operator()({args_str}) const"
- @dataclass(frozen=True)
- class UfuncSignature:
- g: NativeFunctionsGroup
- name: str
- compute_t: CType
- def arguments(self) -> List[Binding]:
- return ufunc.ufunc_arguments(self.g, compute_t=self.compute_t)
- def call(self, ctx: Sequence[Union[Binding, Expr]]) -> str:
- return f"{self.name}({', '.join(a.expr for a in translate(ctx, self.arguments()))})"
- # steps:
- # 1. take the functional signature
- # 2. use api.ufunc to convert it to template signature. this establishes
- # the type of the template function
- # 3. use api.ufunc (II) to generate a split struct / operator() signature.
- # this establish context in which we call the template signature
- #
- # StructuredImplSignature context
- # ~> functor constructor sig
- #
- # Functor constructor context
- # ~> functor fields sig
- #
- # Functor apply context (functor fields + functor apply sig)
- # ~> template sig
- #
- def eligible_for_binary_scalar_specialization(g: NativeFunctionsGroup) -> bool:
- num_tensors = sum(
- 1 for a in g.functional.func.arguments.flat_non_out if a.type.is_tensor_like()
- )
- return num_tensors == 2
- def compute_ufunc_cuda_functors(
- g: NativeFunctionsGroup,
- ) -> Tuple[Dict[ScalarType, Dict[UfuncKey, UfunctorSignature]], str]:
- # First, build the functors.
- ufunctor_sigs: Dict[ScalarType, Dict[UfuncKey, UfunctorSignature]] = {}
- ufunctors: List[str] = []
- loops = g.out.ufunc_inner_loop
- scalar_tensor_idx_lookup = {
- UfuncKey.CUDAFunctorOnSelf: 1,
- UfuncKey.CUDAFunctorOnOther: 0,
- UfuncKey.CUDAFunctor: None,
- }
- if eligible_for_binary_scalar_specialization(g):
- keys = [
- UfuncKey.CUDAFunctorOnSelf,
- UfuncKey.CUDAFunctorOnOther,
- UfuncKey.CUDAFunctor,
- ]
- else:
- keys = [UfuncKey.CUDAFunctor]
- for k in [UfuncKey.CUDAFunctorOnSelf, UfuncKey.CUDAFunctorOnOther]:
- assert k not in loops, f"cannot use {k} on non-binary function"
- for k in keys:
- # If the key was directly defined, skip functor codegen; we assume the
- # user already done it for us
- if k in loops:
- ufunctor_sig = UfunctorSignature(
- g, scalar_tensor_idx=scalar_tensor_idx_lookup[k], name=loops[k].name
- )
- for dtype in loops[k].supported_dtypes:
- ufunctor_sigs.setdefault(dtype, {})[k] = ufunctor_sig
- continue
- # Note [ScalarOnly and Generic must match names for CUDA]
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
- # Otherwise, look in ANY of the generic entries. For simplicity of
- # codegen, both ScalarOnly and Generic are defined, the ufunc name
- # must match (if they didn't match, we'd have to generate distinct
- # functors per dtype, which is awful, so we're not going to do it unless
- # someone really forces us to)
- ufunc_name = None
- supported_dtypes: OrderedSet[ScalarType] = OrderedSet()
- for lk in [UfuncKey.ScalarOnly, UfuncKey.Generic]:
- if lk not in loops:
- continue
- if ufunc_name is None:
- ufunc_name = loops[lk].name
- else:
- # See Note [ScalarOnly and Generic must match names for CUDA]
- assert (
- ufunc_name == loops[lk].name
- ), "ScalarOnly and Generic must have same ufunc name"
- supported_dtypes |= loops[lk].supported_dtypes
- assert ufunc_name is not None
- name = f"{k}_{ufunc_name}"
- ufunctor_sig = UfunctorSignature(
- g, scalar_tensor_idx=scalar_tensor_idx_lookup[k], name=name
- )
- for dtype in supported_dtypes:
- ufunctor_sigs.setdefault(dtype, {})[k] = ufunctor_sig
- ufunc_sig = UfuncSignature(
- g, name=f"ufunc::{ufunc_name}", compute_t=BaseCType(opmath_t)
- )
- apply_ctx = ufunctor_sig.fields() + ufunctor_sig.arguments().apply
- ufunctors.append(
- f"""
- template <typename scalar_t>
- struct {ufunctor_sig.name} {{
- using opmath_t = at::opmath_type<scalar_t>;
- {ufunctor_sig.decl_fields()}
- {ufunctor_sig.inline_defn_ctor()}
- __device__ {ufunctor_sig.decl_apply()} {{
- return {ufunc_sig.call(apply_ctx)};
- }}
- }};
- """
- )
- return ufunctor_sigs, "\n".join(ufunctors)
- @dataclass(frozen=True)
- class BinaryScalarSpecializationConfig:
- scalar_idx: int
- ctor_tensor: str
- ufunc_key: UfuncKey
- BinaryScalarSpecializationConfigs = [
- BinaryScalarSpecializationConfig(
- scalar_idx=0,
- ctor_tensor="self",
- ufunc_key=UfuncKey.CUDAFunctorOnOther,
- ),
- BinaryScalarSpecializationConfig(
- scalar_idx=1,
- ctor_tensor="other",
- ufunc_key=UfuncKey.CUDAFunctorOnSelf,
- ),
- ]
- def compute_ufunc_cuda_dtype_body(
- g: NativeFunctionsGroup,
- dtype: ScalarType,
- inner_loops: Dict[UfuncKey, UfunctorSignature],
- parent_ctx: Sequence[Binding],
- ) -> str:
- body = "using opmath_t = at::opmath_type<scalar_t>;"
- body += "if (false) {}\n" # for ease of codegen
- for config in BinaryScalarSpecializationConfigs:
- if config.ufunc_key not in inner_loops:
- continue
- ufunctor_sig = inner_loops[config.ufunc_key]
- scalar_idx = config.scalar_idx + 1
- # Make a copy and at the same time widen the type (not permissible
- # without copy; we don't want to mutate the input argument anyway)
- ctx: List[Union[Expr, Binding]] = list(parent_ctx)
- ctx.append(
- Expr(
- expr=f"iter.scalar_value<opmath_t>({scalar_idx})",
- type=NamedCType(config.ctor_tensor, BaseCType(opmath_t)),
- )
- )
- ufunctor_ctor_exprs_str = ", ".join(
- a.expr for a in translate(ctx, ufunctor_sig.arguments().ctor)
- )
- # NB: ufunctor must be allocated before iter.remove_operand is called,
- # as it relies on iter
- body += f"""\
- else if (iter.is_cpu_scalar({scalar_idx})) {{
- {ufunctor_sig.name}<scalar_t> ufunctor({ufunctor_ctor_exprs_str});
- iter.remove_operand({scalar_idx});
- gpu_kernel(iter, ufunctor);
- }}"""
- ufunctor_sig = inner_loops[UfuncKey.CUDAFunctor]
- ufunctor_ctor_exprs_str = ", ".join(
- a.expr for a in translate(parent_ctx, ufunctor_sig.arguments().ctor)
- )
- body += f"""
- else {{
- gpu_kernel(iter, {ufunctor_sig.name}<scalar_t>({ufunctor_ctor_exprs_str}));
- }}
- """
- return body
- @with_native_function
- def compute_ufunc_cuda(g: NativeFunctionsGroup) -> str:
- # First, build the functors, indexing them by dtype
- ufunctor_sigs, ufunctors = compute_ufunc_cuda_functors(g)
- # Next, build the conditionals
- sig = StructuredImplSignature(g, ufunc.kernel_name(g, DispatchKey.CUDA))
- dtype_cases = []
- for dtype, inner_ufunc_sigs in ufunctor_sigs.items():
- dtype_cases.append(
- f"""
- AT_DISPATCH_CASE(at::ScalarType::{dtype},
- [&]() {{
- {compute_ufunc_cuda_dtype_body(g, dtype, inner_ufunc_sigs, sig.arguments())}
- }}
- )
- """
- )
- dtype_cases_str = "\n".join(dtype_cases)
- stub_sig = StubSignature(g)
- return f"""
- {ufunctors}
- {stub_sig.type_defn()};
- {stub_sig.dispatch_decl()};
- {stub_sig.kernel_defn()} {{
- AT_DISPATCH_SWITCH(iter.common_dtype(), "{sig.name}",
- {dtype_cases_str}
- );
- }}
- REGISTER_DISPATCH({stub_sig.name}, &{stub_sig.kernel_name});
- {sig.defn()} {{
- {stub_sig.direct_call(sig.arguments())};
- }}
- """
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
- #
- # CPU STUFF
- #
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
- @dataclass(frozen=True)
- class StubSignature:
- g: NativeFunctionsGroup
- @property
- def name(self) -> str:
- return f"{str(self.g.functional.func.name.name)}_stub"
- @property
- def kernel_name(self) -> str:
- return f"{str(self.g.functional.func.name.name)}_kernel"
- @property
- def type_name(self) -> str:
- return f"{str(self.g.functional.func.name.name)}_fn"
- def arguments(self) -> List[Binding]:
- return ufunc.stub_arguments(self.g)
- def type(self) -> str:
- cpp_args = self.arguments()
- return f"void(*)(TensorIteratorBase&, {', '.join(a.type for a in cpp_args)})"
- def dispatch_decl(self) -> str:
- return f"DECLARE_DISPATCH({self.type_name}, {self.name})"
- def dispatch_defn(self) -> str:
- return f"DEFINE_DISPATCH({self.name})"
- def kernel_defn(self) -> str:
- return f"void {self.kernel_name}(TensorIteratorBase& iter, {', '.join(a.defn() for a in self.arguments())})"
- def type_defn(self) -> str:
- return f"using {self.type_name} = {self.type()}"
- # must be called from context where this is TensorIteratorBase*
- def call(self, ctx: Sequence[Binding]) -> str:
- return f"{self.name}(device_type(), *this, {', '.join(a.expr for a in translate(ctx, self.arguments()))})"
- # used in CUDA to skip the unnecessary dynamic dispatch
- def direct_call(self, ctx: Sequence[Binding]) -> str:
- return f"{self.kernel_name}(*this, {', '.join(a.expr for a in translate(ctx, self.arguments()))})"
- @with_native_function
- def compute_ufunc_cpu(g: NativeFunctionsGroup) -> str:
- stub_sig = StubSignature(g)
- sig = StructuredImplSignature(g, ufunc.kernel_name(g, DispatchKey.CPU))
- return f"""
- {stub_sig.type_defn()};
- {stub_sig.dispatch_decl()};
- {stub_sig.dispatch_defn()};
- {sig.defn()} {{
- {stub_sig.call(sig.arguments())};
- }}
- """
- def compute_ufunc_cpu_dtype_body(
- g: NativeFunctionsGroup,
- dtype: ScalarType,
- inner_loops: Dict[UfuncKey, UfuncSignature],
- parent_ctx: Sequence[Binding],
- ) -> str:
- assert UfuncKey.CPUScalar in inner_loops, f"{dtype}, {inner_loops.keys()}"
- assert inner_loops.keys() <= {UfuncKey.CPUScalar, UfuncKey.CPUVector}
- scalar_loop = inner_loops[UfuncKey.CPUScalar]
- vec_loop = None
- if UfuncKey.CPUVector in inner_loops:
- vec_loop = inner_loops[UfuncKey.CPUVector]
- # NB: We DON'T use translate here, because translate is
- # incapable of CSE'ing the scalar accesses in case it is also
- # used by Vectorized; also, the unpacking here is very simple
- # and only affects Scalar; everything else is implicitly captured
- # by the lambda
- # Setup scalar in scope
- body = []
- ctx = []
- for b in parent_ctx:
- if isinstance(b.argument, Argument) and b.argument.type != BaseType(
- BaseTy.Scalar
- ):
- continue
- body.append(f"auto _s_{b.name} = {b.name}.to<scalar_t>();")
- ctx.append(Expr(f"_s_{b.name}", NamedCType(b.nctype.name, BaseCType(scalar_t))))
- if vec_loop is not None:
- for b in parent_ctx:
- if isinstance(b.argument, Argument) and b.argument.type != BaseType(
- BaseTy.Scalar
- ):
- continue
- body.append(
- f"auto _v_{b.name} = at::vec::Vectorized<scalar_t>(_s_{b.name});"
- )
- ctx.append(
- Expr(
- f"_v_{b.name}",
- NamedCType(b.nctype.name, VectorizedCType(BaseCType(scalar_t))),
- )
- )
- # Setup lambda signature
- # NB: simplified version of ufunctor_arguments
- scalar_bindings = []
- vec_bindings = []
- for a in g.functional.func.arguments.flat_non_out:
- if not a.type.is_tensor_like():
- continue
- assert a.type == BaseType(BaseTy.Tensor)
- scalar_bindings.append(
- Binding(
- name=a.name,
- nctype=NamedCType(a.name, BaseCType(scalar_t)),
- argument=a,
- )
- )
- if vec_loop is not None:
- vec_bindings.append(
- Binding(
- name=a.name,
- nctype=NamedCType(a.name, VectorizedCType(BaseCType(scalar_t))),
- argument=a,
- )
- )
- def with_ctx(b: Sequence[Binding]) -> List[Union[Expr, Binding]]:
- r: List[Union[Expr, Binding]] = []
- r.extend(ctx)
- r.extend(b)
- return r
- body_str = "\n".join(body)
- if vec_loop is not None:
- return f"""
- {body_str}
- cpu_kernel_vec(iter,
- [=]({', '.join(b.decl() for b in scalar_bindings)}) {{ return {scalar_loop.call(with_ctx(scalar_bindings))}; }},
- [=]({', '.join(b.decl() for b in vec_bindings)}) {{ return {vec_loop.call(with_ctx(vec_bindings))}; }}
- );
- """
- else:
- return f"""
- {body_str}
- cpu_kernel(iter,
- [=]({', '.join(b.decl() for b in scalar_bindings)}) {{ return {scalar_loop.call(with_ctx(scalar_bindings))}; }}
- );
- """
- @with_native_function
- def compute_ufunc_cpu_kernel(g: NativeFunctionsGroup) -> str:
- stub_sig = StubSignature(g)
- # Reindex the ufunc by dtypes; processing generic/scalaronly as well
- loops = g.out.ufunc_inner_loop
- ufunc_sigs: Dict[ScalarType, Dict[UfuncKey, UfuncSignature]] = {}
- for k in [UfuncKey.CPUScalar, UfuncKey.CPUVector]:
- lks = []
- # ORDER MATTERS: this specifies overriding precedence
- if k in loops: # should happen rarely
- lks.append(k)
- if UfuncKey.ScalarOnly in loops and k is UfuncKey.CPUScalar:
- lks.append(UfuncKey.ScalarOnly)
- if UfuncKey.Generic in loops:
- lks.append(UfuncKey.Generic)
- # TODO: don't hardcode ufunc:: namespace here, should be centralized smh
- for lk in lks:
- for dtype in loops[lk].supported_dtypes:
- compute_t: CType
- if k is UfuncKey.CPUScalar:
- compute_t = BaseCType(scalar_t)
- elif k is UfuncKey.CPUVector:
- compute_t = VectorizedCType(BaseCType(scalar_t))
- else:
- raise AssertionError()
- inner_ufunc_sigs = ufunc_sigs.setdefault(dtype, {})
- if k not in inner_ufunc_sigs:
- inner_ufunc_sigs[k] = UfuncSignature(
- g, name=f"ufunc::{loops[lk].name}", compute_t=compute_t
- )
- # Build the conditionals
- dtype_cases = []
- for dtype, inner_ufunc_sigs in ufunc_sigs.items():
- dtype_cases.append(
- f"""
- AT_DISPATCH_CASE(at::ScalarType::{dtype},
- [&]() {{
- {compute_ufunc_cpu_dtype_body(g, dtype, inner_ufunc_sigs, stub_sig.arguments())}
- }}
- )
- """
- )
- dtype_cases_str = "\n".join(dtype_cases)
- return f"""
- namespace {{
- {stub_sig.kernel_defn()} {{
- AT_DISPATCH_SWITCH(iter.common_dtype(), "{stub_sig.name}",
- {dtype_cases_str}
- );
- }}
- }} // anonymous namespace
- {stub_sig.type_defn()};
- {stub_sig.dispatch_decl()};
- REGISTER_DISPATCH({stub_sig.name}, &{stub_sig.kernel_name});
- """
|