_pattern_matcher.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658
  1. import json
  2. import math
  3. import os
  4. import re
  5. from typing import Dict, List, Optional, Set
  6. import torch
  7. from torch.profiler import profile
  8. import torch.utils.benchmark as benchmark
  9. from torch.profiler._utils import index_of_first_match, traverse_bfs, traverse_dfs
  10. from torch._C._profiler import (_ProfilerEvent, _ExtraFields_TorchOp,
  11. _ExtraFields_PyCCall, _ExtraFields_PyCall,
  12. _EventType)
  13. class Pattern:
  14. '''
  15. Base class for all patterns, subclass this class and implement match()
  16. to define custom patterns.
  17. In subclass, define description and skip property.
  18. '''
  19. def __init__(self, prof: profile, should_benchmark: bool = False):
  20. self.prof = prof
  21. self.should_benchmark = should_benchmark
  22. self.name = "Please specify a name for pattern"
  23. self.description = "Please specify a description for pattern"
  24. self.url = ""
  25. assert prof.profiler is not None and prof.profiler.kineto_results is not None
  26. self.event_tree = prof.profiler.kineto_results.experimental_event_tree(
  27. )
  28. self.tid_root: Dict[int, List[_ProfilerEvent]] = {}
  29. for event in self.event_tree:
  30. self.tid_root.setdefault(event.start_tid, []).append(event)
  31. @property
  32. def skip(self):
  33. return False
  34. def report(self, event: _ProfilerEvent):
  35. msg = f"{self.description}\n[Source Code Location] {source_code_location(event)}"
  36. return msg
  37. def eventTreeTraversal(self):
  38. '''
  39. Traverse the event tree and yield all events.
  40. Override this method in subclass to customize the traversal.
  41. '''
  42. yield from traverse_dfs(self.event_tree)
  43. def summary(self, events: List[_ProfilerEvent]):
  44. default_summary = f"{self.name}: {len(events)} events matched."
  45. if self.should_benchmark:
  46. # If benchmark summary is not empty, use it.
  47. return self.benchmark_summary(
  48. events) if hasattr( # type: ignore[attr-defined]
  49. self, 'benchmark') else default_summary
  50. return default_summary
  51. def benchmark_summary(self, events: List[_ProfilerEvent]):
  52. def format_time(time_ns: int):
  53. unit_lst = ["ns", "us", "ms"]
  54. for unit in unit_lst:
  55. if time_ns < 1000:
  56. return f"{time_ns:.2f} {unit}"
  57. time_ns //= 1000
  58. return f"{time_ns:.2f} s"
  59. assert hasattr(self, 'benchmark'), 'Please implement benchmark()'
  60. shapes_factor_map = self.benchmark( # type: ignore[attr-defined]
  61. events)
  62. original_time = sum(event.duration_time_ns for event in events)
  63. new_time = sum(shapes_factor_map[input_shapes(event)] *
  64. event.duration_time_ns for event in events)
  65. return (
  66. f"{self.name}: {len(events)} events matched. "
  67. f"Total Estimated Speedup: {format_time(original_time - new_time)} ({round(original_time/new_time, 2)}X)"
  68. )
  69. def match(self, event: _ProfilerEvent):
  70. '''
  71. Return True if the event matches the pattern.
  72. This method should be overriden in subclass.
  73. '''
  74. raise NotImplementedError
  75. def matched_events(self):
  76. if self.skip:
  77. return []
  78. matched_events = []
  79. for event in self.eventTreeTraversal():
  80. if self.match(event):
  81. matched_events.append(event)
  82. return matched_events
  83. def root_of(self, event: _ProfilerEvent):
  84. while event.parent:
  85. event = event.parent
  86. return event
  87. def siblings_of(self, event: _ProfilerEvent):
  88. if event.parent:
  89. children = event.parent.children
  90. else:
  91. children = self.tid_root[event.start_tid]
  92. index = children.index(event)
  93. return children[:index], children[index + 1:]
  94. def next_of(self, event: _ProfilerEvent):
  95. _, next_events = self.siblings_of(event)
  96. return next_events[0] if next_events else None
  97. def prev_of(self, event: _ProfilerEvent):
  98. prev_events, _ = self.siblings_of(event)
  99. return prev_events[-1] if prev_events else None
  100. def go_up_until(self, event: _ProfilerEvent, predicate):
  101. if not event:
  102. return None
  103. while event.parent and not predicate(event):
  104. event = event.parent
  105. return event
  106. # Patterns
  107. class NamePattern(Pattern):
  108. def __init__(self,
  109. prof: profile,
  110. name: str,
  111. should_benchmark: bool = False):
  112. super().__init__(prof, should_benchmark)
  113. self.description = f"Matched Name Event: {name}"
  114. self.name = name
  115. def match(self, event: _ProfilerEvent):
  116. return re.search(self.name, event.name) is not None
  117. class ExtraCUDACopyPattern(Pattern):
  118. '''
  119. This pattern identifies if we creates a constant tensor on CPU and immediately moves it to GPU.
  120. example: torch.zeros((100, 100)).to("cuda")
  121. Pattern:
  122. build-in method |build-in method
  123. ... | aten::to
  124. aten::fill_/aten::zero_ | aten::_to_copy
  125. Algorithm:
  126. We start at node aten::to, go parent events' previous events,
  127. and check if we have a aten::fill_/aten::zero_ as we keep going down the tree.
  128. We always select the last child in the children list when we go down the tree.
  129. If at any step we failed, it is not a match.
  130. '''
  131. def __init__(self, prof: profile, should_benchmark: bool = False):
  132. super().__init__(prof, should_benchmark)
  133. self.name = "Extra CUDA Copy Pattern"
  134. self.description = "Filled a CPU tensor and immediately moved it to GPU. Please initialize it on GPU."
  135. self.url = "https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html#create-tensors-directly-on-the-target-device"
  136. self.init_ops = {
  137. "aten::fill_", "aten::zero_", "aten::normal_", "aten::uniform_"
  138. }
  139. @property
  140. def skip(self):
  141. return not self.prof.with_stack or not self.prof.record_shapes
  142. def match(self, event):
  143. # TODO: We should also check tensor identities
  144. if event.name != "aten::to":
  145. return False
  146. to_event = event
  147. if not event.children:
  148. return False
  149. event = event.children[-1]
  150. if event.name != "aten::_to_copy":
  151. return False
  152. if not event.children:
  153. return False
  154. event = event.children[-1]
  155. if event.name != "aten::copy_":
  156. return False
  157. # aten::copy_ should have the first 2 args dtype the same
  158. dtypes = input_dtypes(event)
  159. if len(dtypes) < 2:
  160. return False
  161. if dtypes[0] is None or dtypes[0] != dtypes[1]:
  162. return False
  163. event = to_event
  164. # Up one level
  165. event = event.parent
  166. if event is None:
  167. return False
  168. # Check if we have a aten::fill_ in previous leaf
  169. event = self.prev_of(event)
  170. if event is None:
  171. return False
  172. while event.children:
  173. event = event.children[-1]
  174. # aten::zero_ is a special optimzation case where fill_ is not called
  175. if event.name in self.init_ops:
  176. return True
  177. return event.name in self.init_ops
  178. # TODO: Check if tensor is reused
  179. def benchmark(self, events: List[_ProfilerEvent]):
  180. shapes_factor_map = {input_shapes(event): 0.0 for event in events}
  181. for shape in shapes_factor_map:
  182. size = shape[0]
  183. to_timer = benchmark.Timer(stmt='torch.ones(size).to("cuda")',
  184. globals={'size': size})
  185. de_timer = benchmark.Timer(stmt='torch.ones(size, device="cuda")',
  186. globals={'size': size})
  187. to_time = to_timer.timeit(10).mean
  188. de_time = de_timer.timeit(10).mean
  189. shapes_factor_map[shape] = de_time / to_time
  190. return shapes_factor_map
  191. class ForLoopIndexingPattern(Pattern):
  192. '''
  193. This pattern identifies if we use a for loop to index a tensor that
  194. can be vectorized.
  195. example:
  196. tensor = torch.empty((100, 100))
  197. for i in range(100):
  198. tensor[i] = i
  199. Pattern:
  200. aten::select | ... | aten::select | ... (Repeat)
  201. Algorithm:
  202. We start at node aten::select, and we check if we can find this alternating patterns.
  203. We also keep a dictionary to avoid duplicate match in the for loop.
  204. '''
  205. def __init__(self, prof: profile, should_benchmark: bool = False):
  206. super().__init__(prof, should_benchmark)
  207. self.name = "For Loop Indexing Pattern"
  208. self.description = "For loop indexing detected. Vectorization recommended."
  209. self.visited: Set[int] = set()
  210. def eventTreeTraversal(self):
  211. '''
  212. We need to use BFS traversal order to avoid duplicate match.
  213. '''
  214. yield from traverse_bfs(self.event_tree)
  215. def match(self, event: _ProfilerEvent):
  216. if event.name != "aten::select":
  217. return False
  218. if event.id in self.visited:
  219. return False
  220. repeat_count = 1
  221. _, next = self.siblings_of(event)
  222. if len(next) <= 1:
  223. return False
  224. # Custom event list matching
  225. def same_ops(list1, list2):
  226. if len(list1) != len(list2):
  227. return False
  228. for op1, op2 in zip(list1, list2):
  229. if op1.name != op2.name:
  230. return False
  231. return True
  232. # Record the ops between two aten::select
  233. next_select_idx = index_of_first_match(
  234. next, lambda e: e.name == "aten::select")
  235. if next_select_idx is None:
  236. return False
  237. indexing_ops = [event] + next[:next_select_idx]
  238. next = next[len(indexing_ops) - 1:]
  239. for i in range(0, len(next), len(indexing_ops)):
  240. if same_ops(indexing_ops, next[i:i + len(indexing_ops)]):
  241. repeat_count += 1
  242. self.visited.add(next[i].id)
  243. else:
  244. break
  245. return repeat_count >= 10
  246. class FP32MatMulPattern(Pattern):
  247. def __init__(self, prof: profile, should_benchmark: bool = False):
  248. super().__init__(prof, should_benchmark)
  249. self.name = "FP32 MatMul Pattern"
  250. self.description = (
  251. "You are currently using GPU that supports TF32. "
  252. "Please enable TF32 by setting 'torch.backends.cuda.matmul.allow_tf32 = True'"
  253. )
  254. self.url = "https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
  255. @property
  256. def skip(self):
  257. if torch.version.hip is not None:
  258. has_tf32 = False
  259. else:
  260. # Anything less than sm_80 is not Ampere which doesn't support TF32
  261. has_tf32 = all(
  262. int(arch[3:]) >= 80 for arch in torch.cuda.get_arch_list())
  263. return has_tf32 is False or super().skip or not self.prof.record_shapes
  264. def match(self, event: _ProfilerEvent):
  265. # If we saw this pattern once, we don't need to match it again
  266. if event.tag != _EventType.TorchOp:
  267. return False
  268. assert isinstance(event.extra_fields, _ExtraFields_TorchOp)
  269. if event.name == "aten::mm":
  270. if event.extra_fields.allow_tf32_cublas is False:
  271. return True
  272. return False
  273. def report(self, event: _ProfilerEvent):
  274. return self.description
  275. def benchmark(self, events: List[_ProfilerEvent]):
  276. shapes_factor_map = {input_shapes(event): 0.0 for event in events}
  277. for shape in shapes_factor_map:
  278. matrixA = torch.randn(shape[0], device="cuda", dtype=torch.float32)
  279. matrixB = torch.randn(shape[1], device="cuda", dtype=torch.float32)
  280. fp32_timer = benchmark.Timer(stmt='torch.mm(matrixA, matrixB)',
  281. globals={
  282. "matrixA": matrixA,
  283. "matrixB": matrixB
  284. })
  285. tf32_timer = benchmark.Timer(
  286. stmt='torch.mm(matrixA, matrixB)',
  287. setup='torch.backends.cuda.matmul.allow_tf32 = True',
  288. globals={
  289. "matrixA": matrixA,
  290. "matrixB": matrixB
  291. })
  292. torch.backends.cuda.matmul.allow_tf32 = False
  293. fp32_time = fp32_timer.timeit(10).mean
  294. tf32_time = tf32_timer.timeit(10).mean
  295. shapes_factor_map[shape] = tf32_time / fp32_time
  296. return shapes_factor_map
  297. class OptimizerSingleTensorPattern(Pattern):
  298. '''
  299. This pattern identifies if we are using the single-tensor version of an optimizer.
  300. example:
  301. optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
  302. By adding foreach=True to enable multi-tensor optimizer, we can gain speedup when
  303. the kernels are relatively small.
  304. Pattern:
  305. XXXXX: _single_tenser_<OPTIMIZER_NAME>
  306. Algorithm:
  307. String match
  308. '''
  309. def __init__(self, prof: profile, should_benchmark: bool = False):
  310. super().__init__(prof, should_benchmark)
  311. self.name = "Optimizer Single Tensor Pattern"
  312. self.optimizers_with_foreach = ["adam", "sgd", "adamw"]
  313. self.description = (
  314. "Deteced optimizer running with single tensor implementation. "
  315. "Please enable multi tensor implementation by passing 'foreach=True' into optimizer."
  316. )
  317. self.url = ""
  318. def match(self, event: _ProfilerEvent):
  319. for optimizer in self.optimizers_with_foreach:
  320. if event.name.endswith(f"_single_tensor_{optimizer}"):
  321. return True
  322. return False
  323. class SynchronizedDataLoaderPattern(Pattern):
  324. '''
  325. This pattern identifies if we are using num_workers=0 in DataLoader.
  326. example:
  327. torch.utils.data.DataLoader(dataset, batch_size=batch_size)
  328. Add num_workers=N to the arguments. N depends on system configuration.
  329. Pattern:
  330. dataloader.py(...): __iter__
  331. dataloader.py(...): _get_iterator
  332. NOT dataloader.py(...): check_worker_number_rationality
  333. Algorithm:
  334. If we don't see check_worker_number_rationality call in the dataloader __iter__,
  335. It is not an asynchronous dataloader.
  336. '''
  337. def __init__(self, prof: profile, should_benchmark: bool = False):
  338. super().__init__(prof, should_benchmark)
  339. self.name = "Synchronized DataLoader Pattern"
  340. self.description = (
  341. "Detected DataLoader running with synchronized implementation. "
  342. "Please enable asynchronous dataloading by setting num_workers > 0 when initializing DataLoader."
  343. )
  344. self.url = (
  345. "https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html"
  346. "#enable-async-data-loading-and-augmentation")
  347. def match(self, event: _ProfilerEvent):
  348. def is_dataloader_function(name: str, function_name: str):
  349. return name.startswith(
  350. os.path.join("torch", "utils", "data",
  351. "dataloader.py")) and name.endswith(function_name)
  352. # TODO: fixme! Due to lifetime issues of the function name, this field might
  353. # actually point to an already freed string when the even is a PyCall.
  354. # Just silently skip this to unblock testing.
  355. try:
  356. event.name
  357. except UnicodeDecodeError:
  358. return False
  359. if not is_dataloader_function(event.name, "__iter__"):
  360. return False
  361. if not event.children:
  362. return False
  363. event = event.children[0]
  364. if not is_dataloader_function(event.name, "_get_iterator"):
  365. return False
  366. if not event.children:
  367. return False
  368. event = event.children[0]
  369. return not is_dataloader_function(event.name,
  370. "check_worker_number_rationality")
  371. # TODO: We should also check if the loader is bottleneck.
  372. class GradNotSetToNonePattern(Pattern):
  373. '''
  374. This pattern identifies if we are not setting grad to None in zero_grad.
  375. example:
  376. optimizer.zero_grad()
  377. By setting set_to_none=True, we can gain speedup
  378. Pattern:
  379. XXXXX: _zero_grad
  380. NOT aten::zeros
  381. aten::zero_
  382. aten::zero_ is called on each parameter in the model.
  383. We also want to make sure it is not called by aten::zeros.
  384. Algorithm:
  385. String match
  386. '''
  387. def __init__(self, prof: profile, should_benchmark: bool = False):
  388. super().__init__(prof, should_benchmark)
  389. self.name = "Gradient Set To Zero Instead of None Pattern"
  390. self.description = (
  391. "Detected gradient set to zero instead of None. "
  392. "Please add 'set_to_none=True' when calling zero_grad().")
  393. self.url = (
  394. "https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html"
  395. "#disable-gradient-calculation-for-validation-or-inference")
  396. def match(self, event: _ProfilerEvent):
  397. if not event.name.endswith(": zero_grad"):
  398. return False
  399. if not event.children:
  400. return False
  401. for sub_event in traverse_dfs(event.children):
  402. if sub_event.name == "aten::zero_" and sub_event.parent.name != "aten::zeros":
  403. return True
  404. # TODO: We should also check if the optimizer's numerical behavior will change.
  405. return False
  406. class Conv2dBiasFollowedByBatchNorm2dPattern(Pattern):
  407. '''
  408. This pattern identifies if we are enabling bias in Conv2d which is followed by BatchNorm2d.
  409. Bias doesn't do anything when followed by batchnorm.
  410. Pattern:
  411. nn.Module: Conv2d | nn.Module: BatchNorm2d
  412. ...
  413. aten::conv2d AND dtype of third argument is not null
  414. The third argument is the bias
  415. Algorithm:
  416. String match
  417. '''
  418. def __init__(self, prof: profile, should_benchmark: bool = False):
  419. super().__init__(prof, should_benchmark)
  420. self.name = "Enabling Bias in Conv2d Followed By BatchNorm Pattern"
  421. self.description = "Detected bias enabled in Conv2d that is followed by BatchNorm2d. Please set 'bias=False' in Conv2d."
  422. self.url = (
  423. "https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html"
  424. "#disable-bias-for-convolutions-directly-followed-by-a-batch-norm")
  425. @property
  426. def skip(self):
  427. return self.prof.record_shapes is False or super().skip
  428. def match(self, event: _ProfilerEvent):
  429. if event.name != "aten::conv2d":
  430. return False
  431. if len(input_dtypes(event)) < 3 or input_dtypes(event)[2] is None:
  432. return False
  433. # This means bias=True
  434. event = self.go_up_until(
  435. event, lambda e: e.name.startswith("nn.Module: Conv2d"))
  436. if not event:
  437. return False
  438. event = self.next_of(event)
  439. if not event:
  440. return False
  441. return event.name.startswith("nn.Module: BatchNorm2d")
  442. class MatMulDimInFP16Pattern(Pattern):
  443. def __init__(self, prof: profile, should_benchmark: bool = False):
  444. super().__init__(prof, should_benchmark)
  445. self.name = "Matrix Multiplication Dimension Not Aligned Pattern"
  446. self.description = "Detected matmul with dimension not aligned. Please use matmul with aligned dimension."
  447. self.url = "https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html#use-mixed-precision-and-amp"
  448. @property
  449. def skip(self):
  450. return not self.prof.with_stack or not self.prof.record_shapes
  451. def match(self, event: _ProfilerEvent):
  452. def mutiple_of(shapes, multiple):
  453. return all(dim % multiple == 0 for shape in shapes
  454. for dim in shape[-2:])
  455. if event.name not in ("aten::mm", "aten::bmm", "aten::addmm"):
  456. return False
  457. if not input_dtypes(event):
  458. return False
  459. arg_dtype = input_dtypes(event)[0]
  460. if arg_dtype in (torch.bfloat16, torch.half) and not mutiple_of(input_shapes(event), 8):
  461. return True
  462. return False
  463. def benchmark(self, events: List[_ProfilerEvent]):
  464. def closest_multiple(shapes, multiple):
  465. return [multiple * math.ceil(shape / multiple) for shape in shapes]
  466. shapes_factor_map = {input_shapes(event): 0.0 for event in events}
  467. for shape in shapes_factor_map:
  468. matrixA = torch.randn(shape[0], device="cuda", dtype=torch.float16)
  469. matrixB = torch.randn(shape[1], device="cuda", dtype=torch.float16)
  470. not_aligned_dim_timer = benchmark.Timer(
  471. stmt='torch.mm(matrixA, matrixB)',
  472. globals={
  473. "matrixA": matrixA,
  474. "matrixB": matrixB
  475. })
  476. matrixA = torch.randn(closest_multiple(shape[0], 8),
  477. device="cuda",
  478. dtype=torch.float16)
  479. matrixB = torch.randn(closest_multiple(shape[1], 8),
  480. device="cuda",
  481. dtype=torch.float16)
  482. aligned_dim_timer = benchmark.Timer(
  483. stmt='torch.mm(matrixA, matrixB)',
  484. globals={
  485. "matrixA": matrixA,
  486. "matrixB": matrixB
  487. })
  488. not_aligned_dim_time = not_aligned_dim_timer.timeit(10).mean
  489. aligned_dim_time = aligned_dim_timer.timeit(10).mean
  490. shapes_factor_map[shape] = aligned_dim_time / not_aligned_dim_time
  491. return shapes_factor_map
  492. def source_code_location(event: Optional[_ProfilerEvent]):
  493. while event:
  494. if event.tag == _EventType.PyCall or event.tag == _EventType.PyCCall:
  495. assert isinstance(event.extra_fields,
  496. _ExtraFields_PyCall) or isinstance(
  497. event.extra_fields, _ExtraFields_PyCCall)
  498. if not event.extra_fields.caller.file_name.startswith("torch" +
  499. os.sep):
  500. return f"{event.extra_fields.caller.file_name}:{event.extra_fields.caller.line_number}"
  501. event = event.parent
  502. return "No source code location found"
  503. def input_shapes(event: _ProfilerEvent):
  504. assert isinstance(event.extra_fields, _ExtraFields_TorchOp)
  505. return tuple(tuple(getattr(i, "sizes", ())) for i in event.extra_fields.inputs)
  506. def input_dtypes(event: _ProfilerEvent):
  507. assert isinstance(event.extra_fields, _ExtraFields_TorchOp)
  508. return tuple(getattr(i, "dtype", None) for i in event.extra_fields.inputs)
  509. def report_all_anti_patterns(prof,
  510. should_benchmark: bool = False,
  511. print_enable: bool = True,
  512. json_report_dir: str = None):
  513. report_dict: Dict = {}
  514. anti_patterns = [
  515. ExtraCUDACopyPattern(prof, should_benchmark),
  516. # ForLoopIndexingPattern(prof, should_benchmark),
  517. FP32MatMulPattern(prof, should_benchmark),
  518. OptimizerSingleTensorPattern(prof, should_benchmark),
  519. SynchronizedDataLoaderPattern(prof, should_benchmark),
  520. GradNotSetToNonePattern(prof, should_benchmark),
  521. Conv2dBiasFollowedByBatchNorm2dPattern(prof, should_benchmark),
  522. MatMulDimInFP16Pattern(prof, should_benchmark)
  523. ]
  524. reported = set()
  525. summaries = []
  526. message_list = [f"{'-'*40}TorchTidy Report{'-'*40}"]
  527. message_list.append("Matched Events:")
  528. for anti_pattern in anti_patterns:
  529. matched_events = anti_pattern.matched_events()
  530. if not matched_events:
  531. continue
  532. summaries.append(anti_pattern.summary(matched_events))
  533. for event in matched_events:
  534. report_msg = anti_pattern.report(event)
  535. if report_msg not in reported:
  536. message_list.append(report_msg)
  537. reported.add(report_msg)
  538. src_location, line_no = source_code_location(event).split(":")
  539. report_dict.setdefault(src_location, []).append({
  540. "line_number": int(line_no),
  541. "name": anti_pattern.name,
  542. "url": anti_pattern.url,
  543. "message": anti_pattern.description,
  544. })
  545. if json_report_dir is not None:
  546. json_report_path = os.path.join(json_report_dir,
  547. "torchtidy_report.json")
  548. if os.path.exists(json_report_path):
  549. with open(json_report_path, "r") as f:
  550. exisiting_report = json.load(f)
  551. exisiting_report.update(report_dict)
  552. report_dict = exisiting_report
  553. with open(json_report_path, "w") as f:
  554. json.dump(report_dict, f, indent=4)
  555. message_list.append("Summary:")
  556. message_list += summaries
  557. message_list.append(f"{'-'*40}TorchTidy Report{'-'*40}")
  558. if print_enable:
  559. print("\n".join(message_list))