lazy_ts_lowering.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. from torchgen.api.lazy import LazyIrSchema
  2. from torchgen.api.types import OptionalCType
  3. def ts_lowering_body(schema: LazyIrSchema) -> str:
  4. # for now, we just want one IR class decl and soon after also the method defs
  5. # and we use the functional version not out/inplace.
  6. emplace_arguments = []
  7. for arg in schema.positional_args:
  8. if arg.is_lazy_value:
  9. if isinstance(arg.lazy_type, OptionalCType):
  10. emplace_arguments.append(
  11. f"has_{arg.name} ? loctx->GetOutputOp(operand(i++)) : nullptr"
  12. )
  13. continue
  14. emplace_arguments.append("loctx->GetOutputOp(operand(i++))")
  15. continue
  16. emplace_arguments.append(f'"{arg.name}", {arg.name}')
  17. emplace_arguments_str = "\n ".join(
  18. [f"arguments.emplace_back({a});" for a in emplace_arguments]
  19. )
  20. emplace_kwarg_values = [
  21. f'"{arg.name}", loctx->GetOutputOp(operand(i++))'
  22. for arg in schema.keyword_values
  23. ]
  24. emplace_kwarg_scalars = [
  25. f'"{arg.name}", {arg.name}' for arg in schema.keyword_scalars
  26. ]
  27. emplace_kwarguments = "\n ".join(
  28. [
  29. f"kwarguments.emplace_back({a});"
  30. for a in emplace_kwarg_values + emplace_kwarg_scalars
  31. ]
  32. )
  33. return f"""\
  34. std::vector<torch::jit::NamedValue> arguments;
  35. std::vector<torch::jit::NamedValue> kwarguments;
  36. arguments.reserve({len(emplace_arguments)});
  37. kwarguments.reserve({len(emplace_kwarg_values + emplace_kwarg_scalars)});
  38. size_t i = 0;
  39. {emplace_arguments_str}
  40. {emplace_kwarguments}
  41. torch::lazy::TSOpVector {schema.aten_name}_out = torch::lazy::LowerTSBuiltin(function, op().op, arguments, kwarguments);
  42. TORCH_CHECK_EQ({schema.aten_name}_out.size(), {len(schema.returns)});
  43. return {schema.aten_name}_out;
  44. """