from typing import Dict, List, NoReturn, Sequence, Union from torchgen.api.types import ( ArrayRefCType, BaseCType, Binding, boolT, ConstRefCType, deviceT, Expr, intArrayRefT, iOptTensorListRefT, layoutT, ListCType, longT, memoryFormatT, MutRefCType, NamedCType, opmath_t, OptionalCType, optionalIntArrayRefT, optionalScalarRefT, optionalSymIntArrayRefT, optionalTensorRefT, scalar_t, scalarT, scalarTypeT, SpecialArgName, symIntArrayRefT, SymIntT, tensorOptionsT, tensorT, VectorCType, ) # This file implements a small program synthesis engine that implements # conversions between one API to another. # # The key data type in this file in NamedCType, short for Named C++ semantic type. A NamedCType # represents a C++ type, plus semantic information about what it represents. # For example, consider the argument "bool pin_memory"; its normal C++ type is # "bool", but its C++ semantic type also keeps track that this represents a # "pin_memory"; you can't just use a random other boolean in a context where you # need a "pin_memory"! # # The translator takes a list of needed NamedCTypes, and then figures out how # to construct expressions with these NamedCTypes from the given bindings. Many # of these expressions are trivial (I need a Tensor other; there's a Tensor # other scope); others are more nontrivial and may require packing/unpacking. # Some examples of non-trivial action: # # - Need the "dtype" binding? Well, maybe "dtype" isn't available # in the context, instead, "options" is, and you need to extract # it from there. (Gather) # # - Need the "context" binding? Well, maybe "context" isn't available # in the context, and you need to construct it from "dtype", "device", # etc. (Scatter) # # - Need the "memory_format" binding? Well, actually, it's available # from both "memory_format" and "options", so you had better make sure # they are consistent. (Join) options_ctype = NamedCType("options", ConstRefCType(BaseCType(tensorOptionsT))) out_tensor_ctype = NamedCType("out", ConstRefCType(BaseCType(tensorT))) longVec_ctype = VectorCType(BaseCType(longT)) longSymVec_ctype = VectorCType(BaseCType(SymIntT)) optionalLongVec_ctype = OptionalCType(VectorCType(BaseCType(longT))) optionalScalar_ctype = OptionalCType(BaseCType(scalarT)) optionalTensor_ctype = OptionalCType(BaseCType(tensorT)) class UnsatError(RuntimeError): pass # Given a set of in-scope bindings and a set of target bindings, synthesize # a list of expressions that uses only the in-scope bindings (bindings) that # have all of the types of goals. You may want to use this function if # you're generating code for a function like: # # void f({args}) { # g({exprs}); // g is a different API # } # # and you need to generate "exprs". # # Typically, a list of Bindings is convenient to get (you usually call something # like arguments() to get them); but technically you only need less information: # for 'bindings' an (un-ordered) list of Exprs is sufficient; similarly, for # 'goals', an (ordered) list of NamedCType goals is sufficient. If you are doing # something more complicated, e.g., tracking the set of bindings in a context, # you may find using these smaller types more convenient. def translate( bindings: Sequence[Union[Expr, Binding]], goals: Sequence[Union[NamedCType, Binding]], *, method: bool = False, allow_expensive_conversions: bool = False, ) -> List[Expr]: binding_exprs: List[Expr] = [] for b in bindings: if isinstance(b, Binding): binding_exprs.append( Expr(, type=b.nctype, ) ) else: binding_exprs.append(b) goal_ctypes: List[NamedCType] = [] for g in goals: if isinstance(g, Binding): goal_ctypes.append(g.nctype) else: goal_ctypes.append(g) # Add all the bindings to the context ctx: Dict[NamedCType, str] = {} for b in binding_exprs: ctx[b.type] = b.expr # While we're at it, do some simple forward inference, looking through # constructors. # # NB: When should you do forward inference versus backward inference? # The general idea: # # - Backward inference WHEN the goal gets smaller # - Forward inference WHEN the hypothesis gets smaller # # This helps ensure termination: backward inference starts with a goal # and tries to make it simpler and simpler until it's trivial; if the # goal can grow in size, we blow up to a really huge goal size. # Similarly, with forward inference we take hypotheses and decompose # them into simpler hypotheses; if hypotheses could expand in size, # we also have potential nontermination. (In the code below, forward # inference is only ever carried out at a single step, but you could # imagine repeated application of forward inference being profitable.) # # A good starting point in the literature for exploring more about proof # search are these lecture notes # # # TODO: My kingdom for a pattern matcher # # # TODO: This could get us in recomputation trouble if b.expr is nontrivial. # Fix this by implementing some sort of sharing so that if multiple # goals share the same expression, we only compute it once. This seems # to matter in practice as compiler is often unwilling to CSE nontrivial # expressions like t = b.type if ( isinstance(t, ConstRefCType) and isinstance(t.elem, OptionalCType) and isinstance(t.elem.elem, BaseCType) and str(t.elem.elem.type) == "at::Tensor" ): ctx[ NamedCType(, ConstRefCType(BaseCType(tensorT))) ] = f"({b.expr}.has_value() ? *{b.expr} : at::Tensor())" if t.type == ConstRefCType(OptionalCType(BaseCType(tensorT))): ctx[ NamedCType(, BaseCType(optionalTensorRefT)) ] = f"(({b.expr}.has_value() && (*{b.expr}).defined()) ? at::OptionalTensorRef(*{b.expr}) : at::OptionalTensorRef())" if t.type == ConstRefCType(BaseCType(scalarT)): ctx[NamedCType(, BaseCType(opmath_t))] = f"({b.expr}).to()" if t.type == ConstRefCType(OptionalCType(BaseCType(scalarT))): ctx[ NamedCType(, BaseCType(optionalScalarRefT)) ] = f"({b.expr}.has_value() ? at::OptionalScalarRef(&({b.expr}.value())) : at::OptionalScalarRef())" if t.type == BaseCType(scalar_t): ctx[ NamedCType(, BaseCType(opmath_t)) ] = f"static_cast({b.expr})" # [Note: IOptTensorListRef] if t.type == ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT)))): ctx[ NamedCType(, BaseCType(iOptTensorListRefT)) ] = f"at::IOptTensorListRef({b.expr})" # Add implicit bindings if the generated code is inside a Tensor method if method: ctx[ NamedCType("self", MutRefCType(BaseCType(tensorT))) ] = "const_cast(*this)" ctx[ NamedCType("self", ConstRefCType(BaseCType(tensorT))) ] = "const_cast(*this)" # This is better! Byte-for-byte compat # ctx[NamedCType("self", ConstRefCType(BaseCType(tensorT)))] = "*this" def unsat(goal: NamedCType) -> NoReturn: ctx_desc = "\n".join( f" {t.cpp_type()} {}; // {e}" for t, e in ctx.items() ) raise UnsatError( f""" Failed to synthesize the expression "{goal.cpp_type()} {}". When I failed, the following bindings were available in the context: {ctx_desc} This probably means there is a missing rule in the rules of torchgen.api.translate. Check this module for more information. """ ) # A shitty backtracking search implementation. It's shitty because it # does backtracking via stack (bad idea!) and for the most part tries to # avoid backtracking. In particular, if # direct=True, we won't try to do any fancy synthesis, just trivial # conversions (e.g., "T a" is OK for "const T& a"). So all of the # existing rules in this function simply try to solve immediately, # and bail if things don't work out. def solve(goal: NamedCType, *, direct: bool) -> str: def direct_solve(goal: NamedCType) -> str: return solve(goal, direct=True) if goal in ctx: # Trivial return ctx[goal] # const & is satisfied with mutable & if isinstance(goal.type, ConstRefCType): try: # WARNING: not strictly decreasing; be careful not # to add a direct conversion that goes satisfies # mutable& with const& return solve( NamedCType(, MutRefCType(goal.type.elem)), direct=direct ) except UnsatError: pass # mutable & is satisfied with value if isinstance(goal.type, MutRefCType): try: return solve(NamedCType(, goal.type.elem), direct=direct) except UnsatError: pass # TODO: These are referentially equal, shouldn't have to do this; # ensuring we don't use type synonym IntArrayRef in codegen would # help if goal.type == ArrayRefCType(BaseCType(longT)): return solve(NamedCType(, BaseCType(intArrayRefT)), direct=direct) if direct: unsat(goal) # For now, all of these rules are mutually exclusive. if goal == NamedCType("memory_format", OptionalCType(BaseCType(memoryFormatT))): memory_format = direct_solve( NamedCType( SpecialArgName.possibly_redundant_memory_format, OptionalCType(BaseCType(memoryFormatT)), ) ) # No need to join "memory_format" and "options" if the target API takes "options" directly. # Otherwise it will cause the redundant memory_format error. if options_ctype in goal_ctypes: return memory_format try: options = direct_solve(options_ctype) return f"c10::impl::check_tensor_options_and_extract_memory_format({options}, {memory_format})" except UnsatError: return memory_format elif goal == NamedCType("options", BaseCType(tensorOptionsT)): dtype = direct_solve( NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))) ) pin_memory = direct_solve( NamedCType("pin_memory", OptionalCType(BaseCType(boolT))) ) device = direct_solve( NamedCType("device", OptionalCType(BaseCType(deviceT))) ) layout = direct_solve( NamedCType("layout", OptionalCType(BaseCType(layoutT))) ) return f"TensorOptions().dtype({dtype}).layout({layout}).device({device}).pinned_memory({pin_memory})" elif goal == NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))): try: options = direct_solve(options_ctype) return f"optTypeMetaToScalarType({options}.dtype_opt())" except UnsatError: out_tensor = direct_solve(out_tensor_ctype) return f"{out_tensor}.scalar_type()" elif goal == NamedCType("layout", OptionalCType(BaseCType(layoutT))): try: options = direct_solve(options_ctype) return f"{options}.layout_opt()" except UnsatError: out_tensor = direct_solve(out_tensor_ctype) return f"{out_tensor}.layout()" elif goal == NamedCType("device", OptionalCType(BaseCType(deviceT))): try: options = direct_solve(options_ctype) return f"{options}.device_opt()" except UnsatError: out_tensor = direct_solve(out_tensor_ctype) return f"{out_tensor}.device()" elif goal == NamedCType("pin_memory", OptionalCType(BaseCType(boolT))): try: options = direct_solve(options_ctype) return f"{options}.pinned_memory_opt()" except UnsatError: # If we're calling a factory op from its out= variant, # We don't actually care about the value of pin_memory. out_tensor = direct_solve(out_tensor_ctype) return "c10::nullopt" # We can always do translations from value types to reference types, like vector -> IntArrayRef elif goal.type == BaseCType(intArrayRefT): try: return direct_solve(NamedCType(, longVec_ctype)) except UnsatError: # We can also go SymIntArrayRef -> IntArrayRef symIntArrayRef_type = direct_solve( NamedCType(, BaseCType(symIntArrayRefT)) ) return f"C10_AS_INTARRAYREF_SLOW({symIntArrayRef_type})" elif goal.type == BaseCType(symIntArrayRefT): try: r = direct_solve(NamedCType(, BaseCType(intArrayRefT))) return f"c10::fromIntArrayRefSlow({r})" except UnsatError: return direct_solve(NamedCType(, longSymVec_ctype)) elif goal.type == BaseCType(SymIntT): return direct_solve(NamedCType(, BaseCType(longT))) elif goal.type == OptionalCType(BaseCType(SymIntT)): argname = direct_solve( NamedCType(, OptionalCType(BaseCType(longT))) ) return f"{argname}.has_value() ? c10::make_optional(c10::SymInt(*{argname})) : c10::nullopt" elif goal.type == BaseCType(longT): symInt_type = direct_solve(NamedCType(, BaseCType(SymIntT))) return f"{symInt_type}.expect_int()" elif goal.type == OptionalCType(BaseCType(longT)): argname = direct_solve( NamedCType(, OptionalCType(BaseCType(SymIntT))) ) return f"{argname}.has_value() ? c10::make_optional({argname}->expect_int()) : c10::nullopt" elif goal.type == BaseCType(optionalIntArrayRefT): try: return direct_solve(NamedCType(, optionalLongVec_ctype)) except UnsatError: argname = direct_solve( NamedCType(, BaseCType(optionalSymIntArrayRefT)) ) return f"{argname}.has_value() ? c10::make_optional(C10_AS_INTARRAYREF_SLOW(*{argname})) : c10::nullopt" elif goal.type == BaseCType(optionalSymIntArrayRefT): # TODO: You might also want to solve this from longSymVec_ctype or # an optional version of it argname = direct_solve( NamedCType(, BaseCType(optionalIntArrayRefT)) ) return f"{argname}.has_value() ? c10::make_optional(c10::fromIntArrayRefSlow(*{argname})) : c10::nullopt" elif goal.type == BaseCType(optionalScalarRefT): return direct_solve(NamedCType(, optionalScalar_ctype)) elif goal.type == BaseCType(optionalTensorRefT): return direct_solve(NamedCType(, optionalTensor_ctype)) # Note [translation from C++ reference to value types] # The below cases are all for when we have an argument with a reference type, # and a corresponding goal with a value type. # These are needed when we populate the inputs to a lambda capture and we need # to guarantee the lifetime of each captured argument. # We guard it with an explicit kwarg because converting to a value type is expensive # (O(n)) to convert from IntArrayRef to vector), # so the caller of translate() should be explicit that they need it. if allow_expensive_conversions: if goal.type == VectorCType(BaseCType(longT)): intArrayRef_ctype = NamedCType(, BaseCType(intArrayRefT)) argname = direct_solve(intArrayRef_ctype) return f"{argname}.vec()" if goal.type == VectorCType(BaseCType(SymIntT)): symIntArrayRef_ctype = NamedCType(, BaseCType(symIntArrayRefT)) argname = direct_solve(symIntArrayRef_ctype) return f"{argname}.vec()" elif goal.type == OptionalCType(VectorCType(BaseCType(longT))): optionalIntArrayRef_ctype = NamedCType(, BaseCType(optionalIntArrayRefT) ) argname = direct_solve(optionalIntArrayRef_ctype) return f"{argname}.has_value() ? c10::make_optional({argname}->vec()) : c10::nullopt" elif goal.type == OptionalCType(BaseCType(scalarT)): optionalScalarRef_ctype = NamedCType(, BaseCType(optionalScalarRefT) ) argname = direct_solve(optionalScalarRef_ctype) return f"{argname}.has_value() ? c10::make_optional({argname}) : c10::nullopt" elif goal.type == OptionalCType(BaseCType(scalarT)): optionalTensorRef_ctype = NamedCType(, BaseCType(optionalTensorRefT) ) argname = direct_solve(optionalTensorRef_ctype) return f"{argname}.has_value() ? c10::make_optional({argname}) : c10::nullopt" # Technically, we also need to handle cases of C++ containers holding reference types. # But there currently aren't any ops that require lambda capture codegen # With arguments like std::vector. # If that changes, we'll have to add the translation here. # We allow const casting on tensors, since const-correctness is a bit broken for at::Tensor. # We could probably generalize this to non-tensor types too. if goal.type == MutRefCType(BaseCType(tensorT)): const_ref_tensor_ctype = NamedCType(, ConstRefCType(BaseCType(tensorT)) ) argname = direct_solve(const_ref_tensor_ctype) return f"const_cast({argname})" unsat(goal) return [Expr(solve(g, direct=False), g) for g in goal_ctypes]