bytecode_transformation.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500
  1. import dataclasses
  2. import dis
  3. import itertools
  4. import sys
  5. import types
  6. from typing import Any, Dict, List, Optional, Tuple
  7. from .bytecode_analysis import (
  8. propagate_line_nums,
  9. remove_extra_line_nums,
  10. stacksize_analysis,
  11. )
  12. @dataclasses.dataclass
  13. class Instruction:
  14. """A mutable version of dis.Instruction"""
  15. opcode: int
  16. opname: str
  17. arg: Optional[int]
  18. argval: Any
  19. offset: Optional[int] = None
  20. starts_line: Optional[int] = None
  21. is_jump_target: bool = False
  22. # extra fields to make modification easier:
  23. target: Optional["Instruction"] = None
  24. def __hash__(self):
  25. return id(self)
  26. def __eq__(self, other):
  27. return id(self) == id(other)
  28. def convert_instruction(i: dis.Instruction):
  29. return Instruction(
  30. i.opcode,
  31. i.opname,
  32. i.arg,
  33. i.argval,
  34. i.offset,
  35. i.starts_line,
  36. i.is_jump_target,
  37. )
  38. class _NotProvided:
  39. pass
  40. def create_instruction(name, arg=None, argval=_NotProvided, target=None):
  41. if argval is _NotProvided:
  42. argval = arg
  43. return Instruction(
  44. opcode=dis.opmap[name], opname=name, arg=arg, argval=argval, target=target
  45. )
  46. # Python 3.11 remaps
  47. def create_jump_absolute(target):
  48. inst = "JUMP_FORWARD" if sys.version_info >= (3, 11) else "JUMP_ABSOLUTE"
  49. return create_instruction(inst, target=target)
  50. def create_dup_top():
  51. if sys.version_info >= (3, 11):
  52. return create_instruction("COPY", 1)
  53. return create_instruction("DUP_TOP")
  54. def create_rot_n(n):
  55. """
  56. Returns a "simple" sequence of instructions that rotates TOS to the n-th
  57. position in the stack. For Python < 3.11, returns a single ROT_*
  58. instruction. If no such instruction exists, an error is raised and the
  59. caller is expected to generate an equivalent sequence of instructions.
  60. For Python >= 3.11, any rotation can be expressed as a simple sequence of
  61. swaps.
  62. """
  63. if n <= 1:
  64. # don't rotate
  65. return []
  66. if sys.version_info >= (3, 11):
  67. # rotate can be expressed as a sequence of swap operations
  68. # e.g. rotate 3 is equivalent to swap 3, swap 2
  69. return [create_instruction("SWAP", i) for i in range(n, 1, -1)]
  70. # ensure desired rotate function exists
  71. if sys.version_info < (3, 8) and n >= 4:
  72. raise AttributeError(f"rotate {n} not supported for Python < 3.8")
  73. if sys.version_info < (3, 10) and n >= 5:
  74. raise AttributeError(f"rotate {n} not supported for Python < 3.10")
  75. if n <= 4:
  76. return [create_instruction("ROT_" + ["TWO", "THREE", "FOUR"][n - 2])]
  77. return [create_instruction("ROT_N", n)]
  78. def lnotab_writer(lineno, byteno=0):
  79. """
  80. Used to create typing.CodeType.co_lnotab
  81. See https://github.com/python/cpython/blob/main/Objects/lnotab_notes.txt
  82. This is the internal format of the line number table if Python < 3.10
  83. """
  84. assert sys.version_info < (3, 10)
  85. lnotab = []
  86. def update(lineno_new, byteno_new):
  87. nonlocal byteno, lineno
  88. while byteno_new != byteno or lineno_new != lineno:
  89. byte_offset = max(0, min(byteno_new - byteno, 255))
  90. line_offset = max(-128, min(lineno_new - lineno, 127))
  91. assert byte_offset != 0 or line_offset != 0
  92. byteno += byte_offset
  93. lineno += line_offset
  94. lnotab.extend((byte_offset, line_offset & 0xFF))
  95. return lnotab, update
  96. def linetable_writer(first_lineno):
  97. """
  98. Used to create typing.CodeType.co_linetable
  99. See https://github.com/python/cpython/blob/main/Objects/lnotab_notes.txt
  100. This is the internal format of the line number table if Python >= 3.10
  101. """
  102. assert sys.version_info >= (3, 10)
  103. linetable = []
  104. lineno = first_lineno
  105. lineno_delta = 0
  106. byteno = 0
  107. def _update(byteno_delta, lineno_delta):
  108. while byteno_delta != 0 or lineno_delta != 0:
  109. byte_offset = max(0, min(byteno_delta, 254))
  110. line_offset = max(-127, min(lineno_delta, 127))
  111. assert byte_offset != 0 or line_offset != 0
  112. byteno_delta -= byte_offset
  113. lineno_delta -= line_offset
  114. linetable.extend((byte_offset, line_offset & 0xFF))
  115. def update(lineno_new, byteno_new):
  116. nonlocal lineno, lineno_delta, byteno
  117. byteno_delta = byteno_new - byteno
  118. byteno = byteno_new
  119. _update(byteno_delta, lineno_delta)
  120. lineno_delta = lineno_new - lineno
  121. lineno = lineno_new
  122. def end(total_bytes):
  123. _update(total_bytes - byteno, lineno_delta)
  124. return linetable, update, end
  125. def assemble(instructions: List[Instruction], firstlineno):
  126. """Do the opposite of dis.get_instructions()"""
  127. code = []
  128. if sys.version_info < (3, 10):
  129. lnotab, update_lineno = lnotab_writer(firstlineno)
  130. else:
  131. lnotab, update_lineno, end = linetable_writer(firstlineno)
  132. for inst in instructions:
  133. if inst.starts_line is not None:
  134. update_lineno(inst.starts_line, len(code))
  135. arg = inst.arg or 0
  136. code.extend((inst.opcode, arg & 0xFF))
  137. if sys.version_info >= (3, 11):
  138. for _ in range(instruction_size(inst) // 2 - 1):
  139. code.extend((0, 0))
  140. if sys.version_info >= (3, 10):
  141. end(len(code))
  142. return bytes(code), bytes(lnotab)
  143. def virtualize_jumps(instructions):
  144. """Replace jump targets with pointers to make editing easier"""
  145. jump_targets = {inst.offset: inst for inst in instructions}
  146. for inst in instructions:
  147. if inst.opcode in dis.hasjabs or inst.opcode in dis.hasjrel:
  148. for offset in (0, 2, 4, 6):
  149. if jump_targets[inst.argval + offset].opcode != dis.EXTENDED_ARG:
  150. inst.target = jump_targets[inst.argval + offset]
  151. break
  152. _REL_JUMPS = set(dis.hasjrel)
  153. def flip_jump_direction(instruction):
  154. if sys.version_info < (3, 11):
  155. raise RuntimeError("Cannot flip jump direction in Python < 3.11")
  156. if "FORWARD" in instruction.opname:
  157. instruction.opname = instruction.opname.replace("FORWARD", "BACKWARD")
  158. elif "BACKWARD" in instruction.opname:
  159. instruction.opname = instruction.opname.replace("BACKWARD", "FORWARD")
  160. else:
  161. raise AttributeError("Instruction is not a forward or backward jump")
  162. instruction.opcode = dis.opmap[instruction.opname]
  163. assert instruction.opcode in _REL_JUMPS
  164. def devirtualize_jumps(instructions):
  165. """Fill in args for virtualized jump target after instructions may have moved"""
  166. indexof = {id(inst): i for i, inst, in enumerate(instructions)}
  167. jumps = set(dis.hasjabs).union(set(dis.hasjrel))
  168. for inst in instructions:
  169. if inst.opcode in jumps:
  170. target = inst.target
  171. target_index = indexof[id(target)]
  172. for offset in (1, 2, 3):
  173. if (
  174. target_index >= offset
  175. and instructions[target_index - offset].opcode == dis.EXTENDED_ARG
  176. ):
  177. target = instructions[target_index - offset]
  178. else:
  179. break
  180. if inst.opcode in dis.hasjabs:
  181. if sys.version_info < (3, 10):
  182. inst.arg = target.offset
  183. elif sys.version_info < (3, 11):
  184. # `arg` is expected to be bytecode offset, whereas `offset` is byte offset.
  185. # Divide since bytecode is 2 bytes large.
  186. inst.arg = int(target.offset / 2)
  187. else:
  188. raise RuntimeError("Python 3.11+ should not have absolute jumps")
  189. else: # relative jump
  190. # byte offset between target and next instruction
  191. inst.arg = int(target.offset - inst.offset - instruction_size(inst))
  192. if inst.arg < 0:
  193. if sys.version_info < (3, 11):
  194. raise RuntimeError("Got negative jump offset for Python < 3.11")
  195. inst.arg = -inst.arg
  196. # forward jumps become backward
  197. if "FORWARD" in inst.opname:
  198. flip_jump_direction(inst)
  199. elif inst.arg > 0:
  200. # backward jumps become forward
  201. if sys.version_info >= (3, 11) and "BACKWARD" in inst.opname:
  202. flip_jump_direction(inst)
  203. if sys.version_info >= (3, 10):
  204. # see bytecode size comment in the absolute jump case above
  205. inst.arg //= 2
  206. inst.argval = target.offset
  207. inst.argrepr = f"to {target.offset}"
  208. def strip_extended_args(instructions: List[Instruction]):
  209. instructions[:] = [i for i in instructions if i.opcode != dis.EXTENDED_ARG]
  210. def remove_load_call_method(instructions: List[Instruction]):
  211. """LOAD_METHOD puts a NULL on the stack which causes issues, so remove it"""
  212. rewrites = {"LOAD_METHOD": "LOAD_ATTR", "CALL_METHOD": "CALL_FUNCTION"}
  213. for inst in instructions:
  214. if inst.opname in rewrites:
  215. inst.opname = rewrites[inst.opname]
  216. inst.opcode = dis.opmap[inst.opname]
  217. return instructions
  218. def explicit_super(code: types.CodeType, instructions: List[Instruction]):
  219. """convert super() with no args into explict arg form"""
  220. cell_and_free = (code.co_cellvars or tuple()) + (code.co_freevars or tuple())
  221. output = []
  222. for idx, inst in enumerate(instructions):
  223. output.append(inst)
  224. if inst.opname == "LOAD_GLOBAL" and inst.argval == "super":
  225. nexti = instructions[idx + 1]
  226. if nexti.opname == "CALL_FUNCTION" and nexti.arg == 0:
  227. assert "__class__" in cell_and_free
  228. output.append(
  229. create_instruction(
  230. "LOAD_DEREF", cell_and_free.index("__class__"), "__class__"
  231. )
  232. )
  233. first_var = code.co_varnames[0]
  234. if first_var in cell_and_free:
  235. output.append(
  236. create_instruction(
  237. "LOAD_DEREF", cell_and_free.index(first_var), first_var
  238. )
  239. )
  240. else:
  241. output.append(create_instruction("LOAD_FAST", 0, first_var))
  242. nexti.arg = 2
  243. nexti.argval = 2
  244. instructions[:] = output
  245. def fix_extended_args(instructions: List[Instruction]):
  246. """Fill in correct argvals for EXTENDED_ARG ops"""
  247. output = []
  248. def maybe_pop_n(n):
  249. for _ in range(n):
  250. if output and output[-1].opcode == dis.EXTENDED_ARG:
  251. output.pop()
  252. for i, inst in enumerate(instructions):
  253. if inst.opcode == dis.EXTENDED_ARG:
  254. # Leave this instruction alone for now so we never shrink code
  255. inst.arg = 0
  256. elif inst.arg and inst.arg > 0xFFFFFF:
  257. maybe_pop_n(3)
  258. output.append(create_instruction("EXTENDED_ARG", inst.arg >> 24))
  259. output.append(create_instruction("EXTENDED_ARG", inst.arg >> 16))
  260. output.append(create_instruction("EXTENDED_ARG", inst.arg >> 8))
  261. elif inst.arg and inst.arg > 0xFFFF:
  262. maybe_pop_n(2)
  263. output.append(create_instruction("EXTENDED_ARG", inst.arg >> 16))
  264. output.append(create_instruction("EXTENDED_ARG", inst.arg >> 8))
  265. elif inst.arg and inst.arg > 0xFF:
  266. maybe_pop_n(1)
  267. output.append(create_instruction("EXTENDED_ARG", inst.arg >> 8))
  268. output.append(inst)
  269. added = len(output) - len(instructions)
  270. assert added >= 0
  271. instructions[:] = output
  272. return added
  273. # from https://github.com/python/cpython/blob/v3.11.1/Include/internal/pycore_opcode.h#L41
  274. # TODO use the actual object instead, can interface from eval_frame.c
  275. _PYOPCODE_CACHES = {
  276. "BINARY_SUBSCR": 4,
  277. "STORE_SUBSCR": 1,
  278. "UNPACK_SEQUENCE": 1,
  279. "STORE_ATTR": 4,
  280. "LOAD_ATTR": 4,
  281. "COMPARE_OP": 2,
  282. "LOAD_GLOBAL": 5,
  283. "BINARY_OP": 1,
  284. "LOAD_METHOD": 10,
  285. "PRECALL": 1,
  286. "CALL": 4,
  287. }
  288. def instruction_size(inst):
  289. if sys.version_info >= (3, 11):
  290. return 2 * (_PYOPCODE_CACHES.get(dis.opname[inst.opcode], 0) + 1)
  291. return 2
  292. def check_offsets(instructions):
  293. offset = 0
  294. for inst in instructions:
  295. assert inst.offset == offset
  296. offset += instruction_size(inst)
  297. def update_offsets(instructions):
  298. offset = 0
  299. for inst in instructions:
  300. inst.offset = offset
  301. offset += instruction_size(inst)
  302. def debug_bytes(*args):
  303. index = range(max(map(len, args)))
  304. result = []
  305. for arg in (
  306. [index] + list(args) + [[int(a != b) for a, b in zip(args[-1], args[-2])]]
  307. ):
  308. result.append(" ".join(f"{x:03}" for x in arg))
  309. return "bytes mismatch\n" + "\n".join(result)
  310. def debug_checks(code):
  311. """Make sure our assembler produces same bytes as we start with"""
  312. dode = transform_code_object(code, lambda x, y: None, safe=True)
  313. assert code.co_code == dode.co_code, debug_bytes(code.co_code, dode.co_code)
  314. assert code.co_lnotab == dode.co_lnotab, debug_bytes(code.co_lnotab, dode.co_lnotab)
  315. HAS_LOCAL = set(dis.haslocal)
  316. HAS_NAME = set(dis.hasname)
  317. def fix_vars(instructions: List[Instruction], code_options):
  318. varnames = {name: idx for idx, name in enumerate(code_options["co_varnames"])}
  319. names = {name: idx for idx, name in enumerate(code_options["co_names"])}
  320. for i in range(len(instructions)):
  321. if instructions[i].opcode in HAS_LOCAL:
  322. instructions[i].arg = varnames[instructions[i].argval]
  323. elif instructions[i].opcode in HAS_NAME:
  324. instructions[i].arg = names[instructions[i].argval]
  325. def transform_code_object(code, transformations, safe=False):
  326. # Python 3.11 changes to code keys are not fully documented.
  327. # See https://github.com/python/cpython/blob/3.11/Objects/clinic/codeobject.c.h#L24
  328. # for new format.
  329. keys = ["co_argcount"]
  330. keys.append("co_posonlyargcount")
  331. keys.extend(
  332. [
  333. "co_kwonlyargcount",
  334. "co_nlocals",
  335. "co_stacksize",
  336. "co_flags",
  337. "co_code",
  338. "co_consts",
  339. "co_names",
  340. "co_varnames",
  341. "co_filename",
  342. "co_name",
  343. ]
  344. )
  345. if sys.version_info >= (3, 11):
  346. keys.append("co_qualname")
  347. keys.append("co_firstlineno")
  348. if sys.version_info >= (3, 10):
  349. keys.append("co_linetable")
  350. else:
  351. keys.append("co_lnotab")
  352. if sys.version_info >= (3, 11):
  353. # not documented, but introduced in https://github.com/python/cpython/issues/84403
  354. keys.append("co_exceptiontable")
  355. keys.extend(
  356. [
  357. "co_freevars",
  358. "co_cellvars",
  359. ]
  360. )
  361. code_options = {k: getattr(code, k) for k in keys}
  362. assert len(code_options["co_varnames"]) == code_options["co_nlocals"]
  363. instructions = cleaned_instructions(code, safe)
  364. propagate_line_nums(instructions)
  365. transformations(instructions, code_options)
  366. return clean_and_assemble_instructions(instructions, keys, code_options)[1]
  367. def clean_and_assemble_instructions(
  368. instructions: List[Instruction], keys: List[str], code_options: Dict[str, Any]
  369. ) -> Tuple[List[Instruction], types.CodeType]:
  370. fix_vars(instructions, code_options)
  371. dirty = True
  372. while dirty:
  373. update_offsets(instructions)
  374. devirtualize_jumps(instructions)
  375. # this pass might change offsets, if so we need to try again
  376. dirty = fix_extended_args(instructions)
  377. remove_extra_line_nums(instructions)
  378. bytecode, lnotab = assemble(instructions, code_options["co_firstlineno"])
  379. if sys.version_info < (3, 10):
  380. code_options["co_lnotab"] = lnotab
  381. else:
  382. code_options["co_linetable"] = lnotab
  383. code_options["co_code"] = bytecode
  384. code_options["co_nlocals"] = len(code_options["co_varnames"])
  385. code_options["co_stacksize"] = stacksize_analysis(instructions)
  386. assert set(keys) - {"co_posonlyargcount"} == set(code_options.keys()) - {
  387. "co_posonlyargcount"
  388. }
  389. if sys.version_info >= (3, 11):
  390. # generated code doesn't contain exceptions, so leave exception table empty
  391. code_options["co_exceptiontable"] = b""
  392. return instructions, types.CodeType(*[code_options[k] for k in keys])
  393. def cleaned_instructions(code, safe=False):
  394. instructions = list(map(convert_instruction, dis.get_instructions(code)))
  395. check_offsets(instructions)
  396. virtualize_jumps(instructions)
  397. strip_extended_args(instructions)
  398. if not safe:
  399. remove_load_call_method(instructions)
  400. explicit_super(code, instructions)
  401. return instructions
  402. _unique_id_counter = itertools.count()
  403. def unique_id(name):
  404. return f"{name}_{next(_unique_id_counter)}"
  405. def is_generator(code: types.CodeType):
  406. co_generator = 0x20
  407. return (code.co_flags & co_generator) > 0