123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120 |
- from torch.fx.experimental.graph_gradual_typechecker import Refine
- from torch.fx.tensor_type import TensorType
- from torch.fx.experimental.unification import Var, unify # type: ignore[attr-defined]
- def infer_symbolic_types_single_pass(traced):
- """
- Calls our symbolic inferencer once.
- """
- r = Refine(traced)
- r.refine()
- mgu = unify_eq(r.constraints)
- substitute_all_types(traced.graph, mgu)
- def infer_symbolic_types(traced):
- """
- Calls our symbolic inferencer twice.
- This is useful when one pass is not enough
- to infer all the information such as the case
- for braodcasting.
- """
- r = Refine(traced)
- r.refine()
- mgu = unify_eq(r.constraints)
- substitute_all_types(traced.graph, mgu)
- r = Refine(traced)
- r.refine()
- mgu = unify_eq(r.constraints)
- substitute_all_types(traced.graph, mgu)
- r.symbolic_relations()
- def convert_eq(list_of_eq):
- """
- Convert equality constraints in the right format
- to be used by unification library.
- """
- lhs = []
- rhs = []
- for eq in list_of_eq:
- lhs.append(eq.lhs)
- rhs.append(eq.rhs)
- return tuple(lhs), tuple(rhs)
- def unify_eq(list_of_eq):
- """
- Apply unification to a set of
- equality constraints
- """
- lhs, rhs = convert_eq(list_of_eq)
- return unify(lhs, rhs)
- def substitute_solution_one_type(mapping, t):
- """
- Apply the most general unifier to a type
- """
- if isinstance(t, Var):
- if t in mapping.keys():
- return mapping[t]
- else:
- return t
- elif isinstance(t, TensorType):
- new_type = []
- for typ in t.__args__:
- if typ in mapping.keys():
- new_type.append(mapping[typ])
- else:
- new_type.append(typ)
- return TensorType(tuple(new_type))
- elif isinstance(t, list):
- new_type = []
- for typ in t:
- new_type.append(substitute_solution_one_type(mapping, typ))
- return new_type
- elif isinstance(t, tuple):
- new_type = []
- for typ in t:
- new_type.append(substitute_solution_one_type(mapping, typ))
- return tuple(new_type)
- else:
- return t
- def substitute_all_types(graph, mapping):
- """
- Apply the most general unifier to all types in a graph
- till reaching a fixed point. If the input and output graph
- are the same, we converge.
- """
- flag = True
- while flag:
- flag = False
- for k in mapping:
- old_mapping_val = mapping[k]
- if mapping[k] in mapping.keys():
- new_key = mapping[k]
- mapping[k] = mapping[new_key]
- if old_mapping_val != mapping[k]:
- flag = True
- for n in graph.nodes:
- n.type = substitute_solution_one_type(mapping, n.type)
- def check_for_type_equality(g1, g2):
- """
- A check equality to be used in fixed points.
- We do not use graph equality but instead type
- equality.
- """
- for n, m in zip(g1.nodes, g2.nodes):
- if n.type != m.type:
- return False
- return True
|