123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663 |
- import copy
- import re
- from dataclasses import dataclass
- from typing import Dict, List, Match, Optional, Sequence, Set, Tuple
- from torchgen.api import cpp
- from torchgen.api.types import BaseCType, Binding, NamedCType, tensorListT
- from torchgen.model import (
- FunctionSchema,
- NativeFunction,
- NativeFunctionsViewGroup,
- SchemaKind,
- Type,
- )
- from torchgen.utils import IDENT_REGEX
- # Represents a saved attribute involved in backward calculation.
- # Note that it can be a derived property of an input argument, e.g.:
- # we could save `other.scalar_type()` instead of the entire `other` tensor.
- @dataclass(frozen=True)
- class SavedAttribute:
- # The NamedCType holds the updated name and cpp type of the attribute
- # for the name, Suffix is appended if it's derived property, e.g.: `other_scalar_type`
- nctype: NamedCType
- # The expression to read the derived property at save time, e.g.:
- # `other.scalar_type()`.
- expr: str
- # Represents a backward formula that calculates derivatives for one
- # or more tensors.
- @dataclass(frozen=True)
- class Derivative:
- # The formula string (legit C++ expression).
- # Note that expressions against input arguments have been replaced with the
- # corresponding saved attributes.
- # E.g.:
- # raw formula: `mul_tensor_backward(grad, self, other.scalar_type())`
- # here: `mul_tensor_backward(grad, self, other_scalar_type)`
- formula: str
- # The formula string before input argument replacement
- original_formula: str
- # Names of the arguments for which this formula calculates derivatives.
- var_names: Tuple[str, ...]
- # Saved inputs that are referenced by the formula.
- saved_inputs: Tuple[SavedAttribute, ...]
- # Saved outputs that are referenced by the formula.
- saved_outputs: Tuple[SavedAttribute, ...]
- # Gradients that are referenced by name in the formula.
- named_gradients: Set[str]
- # Represents a forward formula that calculates forward derivatives
- # for one tensor.
- @dataclass(frozen=True)
- class ForwardDerivative:
- # The formula string (legit C++ expression).
- # Note that special keywords such as "linear" or "element_wise" have been
- # replaced by the automatically generated formula.
- formula: str
- # Name of the output arguments for which this formula calculates forward
- # derivatives
- var_names: Tuple[str, ...]
- # Type of the output arguments for which this formula calculates forward
- # derivatives
- var_types: Tuple[Type, ...]
- # Inputs for which the forward derivatives are required for this formula
- required_inputs_fw_grad: Optional[Tuple[str, ...]]
- # Inputs for which the primal is required for this formula
- required_inputs_primal: Optional[Tuple[str, ...]]
- # Flag to specify if this formula requires the original value of self
- # This is only used by inplace operations
- required_original_self_value: bool
- # If this formula is specified in derivatives.yaml or if we are re-using the
- # out of place formula for inplace
- is_reusing_outplace_formula: bool
- # Represents differentiability info for a NativeFunction.
- @dataclass(frozen=True)
- class DifferentiabilityInfo:
- # The base name read from derivatives.yaml.
- name: str
- # The matching native function.
- #
- # There can be multiple NativeFunction having the same base name:
- # - different overloads with different types of input arguments;
- # - in-place/out/functional variants of the same function;
- #
- # We first use the schema string (under the 'name' key) in derivatives.yaml
- # to find the NativeFunction having the same schema string.
- # Then we find the in-place/out/functional variants of the matching function.
- # Among these variants, we choose the one having the same name as the
- # derivatives.yaml entry. If there is no exact match, then we choose the
- # in-place variant.
- # TODO: maybe the logic to search for all variants is no longer necessary?
- func: NativeFunction
- # The name of the generated autograd function.
- # It's set only if we will calculate a derivative, i.e.
- # 'args_with_derivatives' is not empty.
- op: Optional[str]
- # The derivatives formulae for this function.
- # Note that the length of this sequence is the number of differentiable inputs
- derivatives: Sequence[Derivative]
- # The forward derivatives formulae for this function.
- # Note that the length of this sequence is the number of differentiable outputs
- forward_derivatives: Sequence[ForwardDerivative]
- # The union of 'saved_inputs' of all 'derivatives'.
- all_saved_inputs: Sequence[SavedAttribute]
- # The union of 'saved_outputs' of all 'derivatives'.
- all_saved_outputs: Sequence[SavedAttribute]
- # All named gradients that are available for use, in the same
- # order as in the grads vector.
- available_named_gradients: Sequence[str]
- # The named gradients that are used in any of the derivatives.
- # Invariant: all(name in available_named_gradients for name in used_named_gradients)
- used_named_gradients: Set[str]
- # The function's input arguments for which it calculates derivatives.
- # It's the union of 'var_names' of all 'derivatives', sorted by the
- # argument order in the function schema.
- args_with_derivatives: Sequence[Binding]
- # Names of arguments whose derivative formula is 'non_differentiable'.
- non_differentiable_arg_names: Sequence[str]
- # Raw data read from derivatives.yaml.
- output_differentiability: Optional[List[bool]]
- # output_differentiability in derivatives.yaml can be a list of
- # conditions that express if the output is differentiable. In this case,
- # the number of conditions must match the number of outputs
- # (NB: we only support one condition right now).
- # output_differentiability gets populated with True for each condition,
- # while output_differentiability_conditions gets populated with the conditions
- output_differentiability_conditions: Optional[List[str]]
- @property
- def has_derivatives(self) -> bool:
- return len(self.args_with_derivatives) > 0
- # Generates a new DifferentiabilityInfo using the exact same set of derivative information,
- # but with a new operator name.
- # This is used when generating "copy" variants of view ops,
- # which are able to use the exact same derivative formula as the original view op
- # See Note [Codegen'd {view}_copy Operators]
- def create_view_copy_from_view_derivative(
- self, g: NativeFunctionsViewGroup
- ) -> Optional["DifferentiabilityInfo"]:
- if g.view_copy is None:
- return None
- f = g.view_copy
- name_split_by_period = self.name.split(".", maxsplit=2)
- # Append a "_copy" to the base name of the operator (but keep the overload name the same)
- view_copy_name = f"{name_split_by_period[0]}_copy." + ".".join(
- name_split_by_period[1:]
- )
- view_copy_op_name = None if self.op is None else f"{self.op}_copy"
- return DifferentiabilityInfo(
- # Use the "_copy" version of name/func/op
- name=view_copy_name,
- func=f,
- op=view_copy_op_name,
- # But keep all derivative info the same
- derivatives=self.derivatives,
- forward_derivatives=self.forward_derivatives,
- all_saved_inputs=self.all_saved_inputs,
- all_saved_outputs=self.all_saved_outputs,
- available_named_gradients=self.available_named_gradients,
- used_named_gradients=self.used_named_gradients,
- args_with_derivatives=self.args_with_derivatives,
- non_differentiable_arg_names=self.non_differentiable_arg_names,
- output_differentiability=self.output_differentiability,
- output_differentiability_conditions=self.output_differentiability_conditions,
- )
- def uses_ident(info: Optional[DifferentiabilityInfo], ident: str) -> bool:
- if info is None:
- return False
- for derivative in info.derivatives:
- formula = derivative.formula
- if re.search(IDENT_REGEX.format(ident), formula):
- return True
- return False
- def uses_retain_variables(info: Optional[DifferentiabilityInfo]) -> bool:
- return uses_ident(info, "retain_variables")
- def uses_single_grad(info: Optional[DifferentiabilityInfo]) -> bool:
- return uses_ident(info, "grad")
- # Represents a differentiable `Argument`.
- # How is it different from the `Argument` type?
- # - It's processed Arguments which are differentiable and only used in the
- # context of the autograd codegen;
- # - It can represent SelfArgument or regular Argument but not TensorOptionsArgument;
- @dataclass(frozen=True)
- class DifferentiableInput:
- name: str
- type: Type
- # TODO: only to keep it byte-for-byte compatible with the old codegen, should remove.
- cpp_type: str
- # Represents a differentiable `Return`.
- # How it it different from the `Return` type?
- # - The name in `Return` is optional. Here it is always populated using the same
- # `cpp.return_names()` method.
- # TODO: some cpp naming logic (e.g. resolving name conflict) might be irrelevant?
- # - It's processed Returns which are differentiable, in compliance with the
- # `output_differentiability` field defined in derivatives.yaml (if specified),
- # and are only used in the context of the autograd codegen;
- @dataclass(frozen=True)
- class DifferentiableOutput:
- name: str
- type: Type
- # TODO: only to keep it byte-for-byte compatible with the old codegen, should remove.
- cpp_type: str
- @dataclass(frozen=True)
- class NativeFunctionWithDifferentiabilityInfo:
- func: NativeFunction
- info: Optional[Dict[str, DifferentiabilityInfo]]
- fw_derivatives: Optional[Dict[str, Sequence[ForwardDerivative]]]
- # TODO: Update comment below since it is out of date.
- def dispatch_strategy(fn: NativeFunctionWithDifferentiabilityInfo) -> str:
- """How are we going to call the underlying implementation of a
- declaration? There are two strategies:
- - use_derived: we want to call the implementation on CPUDoubleType
- (or a similar, derived Type instance). Because these derived
- instances deal in Tensors, not Variables (it's a completely different
- object, so it doesn't dispatch back to VariableType), code on
- this dispatch path needs to wrap/unwrap tensors. If the
- derived implementation takes and returns tensors, the
- implementation is usually differentiable (although we also use
- the derived dispatch path for non-differentiable functions
- that we still want to dispatch on the derived Type instance;
- e.g., size())
- - use_type: we want to call the implementation on Type, because
- it is implemented concretely, and the functions it invokes will
- get dispatched back to VariableType (which will ensure that they
- are differentiable.)
- """
- # fn is derived as long as any of its per-key differentiability infos
- # has_derivatives. dispatch_strategy() is used to guard generation of fns in VariableType
- # and ADInplaceOrViewType. We want to generate these functions as long as a
- # derivative is defined for ANY dispatch key.
- if fn.func.is_abstract or (
- fn.info is not None and any(info.has_derivatives for info in fn.info.values())
- ):
- # If the function is abstract (not implemented on at::Type), we must
- # call the implementation on the derived type with unpacked tensors.
- # If the function has a derivative specified and is concrete, we could
- # call either implementation. We prefer the calling the derived
- # type's implementation with unpacked tensors because it is more
- # performant in some cases: any internal calls to other ATen functions
- # won't have the history tracked.
- # If the function has a type dispatched argument (i.e. is a factory),
- # we prefer calling the derived type's implementation both because it is
- # more performant and to ensure factory functions return tensors with _version
- # of 0 (probably not strictly necessary, but nice to have to keeps versions simple
- # to understand.
- return "use_derived"
- else:
- # If the function is concrete (we don't have to override it) and we
- # didn't declare it in derivatives.yaml, we'll assume that it is
- # actually implemented out of differentiable functions. (This
- # assumption might not hold, but then you'll see gradcheck fail.)
- return "use_type"
- def match_differentiability_info(
- native_functions: List[NativeFunction],
- differentiability_infos: Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]],
- ) -> List[NativeFunctionWithDifferentiabilityInfo]:
- """Sets the "derivative" key on declarations to matching autograd function
- In-place functions will use the out-of-place derivative definition if there
- is no in-place specific derivative.
- """
- functional_info_by_signature = {
- schema.signature(strip_default=True): info_dict
- for schema, info_dict in differentiability_infos.items()
- if schema.kind() == SchemaKind.functional
- }
- non_functional_info_by_signature = {
- schema.signature(strip_default=True): info_dict
- for schema, info_dict in differentiability_infos.items()
- if schema.kind() != SchemaKind.functional
- }
- def find_info(
- f: NativeFunction,
- ) -> Tuple[Optional[Dict[str, DifferentiabilityInfo]], bool]:
- # Don't bother matching info to generated out= variants
- if "generated" in f.tags and f.func.kind() == SchemaKind.out:
- return None, False
- # (1) Check for an exact match
- if f.func in differentiability_infos:
- return differentiability_infos[f.func], True
- # (2) If no exact match, check if the out-of-place variant
- # of this operator has a match.
- # i.e mul() for mul_() or mul_out()
- f_sig = f.func.signature(strip_default=True)
- if f_sig in functional_info_by_signature:
- return functional_info_by_signature[f_sig], False
- # (3) Some operators have a derivative explicitly defined for the mutable
- # variant, but get a code-generated out-of-place variant which does *not*
- # come with a derivative formula.
- # For the generated out-of-place variant, use the mutable variant's formula
- # if it exists.
- if "generated" in f.tags and f_sig in non_functional_info_by_signature:
- info_dict = non_functional_info_by_signature[f_sig]
- # See https://github.com/pytorch/pytorch/pull/76320/files#r874816389
- assert not any(
- any("self" in str(inpt.nctype.name) for inpt in info.all_saved_inputs)
- for info in info_dict.values()
- ), f"""\
- Attempted to convert a derivative formula for a mutable operator
- to be used by automatically by its functional variant ("{str(f.func)}").
- this is not currently supported (we'd need to fix up the formula in the codegen)."""
- return info_dict, False
- # (4) Generate derivative information of unary foreach functions if none is defined in `derivatives.yaml`
- base_op_name = f.func.name.name
- if (
- base_op_name.base.startswith("_foreach")
- and not base_op_name.inplace
- and len(f.func.arguments.post_self_positional) == 0
- ):
- ref_native_op_name = base_op_name.base.split("_foreach_")[-1]
- for function_schema in functional_info_by_signature:
- if (
- function_schema.name.name.base == ref_native_op_name
- and not function_schema.name.name.inplace
- ):
- all_saved_inputs = []
- all_saved_outputs = []
- diff_info_dict = copy.deepcopy(
- differentiability_infos[function_schema]
- )
- diff_info = diff_info_dict["Default"]
- modified_derivative_formulas = []
- for derivative in diff_info.derivatives:
- saved_inputs = []
- saved_outputs = []
- modified_formula = (
- derivative.formula.replace("grad", "grads[i]")
- .replace("self", "self[i]")
- .replace("result", "result[i]")
- )
- if "self" in modified_formula:
- saved_inputs.append(
- SavedAttribute(
- nctype=NamedCType(
- name="self", type=BaseCType(tensorListT)
- ),
- expr="self",
- )
- )
- all_saved_inputs.append(saved_inputs[-1])
- if "result" in modified_formula:
- saved_outputs.append(
- SavedAttribute(
- nctype=NamedCType(
- name="result", type=BaseCType(tensorListT)
- ),
- expr="result",
- )
- )
- all_saved_outputs.append(saved_outputs[-1])
- modified_derivative = Derivative(
- formula=modified_formula,
- original_formula=derivative.original_formula,
- var_names=("self",),
- saved_inputs=tuple(saved_inputs),
- saved_outputs=tuple(saved_outputs),
- named_gradients=set(),
- )
- modified_derivative_formulas.append(modified_derivative)
- assert f.func.arguments.self_arg is not None
- diff_info = DifferentiabilityInfo(
- name=base_op_name.base,
- func=f,
- op=f"Foreach{diff_info.op}",
- derivatives=modified_derivative_formulas,
- forward_derivatives=[],
- all_saved_inputs=tuple(set(all_saved_inputs)),
- all_saved_outputs=tuple(set(all_saved_outputs)),
- available_named_gradients=(),
- used_named_gradients=set(),
- args_with_derivatives=[
- Binding(
- name="self",
- nctype=NamedCType(
- name="self", type=BaseCType(tensorListT)
- ),
- argument=f.func.arguments.self_arg.argument,
- default=None,
- )
- ],
- non_differentiable_arg_names=[],
- output_differentiability=None,
- output_differentiability_conditions=None,
- )
- diff_info_dict["Default"] = diff_info
- if f.func not in differentiability_infos:
- differentiability_infos[f.func] = diff_info_dict
- functional_info_by_signature[f.func] = diff_info_dict
- return diff_info_dict, True
- return None, False
- result: List[NativeFunctionWithDifferentiabilityInfo] = []
- for f in native_functions:
- info_dict, is_exact_match = find_info(f)
- # Currently, the '.strides()' to 'strides_or_error' replacement does not support
- # 'self' derivatives of an inplace function, so we must check for this case.
- if f.func.kind() == SchemaKind.inplace and (info_dict is not None):
- for info in info_dict.values():
- for derivative in info.derivatives:
- if "self" in derivative.var_names:
- for saved_input in derivative.saved_inputs:
- assert "strides_or_error" not in saved_input.expr, (
- "Calling '.strides()' in the 'self' derivative formula of an "
- f"in-place function is not supported: {f.func}"
- )
- if not info_dict:
- result.append(
- NativeFunctionWithDifferentiabilityInfo(
- func=f, info=None, fw_derivatives=None
- )
- )
- continue
- fw_derivative_dict: Dict[str, Sequence[ForwardDerivative]] = {}
- for key, info in info_dict.items():
- if not info.forward_derivatives:
- fw_derivative_dict[key] = []
- continue
- forward_derivatives = info.forward_derivatives
- # For functions that have a single def for out-of-place and inplace (like abs())
- if f.func.kind() == SchemaKind.inplace:
- # For inplace functions there is a little bit of work to do:
- # 1) Validate the formula and make sure the input that is modified in not used:
- # - If there is a formula for the inplace variant of the function (is_exact_match == True) then
- # we make sure that the original value of the input that is being modified inplace (self_p) is
- # not used in the formula. Note that the formula can use "original_self_p" here and that would
- # trigger a clone of the original input.
- # - If we are re-using the out of place formula (is_exact_match == False) then we replace every
- # occurrence of self_p and self_t by original_self_p and original_self_t. These will be
- # populated by cloned version of the original input (either the clone done by the backward AD
- # logic if self is also used in a backward formula or a special clone that we add).
- # 2) At this point, there cannot be a self_p in the formula.
- # 3) Change "result" into "self_p" as by design, in the inplace function codegen, the result is
- # simply called self (as it is modified inplace).
- # 4) Update the required primals data in case it used to contain "result" but should now contain
- # "self"
- # 5) If it is not an exact match, the user formula is not modifying the existing forward grad
- # inplace as it should. So add some code that makes sure that we do so if the forward grad
- # already exists.
- assert (
- len(info.forward_derivatives) == 1
- ) # Only single output inplace should exist
- fw_info = info.forward_derivatives[0]
- formula = fw_info.formula
- def replace_self_with_original_self(formula: str, postfix: str) -> str:
- def repl(m: Match[str]) -> str:
- return f"{m.group(1)}original_self{postfix}{m.group(2)}"
- return re.sub(IDENT_REGEX.format(f"self{postfix}"), repl, formula)
- if re.search(IDENT_REGEX.format("self_p"), formula):
- if is_exact_match:
- # For manually defined formulas, don't allow the original value to be used
- raise RuntimeError(
- f'The formula for "{f.func.name}" is using the original value of self '
- "that is being modified inplace. This would lead to wrong forward gradients. "
- 'Please use "result" in the formula only.'
- )
- else:
- # When the original formula is out of place, we save a clone of the primal
- # value to be able to access this value if needed
- # replace "self_p"/"self_t" from the formula by "original_self_p"/"original_self_t"
- formula = replace_self_with_original_self(formula, "_p")
- formula = replace_self_with_original_self(formula, "_t")
- # replace "result" from the formula by "self_p"
- def repl(m: Match[str]) -> str:
- return f"{m.group(1)}self_p{m.group(2)}"
- formula = re.sub(IDENT_REGEX.format("result"), repl, formula)
- required_primals = fw_info.required_inputs_primal
- if re.search(IDENT_REGEX.format("self_p"), formula):
- required_primals = (
- required_primals + ("self",) if required_primals else ("self",)
- )
- if not is_exact_match:
- # NOTE [In-place forward AD formula Optimization]
- #
- # This optimization transforms the formula to directly do inplace, i.e.
- # instead of self_t.copy_(self_t.op()) we do self_t.op_() when the following are met:
- #
- # 1) the formula satisfies the pattern: "self_t.op(*args)"
- # 2) "op" in (1) needs to be the same as the op the derivative is for
- #
- # (2) may seem too strict, but currently the only ops that satisfy (1) also satisfy (2)
- # If there is a need, we can relax (2) to allow any op that has an in-place variant
- is_single_method_on_self_t = False
- directly_do_inplace = False
- op_name: Optional[str] = None
- between_parens: Optional[str] = None
- match = re.fullmatch(r"self_t.([\w]*)\((.*)\)", formula)
- if match:
- op_name, between_parens = match.group(1), match.group(2)
- # We want to...
- # Match: self_t.op1(other_p.op2(arg))
- # Avoid: self_t.op1(args) + self_t.op2(args)
- # Avoid: self_t.op1(other_p.op2(arg)) + self_t.op2(args)
- def check_parens_nest_level_gt_zero(s: str) -> bool:
- level = 1
- for ch in s:
- if ch == ")":
- level -= 1
- if level == 0:
- return False
- if ch == "(":
- level += 1
- return True
- is_single_method_on_self_t = check_parens_nest_level_gt_zero(
- between_parens
- )
- directly_do_inplace = (
- is_single_method_on_self_t and op_name == info.name
- )
- if directly_do_inplace:
- assert op_name is not None
- assert between_parens is not None
- formula = f"self_t_raw.defined() ? self_t_raw.{op_name}_({between_parens}) : {formula}"
- else:
- # Make sure that the forward grad is modified inplace when the original formula
- # is out of place
- formula = f"self_t_raw.defined() ? self_t_raw.copy_({formula}) : {formula}"
- required_original_self_value = bool(
- re.search(IDENT_REGEX.format("original_self_p"), formula)
- ) or bool(re.search(IDENT_REGEX.format("original_self_t"), formula))
- forward_derivatives = [
- ForwardDerivative(
- formula=formula,
- var_names=("self",),
- var_types=fw_info.var_types,
- required_inputs_fw_grad=fw_info.required_inputs_fw_grad,
- required_inputs_primal=required_primals,
- required_original_self_value=required_original_self_value,
- is_reusing_outplace_formula=not is_exact_match,
- ),
- ]
- fw_derivative_dict[key] = forward_derivatives
- result.append(
- NativeFunctionWithDifferentiabilityInfo(
- func=f, info=info_dict, fw_derivatives=fw_derivative_dict
- )
- )
- return result
- def is_differentiable(
- name: str, type: Type, info: Optional[DifferentiabilityInfo]
- ) -> bool:
- return type.is_tensor_like() and (
- info is None or name not in info.non_differentiable_arg_names
- )
- def gen_differentiable_outputs(
- fn: NativeFunctionWithDifferentiabilityInfo, key: str = "Default"
- ) -> List[DifferentiableOutput]:
- f = fn.func
- info = fn.info[key] if fn.info else None
- outputs: List[DifferentiableOutput] = [
- DifferentiableOutput(
- name=name,
- type=ret.type,
- cpp_type=cpp.return_type(ret, symint=True).cpp_type(),
- )
- for name, ret in zip(cpp.return_names(f), f.func.returns)
- ]
- output_differentiability = info.output_differentiability if info else None
- if output_differentiability is not None:
- if len(output_differentiability) != len(outputs):
- raise RuntimeError(
- f"The length of output_differentiability ({len(output_differentiability)}), "
- f"does not match the number of outputs ({len(outputs)})."
- )
- differentiable_outputs: List[DifferentiableOutput] = []
- if False in output_differentiability and f.func.kind() == SchemaKind.inplace:
- raise RuntimeError(
- "output_differentiability=False for inplace operation (version_counter won't get updated)"
- )
- for differentiable, output in zip(output_differentiability, outputs):
- if differentiable:
- differentiable_outputs.append(output)
- return differentiable_outputs
- candidate_differentiable_outputs = list(
- filter(lambda r: is_differentiable(r.name, r.type, info), outputs)
- )
- if uses_single_grad(info):
- return candidate_differentiable_outputs[:1]
- else:
- return candidate_differentiable_outputs
|