123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470 |
- from typing import Any, Dict, List, Optional, Tuple, Union
- from torchgen.api.types import (
- BaseCppType,
- BaseCType,
- boolT,
- CType,
- deviceT,
- doubleT,
- layoutT,
- ListCType,
- longT,
- memoryFormatT,
- NamedCType,
- OptionalCType,
- scalarT,
- scalarTypeT,
- stringT,
- SymIntT,
- VectorCType,
- )
- from torchgen.model import (
- Argument,
- BaseTy,
- BaseType,
- FunctionSchema,
- ListType,
- OperatorName,
- OptionalType,
- Return,
- TensorOptionsArguments,
- Type,
- )
- _valueT = None
- # A ValueT is an IR type which represents the computation of a Tensor. In other
- # words, a PyTorch user will do operations on lazy tensors, and each output lazy
- # tensor internally tracks a ValueT representing the IR node that would have
- # actually produced the value of this tensor for real.
- #
- # This is configurable because different lazy tensor backends (LTC vs XLA) will
- # have different IR representations. (Though, arguably, after unification they
- # shouldn't!)
- def getValueT() -> BaseCppType:
- global _valueT
- if not _valueT:
- raise NotImplementedError(
- "The value type needs to be set with setValueT() in run_gen_lazy_tensor()"
- )
- return _valueT
- def setValueT(val: BaseCppType) -> None:
- global _valueT
- _valueT = val
- # this is a bad hack. I need to refactor the data model to represent each arg in the schema as an object,
- # making it easier to represent special properties of an arg.
- tensorListValueT = BaseCppType("torch::lazy", "Value")
- def process_ir_type(
- typ: Type, properties: "LazyIrProperties", *, symint: bool
- ) -> Union[BaseCType, VectorCType, OptionalCType, ListCType]:
- """
- This function takes a type from NativeFunctions and converts it for use with
- lazy tensor codegen.
- Type conversion for lazy currently consists of
- (1) changing at::Tensors into lazy::Values
- (2) wrapping everything in a BaseCType
- (3) making cpp-reference types into cpp-value types (e.g. vector instead of IntArrayRef)
- (1) converts at::Tensors to lazy::Values (which wrap lazy::Nodes, with which Lazy IR represents tensors.)
- There is special handling for Optional[Tensor] or List[Tensor], etc- hence 'tensor-like'
- This is incomplete- there are assertions in places that it's expected to need to add
- more types as the codegen is used with more operators.
- """
- if isinstance(typ, BaseType):
- if typ.name == BaseTy.Tensor:
- return BaseCType(getValueT())
- elif typ.name == BaseTy.Scalar:
- if properties.TreatScalarsAsConstants:
- return BaseCType(scalarT)
- # at::scalar has special handling,
- # and is wrapped in an lazy::Value just like at::tensor
- return BaseCType(getValueT())
- elif typ.name == BaseTy.ScalarType:
- return BaseCType(scalarTypeT)
- elif typ.name == BaseTy.int:
- return BaseCType(longT)
- elif typ.name == BaseTy.SymInt:
- if symint:
- return BaseCType(getValueT())
- else:
- return BaseCType(longT)
- elif typ.name == BaseTy.bool:
- return BaseCType(boolT)
- elif typ.name == BaseTy.float:
- return BaseCType(doubleT)
- elif typ.name == BaseTy.str:
- return BaseCType(stringT)
- elif typ.name == BaseTy.Device:
- return BaseCType(deviceT)
- elif typ.name == BaseTy.Layout:
- return BaseCType(layoutT)
- elif typ.name == BaseTy.MemoryFormat:
- return BaseCType(memoryFormatT)
- else:
- raise AssertionError(f"TODO add support for type {repr(typ)}")
- elif isinstance(typ, OptionalType):
- return OptionalCType(process_ir_type(typ.elem, properties, symint=symint))
- elif isinstance(typ, ListType):
- if str(typ.elem) == "Tensor?":
- # TODO(whc) is this actually correct? or should it use a Vector like above
- return ListCType(OptionalCType(BaseCType(getValueT())))
- elif str(typ.elem) == "Tensor":
- # this is a TensorList which comes in from GetTensorList as a Value
- return BaseCType(tensorListValueT)
- elif typ.elem == BaseType(BaseTy.SymInt):
- # TODO: return a value type. The problem here is analogous to
- # the problem with tensorListValueT: if you have SymInt[] you
- # cannot conveniently save the list of Value directly, as nodes
- # expect to save values as a vector for ALL arguments. So you
- # need a separate IR node that represents all of the size nodes
- # assembled into a list. I'm not an LTC dev so I don't want to
- # figure it out right now. Y'all figure it out...
- return VectorCType(BaseCType(longT))
- else:
- return VectorCType(process_ir_type(typ.elem, properties, symint=symint))
- else:
- raise AssertionError(f"unrecognized type {repr(typ)}")
- # TODO: Determining this based off of CType is bad; this should be computed
- # from Type directly; then the same logic as process_ir_type can be used
- #
- # Invariant: passed typ should be an *owning* CType (e.g., we will report
- # that ArrayRef<Value> is NOT a value type)
- def isValueType(typ: CType, properties: "Optional[LazyIrProperties]" = None) -> bool:
- """
- Given a type, determine if it is a Value-like type. This is equivalent to
- being Tensor-like, but assumes the type has already been transformed.
- """
- if isinstance(typ, BaseCType):
- # I am regretting my naming conventions, but now we are wrapping at::scalar in
- # lazy value, while preserving other 'scalar' types as scalars in the IR
- treat_scalars_as_constants = properties and properties.TreatScalarsAsConstants
- return (
- typ.type == getValueT()
- or (typ.type == scalarT and not treat_scalars_as_constants)
- or typ.type == SymIntT
- )
- elif typ == VectorCType(BaseCType(SymIntT)):
- # TODO: report True for this
- return False
- elif isinstance(typ, (OptionalCType, ListCType, VectorCType)):
- return isValueType(typ.elem, properties)
- return False
- def isSymIntType(typ: Type) -> bool:
- return isinstance(typ, BaseType) and typ.name == BaseTy.SymInt
- def isWrappedScalarType(typ: Type) -> bool:
- """
- Given a type, determine if it is a c10::scalar which we will wrap in a lazy Value.
- Since we literally change the type from scalarT to valueT, information is lost.
- This function helps build a list of wrapped scalars to save that information
- """
- if isinstance(typ, BaseType):
- # I am regretting my naming conventions, but now we are wrapping at::scalar in
- # lazy value, while preserving other 'scalar' types as scalars in the IR
- return typ.name == BaseTy.Scalar
- elif isinstance(typ, (OptionalType, ListType)):
- return isWrappedScalarType(typ.elem)
- return False
- # TODO: dedupe with Type.is_generator_like
- def isGeneratorType(typ: Type) -> bool:
- if isinstance(typ, BaseType):
- return typ.name == BaseTy.Generator
- elif isinstance(typ, (OptionalType)):
- return isGeneratorType(typ.elem)
- return False
- # This class caches a few derived properties computed from an Argument
- # and LazyIrProperties
- class LazyArgument:
- name: str
- orig_type: Type
- lazy_type_: Optional[CType]
- is_wrapped_scalar: bool
- is_generator: bool
- # TODO: this is lies, it is false for symint list
- is_symint_or_list: bool
- # Whether or not we are treating this as symint or not
- symint: bool
- # true if this argument is or contains a lazy IR value
- is_lazy_value: bool
- def __init__(self, arg: Argument, properties: "LazyIrProperties", *, symint: bool):
- self.name = arg.name
- self.orig_type = arg.type
- self.symint = symint
- self.is_optional = isinstance(arg.type, OptionalType)
- self.is_generator = isGeneratorType(arg.type)
- if self.is_generator:
- assert (
- self.is_optional
- ), "We expect all generators are optional since currently they are"
- # there is no handling for generators in TorchScript IR (or XLA)
- # so we fall back to eager if the (optional)generator has value, and otherwise
- # its null and safe to exclude from lazy IR
- self.lazy_type_ = None
- else:
- self.lazy_type_ = process_ir_type(arg.type, properties, symint=symint)
- self.is_wrapped_scalar = isWrappedScalarType(arg.type)
- self.is_symint_or_list = symint and (
- isSymIntType(arg.type)
- or (isinstance(arg.type, OptionalType) and isSymIntType(arg.type.elem))
- # TODO: lists of symints are not currently treated as value types
- # or (isinstance(arg.type, ListType) and isSymIntType(arg.type.elem))
- )
- self.is_lazy_value = not self.is_generator and isValueType(
- self.lazy_type, properties
- )
- @property
- def lazy_type(self) -> CType:
- assert (
- self.lazy_type_ is not None
- ), f"Attempted to access lazy_type for invalid argument {self.name}"
- return self.lazy_type_
- class LazyIrProperties:
- """Collection of properties for an IR node
- The property groups are listed below. Each group is mutually
- exclusive, meaning that only one property from each group can be True
- at any one time. The properties can be accessed as if they were normal
- attributes. The mutual exclusivity is automatically handled.
- """
- Properties: Tuple[Tuple[str, ...], ...] = (
- (
- "ShapePrecompute", # Assume shape has been precomputed
- "ShapeCompute", # Need to compute the shape on construction
- "ShapeCache", # Utilize the shape cache to defer computation
- ),
- (
- "Lower", # Codegen full lower function
- "LowerDeclOnly", # Codegen only lower function declaration
- ),
- (
- "CanBeReused", # Codegen full reuse function
- "CanBeReusedDeclOnly", # Codegen only reuse function declaration
- ),
- (
- "CreateFn", # Codegen full create function
- "CreateFnDeclOnly", # Codegen only create function declaration
- ),
- (
- "TreatScalarsAsConstants", # Treat Scalars as constants instead of handling like values
- ),
- )
- def __init__(self, *default_properties: str):
- properties: Dict[Tuple[str, ...], Optional[str]] = {
- p: None for p in LazyIrProperties.Properties
- }
- self.__dict__["properties"] = properties
- for p in default_properties:
- setattr(self, p, True)
- def __getattr__(self, key: str) -> Any:
- properties = self.__dict__["properties"]
- for values in LazyIrProperties.Properties:
- if key in values:
- return properties[values] == key
- return self.__getattribute__(key)
- def __setattr__(self, key: str, value: Any) -> Any:
- properties = self.__dict__["properties"]
- for values in LazyIrProperties.Properties:
- if key in values:
- properties[values] = key if value else None
- return value
- raise KeyError(f"Invalid property: {key}")
- # Inspired by a FunctionSchema object, a LazyIrSchema holds the schema of a Lazy IR node.
- # Unlike a FunctionSchema, it has no round-trippable string form (relating to the YAML),
- # but carries type information from a native FunctionSchema modified for use with IR nodes,
- # and preserving original argument names.
- #
- # TODO: This is not idiomatic with how other torchgen APIs transform on schema.
- class LazyIrSchema:
- # The name of the operator this function schema describes.
- name: "OperatorName"
- positional_args: Tuple[LazyArgument, ...]
- keyword_args: Tuple[LazyArgument, ...]
- # TODO: Need to handle collisions with argument names at some point
- returns: Tuple["Return", ...]
- # if this schema has a Generator arg, list its orig ctype/name but don't
- # build a LazyArgument since lazy IR doesn't support it
- generator_arg: Optional[NamedCType] = None
- # original function schema
- func: FunctionSchema
- # Whether or not we are code-genning for SymInt or not
- symint: bool
- properties: LazyIrProperties = LazyIrProperties(
- # default properties
- "ShapePrecompute",
- "Lower",
- "CanBeReused",
- )
- opkind: Optional[str] = None
- def __init__(
- self,
- func: FunctionSchema,
- properties: Optional[LazyIrProperties] = None,
- *,
- symint: bool,
- ):
- if properties:
- self.properties = properties
- self.func = func
- self.symint = symint
- positional_args: List[LazyArgument] = []
- for arg_field in ["pre_self_positional", "self_arg", "post_self_positional"]:
- if arg_field == "self_arg" and func.arguments.self_arg is not None:
- arg = getattr(func.arguments, "self_arg").argument
- positional_args.append(
- LazyArgument(arg, self.properties, symint=symint)
- )
- elif getattr(func.arguments, arg_field) is not None:
- positional_args.extend(
- LazyArgument(arg, self.properties, symint=symint)
- for arg in getattr(func.arguments, arg_field)
- )
- self.positional_args = tuple(positional_args)
- keyword_args: List[LazyArgument] = []
- for arg_field in [
- "pre_tensor_options_kwarg_only",
- "tensor_options",
- "post_tensor_options_kwarg_only",
- "out",
- ]:
- curr_args = getattr(func.arguments, arg_field)
- if curr_args is not None:
- if isinstance(curr_args, TensorOptionsArguments):
- curr_args = curr_args.all()
- for arg in curr_args:
- if isGeneratorType(arg.type):
- assert (
- self.generator_arg is None
- ), "We expect there is only one generator arg"
- self.generator_arg = NamedCType(arg.name, arg.type)
- keyword_args.extend(
- LazyArgument(arg, self.properties, symint=symint)
- for arg in curr_args
- )
- self.keyword_args = tuple(keyword_args)
- self.name = func.name
- self.returns = func.returns
- @property
- def node_name(self) -> str:
- """
- Return camel-case version of op in node.
- Note: This function also appends any `overload_name` in the operation.
- For example, if the op is `bitwise_and.Tensor`, the returned name
- will be `BitwiseAndTensor`.
- """
- op_name = f"{self.name.name}_{self.name.overload_name}".lower()
- return "".join(word.capitalize() or "" for word in op_name.split("_"))
- @property
- def aten_name(self) -> str:
- return str(self.name.name)
- @property
- def base_name(self) -> str:
- return f"{self.name.name.base}"
- def filtered_args(
- self,
- positional: bool = True,
- keyword: bool = True,
- values: bool = True,
- scalars: bool = True,
- generator: bool = False,
- ) -> List[LazyArgument]:
- # This function maintains the sorted order of arguments but provides different filtered views.
- # Some parts of the code care about kwargs vs args (TS lowerings),
- # other parts care about whether they need to wrap the arg in a lazy value or leave it alone.
- # Generators are special cased, as they are needed for fallback/shape-inference but not supported
- # in TS lowerings and therefore also omitted from lazy IR.
- args: List[LazyArgument] = []
- if positional:
- args.extend(self.positional_args)
- if keyword:
- args.extend(self.keyword_args)
- if values and scalars and generator:
- return args
- elif values and scalars:
- return [a for a in args if not a.is_generator]
- elif values:
- return [a for a in args if a.is_lazy_value]
- elif scalars:
- return [
- a
- for a in args
- if not a.is_lazy_value and (generator or not a.is_generator)
- ]
- return []
- @property
- def positional_values(self) -> List[LazyArgument]:
- return self.filtered_args(
- positional=True, keyword=False, values=True, scalars=False
- )
- @property
- def positional_scalars(self) -> List[LazyArgument]:
- return self.filtered_args(
- positional=True, keyword=False, values=False, scalars=True
- )
- @property
- def keyword_values(self) -> List[LazyArgument]:
- return self.filtered_args(
- positional=False, keyword=True, values=True, scalars=False
- )
- @property
- def keyword_scalars(self) -> List[LazyArgument]:
- return self.filtered_args(
- positional=False, keyword=True, values=False, scalars=True
- )
|