check_kernel_launches.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. import os
  2. import re
  3. import sys
  4. from typing import List
  5. __all__ = [
  6. "check_code_for_cuda_kernel_launches",
  7. "check_cuda_kernel_launches",
  8. ]
  9. # FILES TO EXCLUDE (match is done with suffix using `endswith`)
  10. # You wouldn't drive without a seatbelt, though, so why would you
  11. # launch a kernel without some safety? Use this as a quick workaround
  12. # for a problem with the checker, fix the checker, then de-exclude
  13. # the files in question.
  14. exclude_files: List[str] = []
  15. # Without using a C++ AST we can't 100% detect kernel launches, so we
  16. # model them as having the pattern "<<<parameters>>>(arguments);"
  17. # We then require that `C10_CUDA_KERNEL_LAUNCH_CHECK` be
  18. # the next statement.
  19. #
  20. # We model the next statement as ending at the next `}` or `;`.
  21. # If we see `}` then a clause ended (bad) if we see a semi-colon then
  22. # we expect the launch check just before it.
  23. #
  24. # Since the kernel launch can include lambda statements, it's important
  25. # to find the correct end-paren of the kernel launch. Doing this with
  26. # pure regex requires recursive regex, which aren't part of the Python
  27. # standard library. To avoid an additional dependency, we build a prefix
  28. # regex that finds the start of a kernel launch, use a paren-matching
  29. # algorithm to find the end of the launch, and then another regex to
  30. # determine if a launch check is present.
  31. # Finds potential starts of kernel launches
  32. kernel_launch_start = re.compile(
  33. r"^.*<<<[^>]+>>>\s*\(", flags=re.MULTILINE
  34. )
  35. # This pattern should start at the character after the final paren of the
  36. # kernel launch. It returns a match if the launch check is not the next statement
  37. has_check = re.compile(
  38. r"\s*;(?![^;}]*C10_CUDA_KERNEL_LAUNCH_CHECK\(\);)", flags=re.MULTILINE
  39. )
  40. def find_matching_paren(s: str, startpos: int) -> int:
  41. """Given a string "prefix (unknown number of characters) suffix"
  42. and the position of the first `(` returns the index of the character
  43. 1 past the `)`, accounting for paren nesting
  44. """
  45. opening = 0
  46. for i, c in enumerate(s[startpos:]):
  47. if c == '(':
  48. opening += 1
  49. elif c == ')':
  50. opening -= 1
  51. if opening == 0:
  52. return startpos + i + 1
  53. raise IndexError("Closing parens not found!")
  54. def should_exclude_file(filename) -> bool:
  55. for exclude_suffix in exclude_files:
  56. if filename.endswith(exclude_suffix):
  57. return True
  58. return False
  59. def check_code_for_cuda_kernel_launches(code, filename=None):
  60. """Checks code for CUDA kernel launches without cuda error checks.
  61. Args:
  62. filename - Filename of file containing the code. Used only for display
  63. purposes, so you can put anything here.
  64. code - The code to check
  65. Returns:
  66. The number of unsafe kernel launches in the code
  67. """
  68. if filename is None:
  69. filename = "##Python Function Call##"
  70. # We break the code apart and put it back together to add
  71. # helpful line numberings for identifying problem areas
  72. code = enumerate(code.split("\n")) # Split by line breaks
  73. code = [f"{lineno}: {linecode}" for lineno, linecode in code] # Number the lines
  74. code = '\n'.join(code) # Put it back together
  75. num_launches_without_checks = 0
  76. for m in kernel_launch_start.finditer(code):
  77. end_paren = find_matching_paren(code, m.end() - 1)
  78. if has_check.match(code, end_paren):
  79. num_launches_without_checks += 1
  80. context = code[m.start():end_paren + 1]
  81. print(f"Missing C10_CUDA_KERNEL_LAUNCH_CHECK in '{filename}'. Context:\n{context}", file=sys.stderr)
  82. return num_launches_without_checks
  83. def check_file(filename):
  84. """Checks a file for CUDA kernel launches without cuda error checks
  85. Args:
  86. filename - File to check
  87. Returns:
  88. The number of unsafe kernel launches in the file
  89. """
  90. if not (filename.endswith(".cu") or filename.endswith(".cuh")):
  91. return 0
  92. if should_exclude_file(filename):
  93. return 0
  94. with open(filename, "r") as fo:
  95. contents = fo.read()
  96. unsafeCount = check_code_for_cuda_kernel_launches(contents, filename)
  97. return unsafeCount
  98. def check_cuda_kernel_launches():
  99. """Checks all pytorch code for CUDA kernel launches without cuda error checks
  100. Returns:
  101. The number of unsafe kernel launches in the codebase
  102. """
  103. torch_dir = os.path.dirname(os.path.realpath(__file__))
  104. torch_dir = os.path.dirname(torch_dir) # Go up to parent torch
  105. torch_dir = os.path.dirname(torch_dir) # Go up to parent caffe2
  106. kernels_without_checks = 0
  107. files_without_checks = []
  108. for root, dirnames, filenames in os.walk(torch_dir):
  109. # `$BASE/build` and `$BASE/torch/include` are generated
  110. # so we don't want to flag their contents
  111. if root == os.path.join(torch_dir, "build") or root == os.path.join(torch_dir, "torch/include"):
  112. # Curtail search by modifying dirnames and filenames in place
  113. # Yes, this is the way to do this, see `help(os.walk)`
  114. dirnames[:] = []
  115. continue
  116. for x in filenames:
  117. filename = os.path.join(root, x)
  118. file_result = check_file(filename)
  119. if file_result > 0:
  120. kernels_without_checks += file_result
  121. files_without_checks.append(filename)
  122. if kernels_without_checks > 0:
  123. count_str = f"Found {kernels_without_checks} instances in " \
  124. f"{len(files_without_checks)} files where kernel " \
  125. "launches didn't have checks."
  126. print(count_str, file=sys.stderr)
  127. print("Files without checks:", file=sys.stderr)
  128. for x in files_without_checks:
  129. print(f"\t{x}", file=sys.stderr)
  130. print(count_str, file=sys.stderr)
  131. return kernels_without_checks
  132. if __name__ == "__main__":
  133. unsafe_launches = check_cuda_kernel_launches()
  134. sys.exit(0 if unsafe_launches == 0 else 1)