from torch.fx.experimental.graph_gradual_typechecker import Refine from torch.fx.tensor_type import TensorType from torch.fx.experimental.unification import Var, unify # type: ignore[attr-defined] def infer_symbolic_types_single_pass(traced): """ Calls our symbolic inferencer once. """ r = Refine(traced) r.refine() mgu = unify_eq(r.constraints) substitute_all_types(traced.graph, mgu) def infer_symbolic_types(traced): """ Calls our symbolic inferencer twice. This is useful when one pass is not enough to infer all the information such as the case for braodcasting. """ r = Refine(traced) r.refine() mgu = unify_eq(r.constraints) substitute_all_types(traced.graph, mgu) r = Refine(traced) r.refine() mgu = unify_eq(r.constraints) substitute_all_types(traced.graph, mgu) r.symbolic_relations() def convert_eq(list_of_eq): """ Convert equality constraints in the right format to be used by unification library. """ lhs = [] rhs = [] for eq in list_of_eq: lhs.append(eq.lhs) rhs.append(eq.rhs) return tuple(lhs), tuple(rhs) def unify_eq(list_of_eq): """ Apply unification to a set of equality constraints """ lhs, rhs = convert_eq(list_of_eq) return unify(lhs, rhs) def substitute_solution_one_type(mapping, t): """ Apply the most general unifier to a type """ if isinstance(t, Var): if t in mapping.keys(): return mapping[t] else: return t elif isinstance(t, TensorType): new_type = [] for typ in t.__args__: if typ in mapping.keys(): new_type.append(mapping[typ]) else: new_type.append(typ) return TensorType(tuple(new_type)) elif isinstance(t, list): new_type = [] for typ in t: new_type.append(substitute_solution_one_type(mapping, typ)) return new_type elif isinstance(t, tuple): new_type = [] for typ in t: new_type.append(substitute_solution_one_type(mapping, typ)) return tuple(new_type) else: return t def substitute_all_types(graph, mapping): """ Apply the most general unifier to all types in a graph till reaching a fixed point. If the input and output graph are the same, we converge. """ flag = True while flag: flag = False for k in mapping: old_mapping_val = mapping[k] if mapping[k] in mapping.keys(): new_key = mapping[k] mapping[k] = mapping[new_key] if old_mapping_val != mapping[k]: flag = True for n in graph.nodes: n.type = substitute_solution_one_type(mapping, n.type) def check_for_type_equality(g1, g2): """ A check equality to be used in fixed points. We do not use graph equality but instead type equality. """ for n, m in zip(g1.nodes, g2.nodes): if n.type != m.type: return False return True