unify_refinements.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. from torch.fx.experimental.graph_gradual_typechecker import Refine
  2. from torch.fx.tensor_type import TensorType
  3. from torch.fx.experimental.unification import Var, unify # type: ignore[attr-defined]
  4. def infer_symbolic_types_single_pass(traced):
  5. """
  6. Calls our symbolic inferencer once.
  7. """
  8. r = Refine(traced)
  9. r.refine()
  10. mgu = unify_eq(r.constraints)
  11. substitute_all_types(traced.graph, mgu)
  12. def infer_symbolic_types(traced):
  13. """
  14. Calls our symbolic inferencer twice.
  15. This is useful when one pass is not enough
  16. to infer all the information such as the case
  17. for braodcasting.
  18. """
  19. r = Refine(traced)
  20. r.refine()
  21. mgu = unify_eq(r.constraints)
  22. substitute_all_types(traced.graph, mgu)
  23. r = Refine(traced)
  24. r.refine()
  25. mgu = unify_eq(r.constraints)
  26. substitute_all_types(traced.graph, mgu)
  27. r.symbolic_relations()
  28. def convert_eq(list_of_eq):
  29. """
  30. Convert equality constraints in the right format
  31. to be used by unification library.
  32. """
  33. lhs = []
  34. rhs = []
  35. for eq in list_of_eq:
  36. lhs.append(eq.lhs)
  37. rhs.append(eq.rhs)
  38. return tuple(lhs), tuple(rhs)
  39. def unify_eq(list_of_eq):
  40. """
  41. Apply unification to a set of
  42. equality constraints
  43. """
  44. lhs, rhs = convert_eq(list_of_eq)
  45. return unify(lhs, rhs)
  46. def substitute_solution_one_type(mapping, t):
  47. """
  48. Apply the most general unifier to a type
  49. """
  50. if isinstance(t, Var):
  51. if t in mapping.keys():
  52. return mapping[t]
  53. else:
  54. return t
  55. elif isinstance(t, TensorType):
  56. new_type = []
  57. for typ in t.__args__:
  58. if typ in mapping.keys():
  59. new_type.append(mapping[typ])
  60. else:
  61. new_type.append(typ)
  62. return TensorType(tuple(new_type))
  63. elif isinstance(t, list):
  64. new_type = []
  65. for typ in t:
  66. new_type.append(substitute_solution_one_type(mapping, typ))
  67. return new_type
  68. elif isinstance(t, tuple):
  69. new_type = []
  70. for typ in t:
  71. new_type.append(substitute_solution_one_type(mapping, typ))
  72. return tuple(new_type)
  73. else:
  74. return t
  75. def substitute_all_types(graph, mapping):
  76. """
  77. Apply the most general unifier to all types in a graph
  78. till reaching a fixed point. If the input and output graph
  79. are the same, we converge.
  80. """
  81. flag = True
  82. while flag:
  83. flag = False
  84. for k in mapping:
  85. old_mapping_val = mapping[k]
  86. if mapping[k] in mapping.keys():
  87. new_key = mapping[k]
  88. mapping[k] = mapping[new_key]
  89. if old_mapping_val != mapping[k]:
  90. flag = True
  91. for n in graph.nodes:
  92. n.type = substitute_solution_one_type(mapping, n.type)
  93. def check_for_type_equality(g1, g2):
  94. """
  95. A check equality to be used in fixed points.
  96. We do not use graph equality but instead type
  97. equality.
  98. """
  99. for n, m in zip(g1.nodes, g2.nodes):
  100. if n.type != m.type:
  101. return False
  102. return True