hipify_python.py 41 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070
  1. #!/usr/bin/env python3
  2. """ The Python Hipify script.
  3. ##
  4. # Copyright (c) 2015-2016 Advanced Micro Devices, Inc. All rights reserved.
  5. # 2017-2018 Advanced Micro Devices, Inc. and
  6. # Facebook Inc. All rights reserved.
  7. #
  8. # Permission is hereby granted, free of charge, to any person obtaining a copy
  9. # of this software and associated documentation files (the "Software"), to deal
  10. # in the Software without restriction, including without limitation the rights
  11. # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  12. # copies of the Software, and to permit persons to whom the Software is
  13. # furnished to do so, subject to the following conditions:
  14. #
  15. # The above copyright notice and this permission notice shall be included in
  16. # all copies or substantial portions of the Software.
  17. #
  18. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  19. # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  20. # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  21. # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  22. # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  23. # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
  24. # THE SOFTWARE.
  25. """
  26. import argparse
  27. import fnmatch
  28. import re
  29. import shutil
  30. import sys
  31. import os
  32. from . import constants
  33. from .cuda_to_hip_mappings import CUDA_TO_HIP_MAPPINGS
  34. from .cuda_to_hip_mappings import MATH_TRANSPILATIONS
  35. from typing import Dict, List, Iterator, Optional
  36. from collections.abc import Mapping, Iterable
  37. HipifyResult = Dict[str, Optional[str]]
  38. HipifyFinalResult = Dict[str, HipifyResult]
  39. HIPIFY_C_BREADCRUMB = "// !!! This is a file automatically generated by hipify!!!\n"
  40. HIPIFY_FINAL_RESULT: HipifyFinalResult = {}
  41. # Hardcode the PyTorch template map
  42. """This dictionary provides the mapping from PyTorch kernel template types
  43. to their actual types."""
  44. PYTORCH_TEMPLATE_MAP = {"Dtype": "scalar_t", "T": "scalar_t"}
  45. __all__ = ['InputError', 'openf', 'bcolors', 'GeneratedFileCleaner', 'match_extensions', 'matched_files_iter',
  46. 'preprocess_file_and_save_result', 'compute_stats', 'add_dim3', 'processKernelLaunches', 'find_closure_group',
  47. 'find_bracket_group', 'find_parentheses_group', 'replace_math_functions', 'hip_header_magic', 'replace_extern_shared',
  48. 'get_hip_file_path', 'is_out_of_place', 'is_pytorch_file', 'is_cusparse_file', 'is_caffe2_gpu_file',
  49. 'is_caffe2_gpu_file', 'Trie', 'preprocessor', 'file_specific_replacement', 'file_add_header',
  50. 'fix_static_global_kernels', 'extract_arguments', 'str2bool', 'hipify']
  51. class InputError(Exception):
  52. # Exception raised for errors in the input.
  53. def __init__(self, message):
  54. super().__init__(message)
  55. self.message = message
  56. def __str__(self):
  57. return "{}: {}".format("Input error", self.message)
  58. def openf(filename, mode):
  59. return open(filename, mode, errors='ignore')
  60. # Color coding for printing
  61. class bcolors:
  62. HEADER = '\033[95m'
  63. OKBLUE = '\033[94m'
  64. OKGREEN = '\033[92m'
  65. WARNING = '\033[93m'
  66. FAIL = '\033[91m'
  67. ENDC = '\033[0m'
  68. BOLD = '\033[1m'
  69. UNDERLINE = '\033[4m'
  70. # To the programmer, the output of hipify most likely are intermediates.
  71. # This class allows users of hipify to ask for a cleanup by running the
  72. # hipify and compilation in a with instantiating this context manager class
  73. # with keep_intermediates=False.
  74. # The main usecase is the cpp_extensions, specifically the load method.
  75. # It is a good idea to keep intermediates (in case of errors or to
  76. # not recompile unchanged files), but in cases where you don't want to
  77. # keep them (e.g. in the CI), this can be used to remove files.
  78. class GeneratedFileCleaner:
  79. """Context Manager to clean up generated files"""
  80. def __init__(self, keep_intermediates=False):
  81. self.keep_intermediates = keep_intermediates
  82. self.files_to_clean = set()
  83. self.dirs_to_clean = []
  84. def __enter__(self):
  85. return self
  86. def open(self, fn, *args, **kwargs):
  87. if not os.path.exists(fn):
  88. self.files_to_clean.add(os.path.abspath(fn))
  89. return open(fn, *args, **kwargs)
  90. def makedirs(self, dn, exist_ok=False):
  91. parent, n = os.path.split(dn)
  92. if not n:
  93. parent, n = os.path.split(parent)
  94. if parent and n and not os.path.exists(parent):
  95. self.makedirs(parent, exist_ok=True)
  96. if not os.path.isdir(dn) or not exist_ok:
  97. os.mkdir(dn)
  98. self.dirs_to_clean.append(os.path.abspath(dn))
  99. def __exit__(self, type, value, traceback):
  100. if not self.keep_intermediates:
  101. for f in self.files_to_clean:
  102. os.unlink(f)
  103. for d in self.dirs_to_clean[::-1]:
  104. os.rmdir(d)
  105. def match_extensions(filename: str, extensions: Iterable) -> bool:
  106. """Helper method to see if filename ends with certain extension"""
  107. return any(filename.endswith(e) for e in extensions)
  108. def _fnmatch(filepath, patterns):
  109. return any(fnmatch.fnmatch(filepath, pattern) for pattern in patterns)
  110. def matched_files_iter(
  111. root_path: str,
  112. includes: Iterable = (),
  113. ignores: Iterable = (),
  114. extensions: Iterable = (),
  115. out_of_place_only: bool = False,
  116. is_pytorch_extension: bool = False) -> Iterator[str]:
  117. exact_matches = set(includes)
  118. # This is a very rough heuristic; really, we want to avoid scanning
  119. # any file which is not checked into source control, but this script
  120. # needs to work even if you're in a Git or Hg checkout, so easier to
  121. # just block the biggest time sinks that won't matter in the
  122. # end.
  123. for (abs_dirpath, dirs, filenames) in os.walk(root_path, topdown=True):
  124. rel_dirpath = os.path.relpath(abs_dirpath, root_path)
  125. if rel_dirpath == '.':
  126. # Blah blah blah O(n) blah blah
  127. if ".git" in dirs:
  128. dirs.remove(".git")
  129. if "build" in dirs:
  130. dirs.remove("build")
  131. if "third_party" in dirs:
  132. dirs.remove("third_party")
  133. dirs.append("third_party/nvfuser")
  134. for filename in filenames:
  135. filepath = os.path.join(abs_dirpath, filename)
  136. rel_filepath = os.path.join(rel_dirpath, filename)
  137. # We respect extensions, UNLESS you wrote the entire
  138. # filename verbatim, in which case we always accept it
  139. if (
  140. _fnmatch(filepath, includes)
  141. and (not _fnmatch(filepath, ignores))
  142. and (match_extensions(filepath, extensions) or filepath in exact_matches)
  143. ):
  144. if not is_pytorch_extension: # for pytorch extensions, consider all files
  145. if not is_pytorch_file(rel_filepath) and not is_caffe2_gpu_file(rel_filepath):
  146. continue
  147. if out_of_place_only and not is_out_of_place(rel_filepath):
  148. continue
  149. yield filepath
  150. def preprocess_file_and_save_result(
  151. output_directory: str,
  152. filepath: str,
  153. all_files: Iterable,
  154. header_include_dirs: Iterable,
  155. stats: Dict[str, List],
  156. hip_clang_launch: bool,
  157. is_pytorch_extension: bool,
  158. clean_ctx: GeneratedFileCleaner,
  159. show_progress: bool) -> None:
  160. result = preprocessor(output_directory, filepath, all_files, header_include_dirs, stats,
  161. hip_clang_launch, is_pytorch_extension, clean_ctx, show_progress)
  162. fin_path = os.path.abspath(os.path.join(output_directory, filepath))
  163. # Show what happened
  164. if show_progress and "ignored" not in str(result["status"]):
  165. print(
  166. fin_path, "->",
  167. result["hipified_path"], result["status"], flush=True)
  168. HIPIFY_FINAL_RESULT[fin_path] = result
  169. def compute_stats(stats):
  170. unsupported_calls = {cuda_call for (cuda_call, _filepath) in stats["unsupported_calls"]}
  171. # Print the number of unsupported calls
  172. print("Total number of unsupported CUDA function calls: {0:d}".format(len(unsupported_calls)))
  173. # Print the list of unsupported calls
  174. print(", ".join(unsupported_calls))
  175. # Print the number of kernel launches
  176. print("\nTotal number of replaced kernel launches: {0:d}".format(len(stats["kernel_launches"])))
  177. def add_dim3(kernel_string, cuda_kernel):
  178. '''adds dim3() to the second and third arguments in the kernel launch'''
  179. count = 0
  180. closure = 0
  181. kernel_string = kernel_string.replace("<<<", "").replace(">>>", "")
  182. arg_locs: List[Dict[str, int]] = [{} for _ in range(2)]
  183. arg_locs[count]['start'] = 0
  184. for ind, c in enumerate(kernel_string):
  185. if count > 1:
  186. break
  187. if c == "(":
  188. closure += 1
  189. elif c == ")":
  190. closure -= 1
  191. if (c == "," or ind == len(kernel_string) - 1) and closure == 0:
  192. arg_locs[count]['end'] = ind + (c != ",")
  193. count += 1
  194. if count < 2:
  195. arg_locs[count]['start'] = ind + 1
  196. first_arg_raw = kernel_string[arg_locs[0]['start']:arg_locs[0]['end'] + 1]
  197. second_arg_raw = kernel_string[arg_locs[1]['start']:arg_locs[1]['end']]
  198. first_arg_clean = kernel_string[arg_locs[0]['start']:arg_locs[0]['end']].replace("\n", "").strip(" ")
  199. second_arg_clean = kernel_string[arg_locs[1]['start']:arg_locs[1]['end']].replace("\n", "").strip(" ")
  200. first_arg_dim3 = "dim3({})".format(first_arg_clean)
  201. second_arg_dim3 = "dim3({})".format(second_arg_clean)
  202. first_arg_raw_dim3 = first_arg_raw.replace(first_arg_clean, first_arg_dim3)
  203. second_arg_raw_dim3 = second_arg_raw.replace(second_arg_clean, second_arg_dim3)
  204. cuda_kernel = cuda_kernel.replace(first_arg_raw + second_arg_raw, first_arg_raw_dim3 + second_arg_raw_dim3)
  205. return cuda_kernel
  206. RE_KERNEL_LAUNCH = re.compile(r'([ ]+)(detail?)::[ ]+\\\n[ ]+')
  207. def processKernelLaunches(string, stats):
  208. """ Replace the CUDA style Kernel launches with the HIP style kernel launches."""
  209. # Concat the namespace with the kernel names. (Find cleaner way of doing this later).
  210. string = RE_KERNEL_LAUNCH.sub(lambda inp: "{0}{1}::".format(inp.group(1), inp.group(2)), string)
  211. def grab_method_and_template(in_kernel):
  212. # The positions for relevant kernel components.
  213. pos = {
  214. "kernel_launch": {"start": in_kernel["start"], "end": in_kernel["end"]},
  215. "kernel_name": {"start": -1, "end": -1},
  216. "template": {"start": -1, "end": -1}
  217. }
  218. # Count for balancing template
  219. count = {"<>": 0}
  220. # Status for whether we are parsing a certain item.
  221. START = 0
  222. AT_TEMPLATE = 1
  223. AFTER_TEMPLATE = 2
  224. AT_KERNEL_NAME = 3
  225. status = START
  226. # Parse the string character by character
  227. for i in range(pos["kernel_launch"]["start"] - 1, -1, -1):
  228. char = string[i]
  229. # Handle Templating Arguments
  230. if status == START or status == AT_TEMPLATE:
  231. if char == ">":
  232. if status == START:
  233. status = AT_TEMPLATE
  234. pos["template"]["end"] = i
  235. count["<>"] += 1
  236. if char == "<":
  237. count["<>"] -= 1
  238. if count["<>"] == 0 and (status == AT_TEMPLATE):
  239. pos["template"]["start"] = i
  240. status = AFTER_TEMPLATE
  241. # Handle Kernel Name
  242. if status != AT_TEMPLATE:
  243. if string[i].isalnum() or string[i] in {'(', ')', '_', ':', '#'}:
  244. if status != AT_KERNEL_NAME:
  245. status = AT_KERNEL_NAME
  246. pos["kernel_name"]["end"] = i
  247. # Case: Kernel name starts the string.
  248. if i == 0:
  249. pos["kernel_name"]["start"] = 0
  250. # Finished
  251. return [(pos["kernel_name"]), (pos["template"]), (pos["kernel_launch"])]
  252. else:
  253. # Potential ending point if we're already traversing a kernel's name.
  254. if status == AT_KERNEL_NAME:
  255. pos["kernel_name"]["start"] = i
  256. # Finished
  257. return [(pos["kernel_name"]), (pos["template"]), (pos["kernel_launch"])]
  258. def find_kernel_bounds(string):
  259. """Finds the starting and ending points for all kernel launches in the string."""
  260. kernel_end = 0
  261. kernel_positions = []
  262. # Continue until we cannot find any more kernels anymore.
  263. while string.find("<<<", kernel_end) != -1:
  264. # Get kernel starting position (starting from the previous ending point)
  265. kernel_start = string.find("<<<", kernel_end)
  266. # Get kernel ending position (adjust end point past the >>>)
  267. kernel_end = string.find(">>>", kernel_start) + 3
  268. if kernel_end <= 0:
  269. raise InputError("no kernel end found")
  270. # Add to list of traversed kernels
  271. kernel_positions.append({"start": kernel_start, "end": kernel_end,
  272. "group": string[kernel_start: kernel_end]})
  273. return kernel_positions
  274. # Replace comments and string literals from the code so that find_kernel_bounds does not
  275. # wrongly capture kernels in comments and string literals.
  276. # This function replaces them with "x" to keep positions.
  277. def mask_comments(string):
  278. in_comment = ''
  279. prev_c = ''
  280. new_string = ''
  281. for c in string:
  282. if in_comment == '':
  283. # Outside comments
  284. if c == '/' and prev_c == '/':
  285. in_comment = '//'
  286. elif c == '*' and prev_c == '/':
  287. in_comment = '/*'
  288. elif c == '"' and prev_c != '\\' and prev_c != "'":
  289. in_comment = '"'
  290. elif in_comment == '//':
  291. # In // xxx
  292. if c == '\r' or c == '\n':
  293. in_comment = ''
  294. elif in_comment == '/*':
  295. # In /* xxx */
  296. if c == '/' and prev_c == '*':
  297. in_comment = ''
  298. elif in_comment == '"':
  299. # In ""
  300. if c == '"' and prev_c != '\\':
  301. in_comment = ''
  302. prev_c = c
  303. if in_comment == '':
  304. new_string += c
  305. else:
  306. new_string += 'x'
  307. return new_string
  308. # Grab positional ranges of all kernel launches
  309. get_kernel_positions = list(find_kernel_bounds(mask_comments(string)))
  310. output_string = string
  311. # Replace each CUDA kernel with a HIP kernel.
  312. for kernel in get_kernel_positions:
  313. # Get kernel components
  314. params = grab_method_and_template(kernel)
  315. # Find parenthesis after kernel launch
  316. parenthesis = string.find("(", kernel["end"])
  317. # Extract cuda kernel
  318. cuda_kernel = string[params[0]["start"]:parenthesis + 1]
  319. kernel_string = string[kernel['start']:kernel['end']]
  320. end_param_index = 0 if params[1]['end'] == -1 else 1
  321. kernel_name_with_template = string[params[0]['start']:params[end_param_index]['end'] + 1]
  322. cuda_kernel_dim3 = add_dim3(kernel_string, cuda_kernel)
  323. # Keep number of kernel launch params consistent (grid dims, group dims, stream, dynamic shared size)
  324. num_klp = len(extract_arguments(0, kernel["group"].replace("<<<", "(").replace(">>>", ")")))
  325. hip_kernel = "hipLaunchKernelGGL(" + cuda_kernel_dim3[0:-1].replace(
  326. ">>>", ", 0" * (4 - num_klp) + ">>>").replace("<<<", ", ").replace(
  327. ">>>", ", ").replace(kernel_name_with_template, "(" + kernel_name_with_template + ")")
  328. # Replace cuda kernel with hip kernel
  329. output_string = output_string.replace(cuda_kernel, hip_kernel)
  330. # Update the statistics
  331. stats["kernel_launches"].append(hip_kernel)
  332. return output_string
  333. def find_closure_group(input_string, start, group):
  334. """Generalization for finding a balancing closure group
  335. if group = ["(", ")"], then finds the first balanced parentheses.
  336. if group = ["{", "}"], then finds the first balanced bracket.
  337. Given an input string, a starting position in the input string, and the group type,
  338. find_closure_group returns the positions of group[0] and group[1] as a tuple.
  339. Example:
  340. >>> find_closure_group("(hi)", 0, ["(", ")"])
  341. (0, 3)
  342. """
  343. inside_parenthesis = False
  344. parens = 0
  345. pos = start
  346. p_start, p_end = -1, -1
  347. while pos < len(input_string):
  348. if input_string[pos] == group[0]:
  349. if inside_parenthesis is False:
  350. inside_parenthesis = True
  351. parens = 1
  352. p_start = pos
  353. else:
  354. parens += 1
  355. elif input_string[pos] == group[1] and inside_parenthesis:
  356. parens -= 1
  357. if parens == 0:
  358. p_end = pos
  359. return p_start, p_end
  360. pos += 1
  361. return None, None
  362. def find_bracket_group(input_string, start):
  363. """Finds the first balanced parantheses."""
  364. return find_closure_group(input_string, start, group=["{", "}"])
  365. def find_parentheses_group(input_string, start):
  366. """Finds the first balanced bracket."""
  367. return find_closure_group(input_string, start, group=["(", ")"])
  368. RE_ASSERT = re.compile(r"\bassert[ ]*\(")
  369. def replace_math_functions(input_string):
  370. """FIXME: Temporarily replace std:: invocations of math functions
  371. with non-std:: versions to prevent linker errors NOTE: This
  372. can lead to correctness issues when running tests, since the
  373. correct version of the math function (exp/expf) might not get
  374. called. Plan is to remove this function once HIP supports
  375. std:: math function calls inside device code
  376. """
  377. output_string = input_string
  378. for func in MATH_TRANSPILATIONS:
  379. output_string = output_string.replace(r'{}('.format(func), '{}('.format(MATH_TRANSPILATIONS[func]))
  380. return output_string
  381. RE_SYNCTHREADS = re.compile(r":?:?\b(__syncthreads)\b(\w*\()")
  382. def hip_header_magic(input_string):
  383. """If the file makes kernel builtin calls and does not include the cuda_runtime.h header,
  384. then automatically add an #include to match the "magic" includes provided by NVCC.
  385. TODO:
  386. Update logic to ignore cases where the cuda_runtime.h is included by another file.
  387. """
  388. # Copy the input.
  389. output_string = input_string
  390. # Check if one of the following headers is already included.
  391. headers = ["hip/hip_runtime.h", "hip/hip_runtime_api.h"]
  392. if any(re.search(r'#include ("{0}"|<{0}>)'.format(ext), output_string) for ext in headers):
  393. return output_string
  394. # Rough logic to detect if we're inside device code
  395. hasDeviceLogic: int
  396. hasDeviceLogic = "hipLaunchKernelGGL" in output_string
  397. hasDeviceLogic += "__global__" in output_string
  398. hasDeviceLogic += "__shared__" in output_string
  399. hasDeviceLogic += RE_SYNCTHREADS.search(output_string) is not None
  400. # If device logic found, provide the necessary header.
  401. if hasDeviceLogic:
  402. output_string = '#include "hip/hip_runtime.h"\n' + input_string
  403. return output_string
  404. RE_EXTERN_SHARED = re.compile(r"extern\s+([\w\(\)]+)?\s*__shared__\s+([\w:<>\s]+)\s+(\w+)\s*\[\s*\]\s*;")
  405. def replace_extern_shared(input_string):
  406. """Match extern __shared__ type foo[]; syntax and use HIP_DYNAMIC_SHARED() MACRO instead.
  407. https://github.com/ROCm-Developer-Tools/HIP/blob/master/docs/markdown/hip_kernel_language.md#__shared__
  408. Example:
  409. "extern __shared__ char smemChar[];" => "HIP_DYNAMIC_SHARED( char, smemChar)"
  410. "extern __shared__ unsigned char smem[];" => "HIP_DYNAMIC_SHARED( unsigned char, my_smem)"
  411. """
  412. output_string = input_string
  413. output_string = RE_EXTERN_SHARED.sub(
  414. lambda inp: "HIP_DYNAMIC_SHARED({0} {1}, {2})".format(
  415. inp.group(1) or "", inp.group(2), inp.group(3)), output_string)
  416. return output_string
  417. def get_hip_file_path(rel_filepath, is_pytorch_extension=False):
  418. """
  419. Returns the new name of the hipified file
  420. """
  421. # At the moment, some PyTorch source files are HIPified in place. The predicate
  422. # is_out_of_place tells us if this is the case or not.
  423. assert not os.path.isabs(rel_filepath)
  424. if not is_pytorch_extension and not is_out_of_place(rel_filepath):
  425. return rel_filepath
  426. dirpath, filename = os.path.split(rel_filepath)
  427. root, ext = os.path.splitext(filename)
  428. # Here's the plan:
  429. #
  430. # In general, we need to disambiguate the HIPified filename so that
  431. # it gets a different name from the original filename, so
  432. # that we don't overwrite the original file
  433. #
  434. # There's a lot of different naming conventions across PyTorch
  435. # and Caffe2, but the general recipe is to convert occurrences
  436. # of cuda/gpu to hip, and add hip if there are no occurrences
  437. # of cuda/gpu anywhere.
  438. #
  439. # Concretely, we do the following:
  440. #
  441. # - If there is a directory component named "cuda", replace
  442. # it with "hip", AND
  443. #
  444. # - If the file name contains "CUDA", replace it with "HIP", AND
  445. #
  446. # - ALWAYS replace '.cu' with '.hip', because those files
  447. # contain CUDA kernels that needs to be hipified and processed with
  448. # hip compiler
  449. #
  450. # - If we are not hipifying a PyTorch extension, and the parent
  451. # directory name did not change as a result of the above
  452. # transformations, insert "hip" in the file path
  453. # as the direct parent folder of the file
  454. #
  455. # - If we are hipifying a PyTorch extension, and the parent directory
  456. # name as well as the filename (incl. extension) did not change as
  457. # a result of the above transformations, insert "_hip" in the filename
  458. #
  459. # This isn't set in stone; we might adjust this to support other
  460. # naming conventions.
  461. if ext == '.cu':
  462. ext = '.hip'
  463. orig_filename = filename
  464. orig_dirpath = dirpath
  465. dirpath = dirpath.replace('cuda', 'hip')
  466. dirpath = dirpath.replace('CUDA', 'HIP')
  467. dirpath = dirpath.replace('THC', 'THH')
  468. root = root.replace('cuda', 'hip')
  469. root = root.replace('CUDA', 'HIP')
  470. # Special case to handle caffe2/core/THCCachingAllocator
  471. if dirpath != "caffe2/core":
  472. root = root.replace('THC', 'THH')
  473. if not is_pytorch_extension and dirpath == orig_dirpath:
  474. dirpath = os.path.join(dirpath, 'hip')
  475. if is_pytorch_extension and dirpath == orig_dirpath and (root + ext) == orig_filename:
  476. root = root + "_hip"
  477. return os.path.join(dirpath, root + ext)
  478. def is_out_of_place(rel_filepath):
  479. assert not os.path.isabs(rel_filepath)
  480. if rel_filepath.startswith("torch/"):
  481. return False
  482. if rel_filepath.startswith("third_party/nvfuser/"):
  483. return False
  484. if rel_filepath.startswith("tools/autograd/templates/"):
  485. return False
  486. return True
  487. # Keep this synchronized with includes/ignores in build_amd.py
  488. def is_pytorch_file(rel_filepath):
  489. assert not os.path.isabs(rel_filepath)
  490. if rel_filepath.startswith("aten/"):
  491. if rel_filepath.startswith("aten/src/ATen/core/"):
  492. return False
  493. return True
  494. if rel_filepath.startswith("torch/"):
  495. return True
  496. if rel_filepath.startswith("third_party/nvfuser/"):
  497. return True
  498. if rel_filepath.startswith("tools/autograd/templates/"):
  499. return True
  500. return False
  501. def is_cusparse_file(rel_filepath):
  502. if is_pytorch_file(rel_filepath):
  503. return "sparse" in rel_filepath.lower()
  504. return False
  505. def is_caffe2_gpu_file(rel_filepath):
  506. assert not os.path.isabs(rel_filepath)
  507. if rel_filepath.startswith("c10/cuda"):
  508. return True
  509. filename = os.path.basename(rel_filepath)
  510. _, ext = os.path.splitext(filename)
  511. return ('gpu' in filename or ext in ['.cu', '.cuh']) and ('cudnn' not in filename)
  512. # Cribbed from https://stackoverflow.com/questions/42742810/speed-up-millions-of-regex-replacements-in-python-3/42789508#42789508
  513. class Trie():
  514. """Regex::Trie in Python. Creates a Trie out of a list of words. The trie can be exported to a Regex pattern.
  515. The corresponding Regex should match much faster than a simple Regex union."""
  516. def __init__(self):
  517. self.data = {}
  518. def add(self, word):
  519. ref = self.data
  520. for char in word:
  521. ref[char] = char in ref and ref[char] or {}
  522. ref = ref[char]
  523. ref[''] = 1
  524. def dump(self):
  525. return self.data
  526. def quote(self, char):
  527. return re.escape(char)
  528. def _pattern(self, pData):
  529. data = pData
  530. if "" in data and len(data.keys()) == 1:
  531. return None
  532. alt = []
  533. cc = []
  534. q = 0
  535. for char in sorted(data.keys()):
  536. if isinstance(data[char], dict):
  537. try:
  538. recurse = self._pattern(data[char])
  539. alt.append(self.quote(char) + recurse)
  540. except Exception:
  541. cc.append(self.quote(char))
  542. else:
  543. q = 1
  544. cconly = not len(alt) > 0
  545. if len(cc) > 0:
  546. if len(cc) == 1:
  547. alt.append(cc[0])
  548. else:
  549. alt.append('[' + ''.join(cc) + ']')
  550. if len(alt) == 1:
  551. result = alt[0]
  552. else:
  553. result = "(?:" + "|".join(alt) + ")"
  554. if q:
  555. if cconly:
  556. result += "?"
  557. else:
  558. result = "(?:%s)?" % result
  559. return result
  560. def pattern(self):
  561. return self._pattern(self.dump())
  562. CAFFE2_TRIE = Trie()
  563. CAFFE2_MAP = {}
  564. PYTORCH_TRIE = Trie()
  565. PYTORCH_MAP: Dict[str, object] = {}
  566. # In PyTorch, we map cuBLAS->rocBLAS and cuSPARSE->hipSPARSE. Note the prefix, roc versus hip.
  567. # The 'hip' APIs offer a more direct CUDA-friendly mapping, but calling rocBLAS directly has better performance.
  568. # Unfortunately, the roc* types and hip* types differ, i.e., rocblas_float_complex versus hipComplex.
  569. # In the case of SPARSE, we must use the hip types for complex instead of the roc types,
  570. # but the pytorch mappings assume roc. Therefore, we create a new SPARSE mapping that has a higher priority.
  571. # Its mappings will trigger first, and only when a miss occurs will the lower-priority pytorch mapping take place.
  572. # When a file contains "sparse" in the filename, a mapping marked with API_SPARSE is preferred over other choices.
  573. PYTORCH_SPARSE_MAP = {}
  574. for mapping in CUDA_TO_HIP_MAPPINGS:
  575. assert isinstance(mapping, Mapping)
  576. for src, value in mapping.items():
  577. dst = value[0]
  578. meta_data = value[1:]
  579. if constants.API_CAFFE2 not in meta_data:
  580. PYTORCH_TRIE.add(src)
  581. # if src is already in PYTORCH_MAP and dst belongs to API_SPARSE
  582. # do not overwrite PYTORCH_MAP, store dst separately
  583. if constants.API_SPARSE in meta_data and PYTORCH_MAP.get(src, ""):
  584. PYTORCH_SPARSE_MAP[src] = dst
  585. else:
  586. PYTORCH_MAP[src] = dst
  587. if constants.API_PYTORCH not in meta_data:
  588. CAFFE2_TRIE.add(src)
  589. CAFFE2_MAP[src] = dst
  590. RE_CAFFE2_PREPROCESSOR = re.compile(CAFFE2_TRIE.pattern())
  591. RE_PYTORCH_PREPROCESSOR = re.compile(r'(?<=\W)({0})(?=\W)'.format(PYTORCH_TRIE.pattern()))
  592. RE_QUOTE_HEADER = re.compile(r'#include "([^"]+)"')
  593. RE_ANGLE_HEADER = re.compile(r'#include <([^>]+)>')
  594. RE_THC_GENERIC_FILE = re.compile(r'#define THC_GENERIC_FILE "([^"]+)"')
  595. RE_CU_SUFFIX = re.compile(r'\.cu\b') # be careful not to pick up .cuh
  596. """
  597. Returns a dict with the following keys:
  598. "hipified_path" : absolute path of hipified source file
  599. "status" : "ok" if hipified file was written out
  600. "skipped" if an identical hipified file already existed or hipified file couldn't be written out
  601. "ignored" if the source file was a hipified file itself or not meant to be hipified
  602. """
  603. def preprocessor(
  604. output_directory: str,
  605. filepath: str,
  606. all_files: Iterable,
  607. header_include_dirs: Iterable,
  608. stats: Dict[str, List],
  609. hip_clang_launch: bool,
  610. is_pytorch_extension: bool,
  611. clean_ctx: GeneratedFileCleaner,
  612. show_progress: bool) -> HipifyResult:
  613. """ Executes the CUDA -> HIP conversion on the specified file. """
  614. if filepath not in all_files:
  615. return {"hipified_path": None, "status": "[ignored, not to be hipified]"}
  616. fin_path = os.path.abspath(os.path.join(output_directory, filepath))
  617. rel_filepath = os.path.relpath(filepath, output_directory)
  618. with open(fin_path, 'r', encoding='utf-8') as fin:
  619. if fin.readline() == HIPIFY_C_BREADCRUMB:
  620. return {"hipified_path": None, "status": "[ignored, input is hipified output]"}
  621. fin.seek(0)
  622. output_source = fin.read()
  623. orig_output_source = output_source
  624. # get_hip_file_path needs a relative path to work correctly
  625. fout_path = os.path.abspath(os.path.join(output_directory, get_hip_file_path(rel_filepath, is_pytorch_extension)))
  626. if not os.path.exists(os.path.dirname(fout_path)):
  627. clean_ctx.makedirs(os.path.dirname(fout_path))
  628. # unsupported_calls statistics reporting is broken atm
  629. def pt_repl(m):
  630. return PYTORCH_MAP[m.group(0)]
  631. def pt_sparse_repl(m):
  632. # checks SPARSE map first, and if a miss occurs, falls back to pytorch mappings
  633. return PYTORCH_SPARSE_MAP.get(m.group(0), pt_repl(m))
  634. if is_pytorch_extension:
  635. output_source = RE_PYTORCH_PREPROCESSOR.sub(pt_repl, output_source)
  636. else:
  637. if is_cusparse_file(rel_filepath):
  638. output_source = RE_PYTORCH_PREPROCESSOR.sub(pt_sparse_repl, output_source)
  639. elif is_pytorch_file(rel_filepath):
  640. output_source = RE_PYTORCH_PREPROCESSOR.sub(pt_repl, output_source)
  641. else:
  642. def c2_repl(m):
  643. return CAFFE2_MAP[m.group(0)]
  644. output_source = RE_CAFFE2_PREPROCESSOR.sub(c2_repl, output_source)
  645. # Header rewrites
  646. def mk_repl(templ, include_current_dir=True):
  647. def repl(m):
  648. f = m.group(1)
  649. dirpath, filename = os.path.split(f)
  650. if (
  651. f.startswith("ATen/cuda")
  652. or f.startswith("ATen/native/cuda")
  653. or f.startswith("ATen/native/nested/cuda")
  654. or f.startswith("ATen/native/quantized/cuda")
  655. or f.startswith("ATen/native/sparse/cuda")
  656. or f.startswith("ATen/native/transformers/cuda")
  657. or f.startswith("THC/")
  658. or (f.startswith("THC") and not f.startswith("THCP"))
  659. ):
  660. return templ.format(get_hip_file_path(m.group(1), is_pytorch_extension))
  661. # if filename is one of the files being hipified for this extension
  662. if (is_pytorch_extension and any(s.endswith(filename) for s in all_files)):
  663. header_dir = None
  664. header_filepath = None
  665. # If include_current_dir True, look first in same dir as the including source file
  666. if include_current_dir:
  667. header_dir_to_check = os.path.dirname(fin_path)
  668. header_path_to_check = os.path.abspath(os.path.join(header_dir_to_check, f))
  669. if os.path.exists(header_path_to_check):
  670. header_dir = header_dir_to_check
  671. header_filepath = header_path_to_check
  672. # If not found, look in include dirs one by one and first match wins
  673. if header_filepath is None:
  674. for header_include_dir in header_include_dirs:
  675. header_dir_to_check = os.path.join(output_directory, header_include_dir)
  676. header_path_to_check = os.path.abspath(os.path.join(header_dir_to_check, f))
  677. if os.path.exists(header_path_to_check):
  678. header_dir = header_dir_to_check
  679. header_filepath = header_path_to_check
  680. # If header file not found, keep as is
  681. if header_filepath is None:
  682. return m.group(0)
  683. # Hipify header file first if needed
  684. if header_filepath not in HIPIFY_FINAL_RESULT:
  685. preprocess_file_and_save_result(output_directory,
  686. header_filepath,
  687. all_files, header_include_dirs, stats, hip_clang_launch,
  688. is_pytorch_extension, clean_ctx, show_progress)
  689. hipified_header_filepath = HIPIFY_FINAL_RESULT[header_filepath]["hipified_path"]
  690. return templ.format(os.path.relpath(hipified_header_filepath if hipified_header_filepath is not None
  691. else header_filepath, header_dir))
  692. return m.group(0)
  693. return repl
  694. output_source = RE_QUOTE_HEADER.sub(mk_repl('#include "{0}"', True), output_source)
  695. output_source = RE_ANGLE_HEADER.sub(mk_repl('#include <{0}>', False), output_source)
  696. output_source = RE_THC_GENERIC_FILE.sub(mk_repl('#define THC_GENERIC_FILE "{0}"'), output_source)
  697. # CMakeLists.txt rewrites
  698. if filepath.endswith('CMakeLists.txt'):
  699. output_source = output_source.replace('CUDA', 'HIP')
  700. output_source = output_source.replace('THC', 'THH')
  701. output_source = RE_CU_SUFFIX.sub('.hip', output_source)
  702. # Perform Kernel Launch Replacements
  703. if not hip_clang_launch:
  704. output_source = processKernelLaunches(output_source, stats)
  705. # Replace std:: with non-std:: versions
  706. if (filepath.endswith(".cu") or filepath.endswith(".cuh")) and "PowKernel" not in filepath:
  707. output_source = replace_math_functions(output_source)
  708. # Include header if device code is contained.
  709. output_source = hip_header_magic(output_source)
  710. # Replace the extern __shared__
  711. # NOTE: No longer needed after transition from hcc to hipclang.
  712. # output_source = replace_extern_shared(output_source)
  713. # Don't write out identical hipified files for extensions if dirpath has not changed
  714. if (
  715. is_pytorch_extension
  716. and orig_output_source == output_source
  717. and os.path.dirname(fin_path) == os.path.dirname(fout_path)
  718. ):
  719. return {"hipified_path": fin_path, "status": "[skipped, no changes]"}
  720. # Add hipify breadcrumb for C-style files to avoid re-hipification
  721. if fin_path != fout_path and match_extensions(fin_path, (".cu", ".cuh", ".c", ".cc", ".cpp", ".h", ".hpp")):
  722. output_source = HIPIFY_C_BREADCRUMB + output_source
  723. do_write = True
  724. if os.path.exists(fout_path):
  725. with open(fout_path, 'r', encoding='utf-8') as fout_old:
  726. do_write = fout_old.read() != output_source
  727. if do_write:
  728. try:
  729. with clean_ctx.open(fout_path, 'w', encoding='utf-8') as fout:
  730. fout.write(output_source)
  731. return {"hipified_path": fout_path, "status": "[ok]"}
  732. except PermissionError as e:
  733. print(f"{bcolors.WARNING}Failed to save {fout_path} with \"{e.strerror}\", leaving {fin_path} unchanged.{bcolors.ENDC}",
  734. file=sys.stderr)
  735. return {"hipified_path": fin_path, "status": "[skipped, no permissions]"}
  736. else:
  737. return {"hipified_path": fout_path, "status": "[skipped, already hipified]"}
  738. def file_specific_replacement(filepath, search_string, replace_string, strict=False):
  739. with openf(filepath, "r+") as f:
  740. contents = f.read()
  741. if strict:
  742. contents = re.sub(r'\b({0})\b'.format(re.escape(search_string)), lambda x: replace_string, contents)
  743. else:
  744. contents = contents.replace(search_string, replace_string)
  745. f.seek(0)
  746. f.write(contents)
  747. f.truncate()
  748. def file_add_header(filepath, header):
  749. with openf(filepath, "r+") as f:
  750. contents = f.read()
  751. if header[0] != "<" and header[-1] != ">":
  752. header = '"{0}"'.format(header)
  753. contents = ('#include {0} \n'.format(header)) + contents
  754. f.seek(0)
  755. f.write(contents)
  756. f.truncate()
  757. def fix_static_global_kernels(in_txt):
  758. """Static global kernels in HIP results in a compilation error."""
  759. in_txt = in_txt.replace(" __global__ static", "__global__")
  760. return in_txt
  761. RE_INCLUDE = re.compile(r"#include .*\n")
  762. def extract_arguments(start, string):
  763. """ Return the list of arguments in the upcoming function parameter closure.
  764. Example:
  765. string (input): '(blocks, threads, 0, THCState_getCurrentStream(state))'
  766. arguments (output):
  767. '[{'start': 1, 'end': 7},
  768. {'start': 8, 'end': 16},
  769. {'start': 17, 'end': 19},
  770. {'start': 20, 'end': 53}]'
  771. """
  772. arguments = []
  773. closures = {
  774. "<": 0,
  775. "(": 0
  776. }
  777. current_position = start
  778. argument_start_pos = current_position + 1
  779. # Search for final parenthesis
  780. while current_position < len(string):
  781. if string[current_position] == "(":
  782. closures["("] += 1
  783. elif string[current_position] == ")":
  784. closures["("] -= 1
  785. elif string[current_position] == "<":
  786. closures["<"] += 1
  787. elif string[current_position] == ">" and string[current_position - 1] != "-" and closures["<"] > 0:
  788. closures["<"] -= 1
  789. # Finished all arguments
  790. if closures["("] == 0 and closures["<"] == 0:
  791. # Add final argument
  792. arguments.append({"start": argument_start_pos, "end": current_position})
  793. break
  794. # Finished current argument
  795. if closures["("] == 1 and closures["<"] == 0 and string[current_position] == ",":
  796. arguments.append({"start": argument_start_pos, "end": current_position})
  797. argument_start_pos = current_position + 1
  798. current_position += 1
  799. return arguments
  800. def str2bool(v):
  801. """ArgumentParser doesn't support type=bool. Thus, this helper method will convert
  802. from possible string types to True / False."""
  803. if v.lower() in ('yes', 'true', 't', 'y', '1'):
  804. return True
  805. elif v.lower() in ('no', 'false', 'f', 'n', '0'):
  806. return False
  807. else:
  808. raise argparse.ArgumentTypeError('Boolean value expected.')
  809. def hipify(
  810. project_directory: str,
  811. show_detailed: bool = False,
  812. extensions: Iterable = (".cu", ".cuh", ".c", ".cc", ".cpp", ".h", ".in", ".hpp"),
  813. header_extensions: Iterable = (".cuh", ".h", ".hpp"),
  814. output_directory: str = "",
  815. header_include_dirs: Iterable = (),
  816. includes: Iterable = ('*',),
  817. extra_files: Iterable = (),
  818. out_of_place_only: bool = False,
  819. ignores: Iterable = (),
  820. show_progress: bool = True,
  821. hip_clang_launch: bool = False,
  822. is_pytorch_extension: bool = False,
  823. hipify_extra_files_only: bool = False,
  824. clean_ctx: Optional[GeneratedFileCleaner] = None
  825. ) -> HipifyFinalResult:
  826. if project_directory == "":
  827. project_directory = os.getcwd()
  828. # Verify the project directory exists.
  829. if not os.path.exists(project_directory):
  830. print("The project folder specified does not exist.")
  831. sys.exit(1)
  832. # If no output directory, provide a default one.
  833. if not output_directory:
  834. project_directory.rstrip("/")
  835. output_directory = project_directory + "_amd"
  836. if project_directory != output_directory:
  837. includes = [include.replace(project_directory, output_directory) for include in includes]
  838. ignores = [ignore.replace(project_directory, output_directory) for ignore in ignores]
  839. # Copy from project directory to output directory if not done already.
  840. if not os.path.exists(output_directory):
  841. shutil.copytree(project_directory, output_directory)
  842. all_files = list(matched_files_iter(output_directory, includes=includes,
  843. ignores=ignores, extensions=extensions,
  844. out_of_place_only=out_of_place_only,
  845. is_pytorch_extension=is_pytorch_extension))
  846. all_files_set = set(all_files)
  847. for f in extra_files:
  848. if not os.path.isabs(f):
  849. f = os.path.join(output_directory, f)
  850. if f not in all_files_set:
  851. all_files.append(f)
  852. # List all files in header_include_paths to ensure they are hipified
  853. from pathlib import Path
  854. for header_include_dir in header_include_dirs:
  855. if os.path.isabs(header_include_dir):
  856. header_include_dir_path = Path(header_include_dir)
  857. else:
  858. header_include_dir_path = Path(os.path.join(output_directory, header_include_dir))
  859. for path in header_include_dir_path.rglob('*'):
  860. if (
  861. path.is_file()
  862. and _fnmatch(str(path), includes)
  863. and (not _fnmatch(str(path), ignores))
  864. and match_extensions(path.name, header_extensions)
  865. ):
  866. all_files.append(str(path))
  867. if clean_ctx is None:
  868. clean_ctx = GeneratedFileCleaner(keep_intermediates=True)
  869. # Preprocessing statistics.
  870. stats: Dict[str, List] = {"unsupported_calls": [], "kernel_launches": []}
  871. for filepath in (all_files if not hipify_extra_files_only else extra_files):
  872. preprocess_file_and_save_result(output_directory, filepath, all_files, header_include_dirs,
  873. stats, hip_clang_launch, is_pytorch_extension, clean_ctx, show_progress)
  874. print(bcolors.OKGREEN + "Successfully preprocessed all matching files." + bcolors.ENDC, file=sys.stderr)
  875. # Show detailed summary
  876. if show_detailed:
  877. compute_stats(stats)
  878. return HIPIFY_FINAL_RESULT