logical_schema.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. # type: ignore[assignment]
  2. from dataclasses import dataclass
  3. from enum import auto, Enum
  4. from typing import List, Union, Dict
  5. ################################################################################
  6. # Following section is the defining the permissible argument types for operators
  7. # Copied from torchgen/model.py
  8. class ScalarType(Enum):
  9. u8 = auto() # torch.uint8
  10. i8 = auto() # torch.int8
  11. i16 = auto() # torch.int16 or torch.short
  12. i32 = auto() # torch.int32 or torch.int
  13. i64 = auto() # torch.int64 or torch.long
  14. f16 = auto() # torch.float16 or torch.half
  15. f32 = auto() # torch.float32 or torch.float
  16. f64 = auto() # torch.float64 or torch.double
  17. c32 = auto() # torch.complex32
  18. c64 = auto() # torch.complex64 or torch.cfloat
  19. c128 = auto() # torch.complex128 or torch.cdouble
  20. b8 = auto() # torch.bool
  21. bf16 = auto() # torch.bfloat16
  22. # Copied from torch/_C/__init__.pyi.in
  23. class Layout(Enum):
  24. # Defined in torch/csrc/utils/tensor_layouts.cpp
  25. strided = auto()
  26. sparse_coo = auto()
  27. sparse_csr = auto()
  28. sparse_csc = auto()
  29. sparse_bsr = auto()
  30. sparse_bsc = auto()
  31. _mkldnn = auto()
  32. # Copied from torch/_C/__init__.pyi.in
  33. class MemoryFormat(Enum):
  34. # Defined in torch/csrc/utils/tensor_memoryformats.cpp
  35. contiguous_format = auto()
  36. channels_last = auto()
  37. channels_last_3d = auto()
  38. preserve_format = auto()
  39. # Copied from torch/_C/__init__.pyi.in
  40. @dataclass
  41. class Device:
  42. # Defined in torch/csrc/Device.cpp
  43. type: str
  44. index: int
  45. @dataclass
  46. class SymInt: # Union, ONLY EXACTLY ONE of the following fields can be set
  47. as_int: int = None
  48. as_sym: str = None
  49. # !!! To support t.item(), we need to introduce SymFloat
  50. # @dataclass
  51. # class SymFloat: # Union, ONLY EXACTLY ONE of the following fields can be set
  52. # as_flaot: float = None
  53. # as_sym: str = None
  54. # This is a Tensor Arugment used in the args of an node
  55. # We intentionally don't store the tensor's storage, nor the tensor's meta data here,
  56. # as the same tensor argument can be used in multiple nodes, and we want to avoid storing the same data multiple times.
  57. # In another word, this field is an reference to the tensor, not the tensor itself.
  58. @dataclass
  59. class TensorArgument:
  60. name: str # identifier of the tensor, which must exist in graph's tensor_values
  61. # This is a SymInt Arugment used in the args of an node
  62. # We intentionally don't store the SymInt's value here, as the same SymInt argument can be used in multiple nodes
  63. # This field is an reference to the SymInt
  64. @dataclass
  65. class SymIntArgument:
  66. name: str # identifier of the symint, which must exist in graph's symint_values
  67. # Permissible return types for operators
  68. # !!! Notice: this assumes that a node can only return Tensor(s) and Symint(s), and not other int/float/bool types...
  69. # !!! What about .item()? Do we need to handle this?
  70. @dataclass
  71. class ReturnArgument: # Union, ONLY EXACTLY ONE of the following fields can be set
  72. as_tensor: TensorArgument = None
  73. # !!! ATM, no operator has return type as Tensor[], might need this latter?
  74. # as_tensors: List[TensorArgument] = None
  75. as_symint: SymIntArgument = None
  76. # Permissible argument types for operators
  77. # !!! This is a Union struct, but there is no good python construct to model this
  78. @dataclass
  79. class Argument: # Union, ONLY EXACTLY ONE of the following fields can be set
  80. # A special type for representing python None in the arguments
  81. # This must only be used for ops that accepts None as an argument, e.g. Tensor?, Scalar?, int?, int[]?
  82. as_none: bool = None
  83. as_tensor: TensorArgument = None
  84. as_tensors: List[TensorArgument] = None # Tensor[], used by aten.cat, and condition ops
  85. as_symint: SymIntArgument = None # Symint can be an argument, there are symint in native_function.yaml
  86. as_symints: List[SymIntArgument] = None # Symint[] can be an argement, there are symint[] in native_function.yaml
  87. as_bool: bool = None
  88. # !!! There are use of bool[3] in canonical aten ops, consider if we can simplify this
  89. as_bools: List[bool] = None # for bool[]
  90. as_int: int = None
  91. as_ints: List[int] = None # for int[]
  92. as_float: float = None
  93. as_floats: List[float] = None # for float[]
  94. as_str: str = None
  95. # List[str], # !!! There is no str[] in native_function.yaml. Consider if this is needed for expressiveness
  96. # Graph, # !!! Consider how to handle condition op, which need to pass in a graph for the branch
  97. # List[Graph], # !!! What about list of graphs? Do we need this?
  98. as_gm: "GraphModule" = None # !!! ATM, torch.cond models branch as GraphModule
  99. # !!! Following types doesn't have a list version in native_function.yaml
  100. as_scalar_type: ScalarType = None
  101. as_memory_format: MemoryFormat = None
  102. as_layout: Layout = None
  103. as_device: Device = None
  104. ################################################################################
  105. # Following section is the defining the schema of serializing a concrete tensor
  106. # TensorMeta is a decription of a tensor, without the actual data (,effectively maps to FakeTensor)
  107. # TensorMeta has multliple uses
  108. # 1. Represent the property of a concrete tensor backed by a storage
  109. # - This is used in the serialization of a concrete tensor, e.g. model weight
  110. # - In this case, sizes and strides must be concrete ints, and cannot be symbolic
  111. # - stride and storage_offset have to used to correctly reconstruct the tensor from the storage
  112. # 2. Represent the property of a virtual tensor (see TensorValue below)
  113. # - In this case, sizes and strides can be either concrete ints or symbolic ints.
  114. # - device/strides/storage_offset/layout/memory_format are tied to pytorch's implementation.
  115. # These are faithful capture of tensor's detail in pytorch's executions during tracing
  116. # However, it's up to downstream system on how to utilized these fields
  117. # In another word, these feilds are suggestive, rather than mandatory.
  118. @dataclass
  119. class TensorMeta:
  120. dtype: ScalarType
  121. sizes: List[SymInt]
  122. # needed for training
  123. requires_grad: bool
  124. # !!! see description above, there are subtle difference on how these fields should be interpreted
  125. device: Device
  126. strides: List[SymInt]
  127. storage_offset: SymInt
  128. layout: Layout
  129. @dataclass
  130. class Buffer:
  131. # data stored in big endian
  132. buffer: bytes
  133. # External data needs to stored in big endian
  134. @dataclass
  135. class ExternalBuffer:
  136. location: str
  137. offset: str # !!! Consider using int, but int has int_max limitation
  138. length: str # !!! Consider using int, but int has int_max limitation
  139. checksum: str
  140. @dataclass
  141. class Storage:
  142. class DataLocation(Enum):
  143. Internal = auto()
  144. External = auto()
  145. data_location: DataLocation
  146. data: Union[Buffer, ExternalBuffer]
  147. # This is a concrete tensor backed by storage
  148. @dataclass
  149. class Tensor:
  150. # storage
  151. storage: Storage
  152. # metadata
  153. meta: TensorMeta
  154. ################################################################################
  155. # Following section is defining the schema of 3 level construct: GraphModule, Graph, Node
  156. # TensorValue has no corresponding class in fx
  157. # TensorValue is the "tensor results" that are passed between nodes in the graph
  158. # TensorValue is a named virtual tensor, with an TensorMeta that describes the properties of the tensor
  159. @dataclass
  160. class TensorValue:
  161. name: str # unique identifier of the TensorValue, referenced in Argument.as_tensor field
  162. meta: TensorMeta # tensor meta
  163. @dataclass
  164. class SymIntValue:
  165. name: str # unique identifier of the SymIntValue, referenced in Argument.as_symint field
  166. value: SymInt
  167. @dataclass
  168. class NodeMetadata:
  169. stack_trace: str # source info of a node
  170. nn_module_stack: str # stack of nn.Module that the node originates from
  171. extra: Dict[str, str] # arbitrary string-string pairs for extra metadata
  172. # Maps to fx.Node
  173. # Node can only be 'call_function' ops
  174. # 'placeholder' and 'output' are serialized as inputs and outputs of the Graph
  175. # 'get_attr' is not needed anymore, as it's an implicit lookup from GraphModule's parameters/buffers
  176. # 'call_method' and 'call_module' is not supported, as it's not used in the canonical FX Graph
  177. @dataclass
  178. class Node:
  179. # fully qualified name to the target, e.g. aten.add.Tensnor
  180. # !!! Consider using a structured operator name instead of string
  181. target: str
  182. args: List[Argument]
  183. # kwargs for this node
  184. # !!! Not all types in Argument are used as kwargs, e.g. TensorArgument should not be used as kwargs
  185. # Do we want to enforce this in the schema? i.e. only allow certain types to be used as kwargs?
  186. kwargs: Dict[str, Argument]
  187. # A list of Argument returned by this node
  188. outputs: List[ReturnArgument]
  189. metadata: NodeMetadata # metadata fields for this node
  190. # Maps to fx.Graph
  191. @dataclass(init=False)
  192. class Graph:
  193. # Maps to fx.graph's placeholder nodes.
  194. # !!! Do we allow SymInt as graph input?
  195. # !!! need to think about where to store the metadata for placeholder nodes
  196. inputs: List[TensorArgument]
  197. # Maps to fx.graph's output node.
  198. # !!! Do we allow SymInt as graph output?
  199. # !!! need to thinking about where to store the metadata for original output node
  200. outputs: List[TensorArgument]
  201. # maps to computations nodes in fx.graph
  202. # Placeholder nodes and output node are not included in this list.
  203. # Only call_function can be included in this list
  204. nodes: List[Node]
  205. # Tensor values that appear in the graph
  206. # They could be graph inputs, graph outputs, or intermediate tensor values produced by nodes
  207. tensor_values: List[TensorValue]
  208. # SymInt values that appear in the graph
  209. symint_values: List[SymIntValue]
  210. # Maps to fx.GraphModule
  211. # This the top level construct for the model
  212. @dataclass(init=False)
  213. class GraphModule:
  214. # A readable name for the model, potentially maps to GraphModule's self.__class__.__name__
  215. # This is not an identified for GraphModule
  216. name: str
  217. graph: Graph # Only one Graph per GraphModule
  218. # maps to GraphModule's meta, which is a Dict[str, Any], but we only support string key and string value.
  219. metadata : Dict[str, str]
  220. # Stateful fields of the graph module
  221. # The name of the tensor will be used to bind to the TensorValues of Graph
  222. # !!! Consider storing them in the Graph.
  223. # There are functional difference between buffers and parameters, so they are stored separately.
  224. parameters: Dict[str, Tensor]
  225. buffers: Dict[str, Tensor]
  226. # !!! model constants: constant, etc.
  227. # !!! Might also need to store the shape_env for symints, but it's unclear how downstream system will use it.
  228. # !!! Consider storing it in the GraphModule, or in the Graph.