retrieve_prs_data.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. import json
  2. import locale
  3. import os
  4. import re
  5. import subprocess
  6. from collections import namedtuple
  7. from os.path import expanduser
  8. import requests
  9. Features = namedtuple(
  10. "Features",
  11. [
  12. "title",
  13. "body",
  14. "pr_number",
  15. "files_changed",
  16. "labels",
  17. ],
  18. )
  19. def dict_to_features(dct):
  20. return Features(
  21. title=dct["title"],
  22. body=dct["body"],
  23. pr_number=dct["pr_number"],
  24. files_changed=dct["files_changed"],
  25. labels=dct["labels"],
  26. )
  27. def features_to_dict(features):
  28. return dict(features._asdict())
  29. def run(command):
  30. """Returns (return-code, stdout, stderr)"""
  31. p = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
  32. output, err = p.communicate()
  33. rc = p.returncode
  34. enc = locale.getpreferredencoding()
  35. output = output.decode(enc)
  36. err = err.decode(enc)
  37. return rc, output.strip(), err.strip()
  38. def commit_body(commit_hash):
  39. cmd = f"git log -n 1 --pretty=format:%b {commit_hash}"
  40. ret, out, err = run(cmd)
  41. return out if ret == 0 else None
  42. def commit_title(commit_hash):
  43. cmd = f"git log -n 1 --pretty=format:%s {commit_hash}"
  44. ret, out, err = run(cmd)
  45. return out if ret == 0 else None
  46. def commit_files_changed(commit_hash):
  47. cmd = f"git diff-tree --no-commit-id --name-only -r {commit_hash}"
  48. ret, out, err = run(cmd)
  49. return out.split("\n") if ret == 0 else None
  50. def parse_pr_number(body, commit_hash, title):
  51. regex = r"(#[0-9]+)"
  52. matches = re.findall(regex, title)
  53. if len(matches) == 0:
  54. if "revert" not in title.lower() and "updating submodules" not in title.lower():
  55. print(f"[{commit_hash}: {title}] Could not parse PR number, ignoring PR")
  56. return None
  57. if len(matches) > 1:
  58. print(f"[{commit_hash}: {title}] Got two PR numbers, using the last one")
  59. return matches[-1][1:]
  60. return matches[0][1:]
  61. def get_ghstack_token():
  62. pattern = "github_oauth = (.*)"
  63. with open(expanduser("~/.ghstackrc"), "r+") as f:
  64. config = f.read()
  65. matches = re.findall(pattern, config)
  66. if len(matches) == 0:
  67. raise RuntimeError("Can't find a github oauth token")
  68. return matches[0]
  69. token = get_ghstack_token()
  70. headers = {"Authorization": f"token {token}"}
  71. def run_query(query):
  72. request = requests.post("https://api.github.com/graphql", json={"query": query}, headers=headers)
  73. if request.status_code == 200:
  74. return request.json()
  75. else:
  76. raise Exception(f"Query failed to run by returning code of {request.status_code}. {query}")
  77. def gh_labels(pr_number):
  78. query = f"""
  79. {{
  80. repository(owner: "pytorch", name: "vision") {{
  81. pullRequest(number: {pr_number}) {{
  82. labels(first: 10) {{
  83. edges {{
  84. node {{
  85. name
  86. }}
  87. }}
  88. }}
  89. }}
  90. }}
  91. }}
  92. """
  93. query = run_query(query)
  94. edges = query["data"]["repository"]["pullRequest"]["labels"]["edges"]
  95. return [edge["node"]["name"] for edge in edges]
  96. def get_features(commit_hash, return_dict=False):
  97. title, body, files_changed = (
  98. commit_title(commit_hash),
  99. commit_body(commit_hash),
  100. commit_files_changed(commit_hash),
  101. )
  102. pr_number = parse_pr_number(body, commit_hash, title)
  103. labels = []
  104. if pr_number is not None:
  105. labels = gh_labels(pr_number)
  106. result = Features(title, body, pr_number, files_changed, labels)
  107. if return_dict:
  108. return features_to_dict(result)
  109. return result
  110. class CommitDataCache:
  111. def __init__(self, path="results/data.json"):
  112. self.path = path
  113. self.data = {}
  114. if os.path.exists(path):
  115. self.data = self.read_from_disk()
  116. def get(self, commit):
  117. if commit not in self.data.keys():
  118. # Fetch and cache the data
  119. self.data[commit] = get_features(commit)
  120. self.write_to_disk()
  121. return self.data[commit]
  122. def read_from_disk(self):
  123. with open(self.path) as f:
  124. data = json.load(f)
  125. data = {commit: dict_to_features(dct) for commit, dct in data.items()}
  126. return data
  127. def write_to_disk(self):
  128. data = {commit: features._asdict() for commit, features in self.data.items()}
  129. with open(self.path, "w") as f:
  130. json.dump(data, f)
  131. def get_commits_between(base_version, new_version):
  132. cmd = f"git merge-base {base_version} {new_version}"
  133. rc, merge_base, _ = run(cmd)
  134. assert rc == 0
  135. # Returns a list of something like
  136. # b33e38ec47 Allow a higher-precision step type for Vec256::arange (#34555)
  137. cmd = f"git log --reverse --oneline {merge_base}..{new_version}"
  138. rc, commits, _ = run(cmd)
  139. assert rc == 0
  140. log_lines = commits.split("\n")
  141. hashes, titles = zip(*[log_line.split(" ", 1) for log_line in log_lines])
  142. return hashes, titles
  143. def convert_to_dataframes(feature_list):
  144. import pandas as pd
  145. df = pd.DataFrame.from_records(feature_list, columns=Features._fields)
  146. return df
  147. def main(base_version, new_version):
  148. hashes, titles = get_commits_between(base_version, new_version)
  149. cdc = CommitDataCache("data.json")
  150. for idx, commit in enumerate(hashes):
  151. if idx % 10 == 0:
  152. print(f"{idx} / {len(hashes)}")
  153. cdc.get(commit)
  154. return cdc
  155. if __name__ == "__main__":
  156. # d = get_features('2ab93592529243862ce8ad5b6acf2628ef8d0dc8')
  157. # print(d)
  158. # hashes, titles = get_commits_between("tags/v0.9.0", "fc852f3b39fe25dd8bf1dedee8f19ea04aa84c15")
  159. # Usage: change the tags below accordingly to the current release, then save the json with
  160. # cdc.write_to_disk().
  161. # Then you can use classify_prs.py (as a notebook)
  162. # to open the json and generate the release notes semi-automatically.
  163. cdc = main("tags/v0.9.0", "fc852f3b39fe25dd8bf1dedee8f19ea04aa84c15")
  164. from IPython import embed
  165. embed()