123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318 |
- import collections
- import dataclasses
- import itertools
- import logging
- import typing
- from typing import Callable, cast, Dict, List, Optional, Set, Tuple, Union
- import sympy
- from .codegen.common import index_prevent_reordering
- from .utils import (
- get_dtype_size,
- sympy_product,
- sympy_str,
- sympy_subs,
- sympy_symbol,
- VarRanges,
- )
- from .virtualized import V
- log = logging.getLogger(__name__)
- Dep = Union["MemoryDep", "StarDep", "WeakDep"]
- class MemoryDep(typing.NamedTuple):
- name: str
- index: sympy.Expr # type: ignore[assignment]
- size: Tuple[sympy.Expr, ...]
- def broadcast_extend_sizes(self, extra_sizes: List[sympy.Expr]) -> "MemoryDep":
- size = (*self.size, *[x for x in extra_sizes if x != 1])
- return MemoryDep(self.name, self.index, size)
- def maybe_swap_sizes(self) -> "MemoryDep":
- # swap only in simple cases where index is trivial and
- # there are just 2 sizes
- if (
- len(self.size) == 2
- and len(self.index.args) == 0
- and cast(sympy.Symbol, self.index).name == canonicalization_prefix() + "0"
- ):
- c = canonicalization_prefix()
- size = (self.size[1], self.size[0])
- s0 = sympy_symbol(c + "0")
- s1 = sympy_symbol(c + "1")
- index = sympy_subs(self.index, {s0: s1})
- return MemoryDep(self.name, index, size)
- else:
- return self
- def strip_last_size(self) -> "MemoryDep":
- nsizes = len(self.size)
- if not (nsizes >= 1 and len(self.index.args) <= nsizes - 1):
- return self
- # make sure last dim index is not used
- prefix = canonicalization_prefix()
- len_prefix = len(prefix)
- prefixes = [
- fs.name[:len_prefix]
- for fs in cast(Set[sympy.Symbol], self.index.free_symbols)
- ]
- assert (
- len(prefixes) == 0 or prefix in prefixes
- ), "index expression should contain canonicalized symbols"
- last_index = f"{prefix}{len(self.size)-1}"
- if last_index not in self.index.free_symbols:
- size = self.size[:-1]
- return MemoryDep(self.name, self.index, size)
- else:
- return self
- def rename(self, renames: Dict[str, str]) -> "MemoryDep":
- if self.name in renames:
- return MemoryDep(renames[self.name], self.index, self.size)
- return self
- def numbytes_hint(self):
- vars = set(self.index.free_symbols)
- size_vars_used = []
- for var in vars:
- if var.name.startswith(canonicalization_prefix()):
- # Sometimes with indirect indexing we have very weird symbol names
- assert " " not in var.name
- size_vars_used.append(int(var.name[len(canonicalization_prefix()) :]))
- return V.graph.sizevars.size_hint(
- sympy_product([self.size[i] for i in size_vars_used])
- ) * get_dtype_size(V.graph.get_dtype(self.name))
- def is_contiguous(self) -> bool:
- return isinstance(self.index, (sympy.Symbol, sympy.Integer))
- class StarDep(typing.NamedTuple):
- # depends on the entire buffer
- name: str
- def rename(self, renames: Dict[str, str]) -> "StarDep":
- if self.name in renames:
- return StarDep(renames[self.name])
- return self
- def numbytes_hint(self):
- from .ir import MultiOutputLayout
- if self.name in V.graph.name_to_buffer:
- buf = V.graph.name_to_buffer[self.name]
- elif self.name in V.graph.graph_inputs:
- buf = V.graph.graph_inputs[self.name]
- else:
- return 1
- if hasattr(buf, "layout") and isinstance(buf.layout, MultiOutputLayout):
- # NB: Too annoying to acquire, should only be used for instrumentation
- return 1
- return V.graph.sizevars.size_hint(
- sympy_product(buf.get_size())
- ) * get_dtype_size(buf.get_dtype())
- def is_contiguous(self) -> bool:
- return False
- # Used for tracking mutation ordering
- # if A reads a buffer and B mutates it
- # B must be ordered after A
- class WeakDep(typing.NamedTuple):
- name: str
- def rename(self, renames: Dict[str, str]) -> "WeakDep":
- if self.name in renames:
- return WeakDep(renames[self.name])
- return self
- def numbytes_hint(self):
- return 1 # Purely inserted for ordering, not an actual dep
- def is_contiguous(self) -> bool:
- return False
- class IndexExprDep(typing.NamedTuple):
- index: sympy.Expr # type: ignore[assignment]
- size: Tuple[sympy.Expr, ...]
- @dataclasses.dataclass
- class ReadWrites:
- reads: Set[Dep]
- writes: Set[Dep]
- index_exprs: Set[IndexExprDep]
- range_vars: Optional[List[sympy.Expr]] = None
- var_ranges: Optional[VarRanges] = None
- def rename(self, renames: typing.Dict[str, str]) -> "ReadWrites":
- return ReadWrites(
- {dep.rename(renames) for dep in self.reads},
- {dep.rename(renames) for dep in self.writes},
- self.index_exprs,
- self.range_vars,
- self.var_ranges,
- )
- def with_read(self, dep: Dep) -> "ReadWrites":
- assert isinstance(dep, (WeakDep, StarDep))
- return ReadWrites(
- set.union(self.reads, {dep}),
- self.writes,
- self.index_exprs,
- self.range_vars,
- self.var_ranges,
- )
- def merge(self, other):
- reads = set.union(self.reads, other.reads)
- writes = set.union(self.writes, other.writes)
- index_exprs = set.union(self.index_exprs, other.index_exprs)
- return ReadWrites(
- reads - writes,
- writes,
- index_exprs,
- )
- def remove_reads(self, rem_reads):
- return ReadWrites(
- self.reads - rem_reads,
- self.writes,
- self.index_exprs,
- self.range_vars,
- self.var_ranges,
- )
- class _RecordLoadStoreInner(V.MockHandler):
- def __init__(self, var_ranges: VarRanges, normalize: bool):
- super().__init__()
- self._reads: Set[MemoryDep] = set()
- self._writes: Set[MemoryDep] = set()
- self._index_exprs: Set[IndexExprDep] = set()
- self._var_ranges: VarRanges = var_ranges
- self._normalize: bool = normalize
- def canonicalize(
- self, index: sympy.Expr
- ) -> Tuple[sympy.Expr, Tuple[sympy.Expr, ...]]:
- sizes = list(self._var_ranges.values())
- sizes = [V.graph.sizevars.simplify(x) for x in sizes]
- if not self._normalize:
- return index, tuple([x for x in sizes if x != 1])
- # Try to further simplify the indexes even if simplify_loops didn't
- # convert it to the simpliest form because of the interference from
- # different indexing formulas.
- index_vars = list(self._var_ranges.keys())
- new_sizes, reindex, prune = V.graph.sizevars._simplify_loops(
- index_vars,
- sizes,
- index_prevent_reordering([index], index_vars, sizes),
- )
- # assign new variables each dimension to deal with numbering mismatches
- # d0, d1, d2 could become d0, d2 -- which won't match d0, d1
- _, add_var = var_builder(canonicalization_prefix())
- replacement = dict(zip(index_vars, reindex([add_var(x) for x in new_sizes])))
- index = sympy_subs(sympy.expand(index), replacement)
- return index, tuple(new_sizes)
- def load(self, name: str, index: sympy.Expr) -> str:
- canonicalized_index, canonicalized_size = self.canonicalize(index)
- self._reads.add(MemoryDep(name, canonicalized_index, canonicalized_size))
- return f"load({name}, {sympy_str(index)})"
- def store(self, name: str, index: sympy.Expr, value: str, mode=None) -> str:
- canonicalized_index, canonicalized_size = self.canonicalize(index)
- self._writes.add(MemoryDep(name, canonicalized_index, canonicalized_size))
- return f"store({name}, {sympy_str(index)}, {value}, {mode})"
- def reduction(
- self, name: str, dtype, src_dtype, reduction_type, index, value
- ) -> str:
- return self.store(name, index, f"reduce_{reduction_type})({value})")
- def index_expr(self, index: sympy.Expr, dtype) -> str:
- canonicalized_index, canonicalized_size = self.canonicalize(index)
- self._index_exprs.add(IndexExprDep(canonicalized_index, canonicalized_size))
- return f"index_expr({sympy_str(index)}, {dtype})"
- class RecordLoadStore(V.KernelFormatterHandler):
- def __init__(self, var_ranges: VarRanges, normalize: bool):
- parent_handler = _RecordLoadStoreInner(
- var_ranges=var_ranges, normalize=normalize
- )
- super().__init__(parent_handler=parent_handler)
- def var_builder(prefix: str) -> Tuple[VarRanges, Callable[[sympy.Expr], sympy.Symbol]]:
- cnt = itertools.count()
- var_ranges: VarRanges = collections.OrderedDict()
- def add_var(length: sympy.Expr) -> sympy.Symbol:
- v = sympy_symbol(f"{prefix}{next(cnt)}")
- var_ranges[v] = length
- return v
- return var_ranges, add_var
- def index_vars_no_squeeze(*argsizes: Tuple[sympy.Expr, ...], prefix: str):
- var_ranges, add_var = var_builder(prefix)
- args: List[List[sympy.Symbol]] = []
- for size in argsizes:
- args.append(list(map(add_var, size)))
- return args, var_ranges
- def index_vars_squeeze(*argsizes: Tuple[sympy.Expr, ...], prefix: str = "d"):
- from .ir import SqueezeView
- var_ranges, add_var = var_builder(prefix)
- args: List[List[sympy.Expr]] = []
- new_sizes: List[List[sympy.Expr]] = []
- for size in argsizes:
- new_size, reindex = SqueezeView.squeezer(size)
- new_sizes.append(new_size)
- args.append(reindex(list(map(add_var, new_size))))
- return new_sizes, args, var_ranges
- def extract_read_writes(
- fn: Callable,
- *argsizes: Tuple[sympy.Expr, ...],
- normalize: bool = False,
- prefix: str = "d",
- ):
- _, args, var_ranges = index_vars_squeeze(*argsizes, prefix=prefix)
- rw = RecordLoadStore(var_ranges, normalize=normalize)
- with V.set_ops_handler(rw): # type: ignore[call-arg]
- fn(*args)
- if normalize:
- range_vars = [] # Number of vars could differ due to normalization
- else:
- range_vars = [*itertools.chain(*args)]
- inner = rw.parent_handler
- return ReadWrites(
- set(inner._reads),
- set(inner._writes),
- inner._index_exprs,
- range_vars,
- var_ranges,
- )
- def canonicalization_prefix():
- return "c"
|