code_template.h 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. #pragma once
  2. #include <c10/util/irange.h>
  3. #include <sstream>
  4. #include <string>
  5. #include <unordered_map>
  6. #include <vector>
  7. namespace at {
  8. namespace jit {
  9. // A template environment is a mapping from template variable names, e.g.,
  10. // identifier (corresponding to $identifier) to their expansions.
  11. //
  12. // This template environment supports storing strings, numbers and lists
  13. // of strings, and can be chained together (so that lookup proceeds in
  14. // in the top level environment, and then recurses into a parent
  15. // environment if the key is not found.)
  16. struct TemplateEnv {
  17. TemplateEnv() = default;
  18. TemplateEnv(TemplateEnv& parent) : parent(&parent) {}
  19. using string_list = std::vector<std::string>;
  20. // Add a string 'v' to the map at key 'k'.
  21. void s(const std::string& k, const std::string& v) {
  22. strings_[k] = v;
  23. lists_.erase(k);
  24. }
  25. // Add a number 'v' to the map at key 'k'
  26. template <typename T>
  27. void d(const std::string& k, const T& v) {
  28. strings_[k] = c10::to_string(v);
  29. lists_.erase(k);
  30. }
  31. // Retrieve the string representation of the value stored at 'k' from the map.
  32. // Raises an exception if the key is not found.
  33. const std::string& s(const std::string& k) const {
  34. if (strings_.count(k) == 0) {
  35. if (parent) {
  36. return parent->s(k);
  37. }
  38. notFound(k);
  39. }
  40. return strings_.at(k);
  41. }
  42. // Store a list of strings 'v' in the map at 'k'.
  43. void v(const std::string& k, const string_list& v) {
  44. lists_[k] = v;
  45. strings_.erase(k);
  46. }
  47. // Retrieve a list of strings stored at 'k' from the map.
  48. // Raises an exception if the key is not found.
  49. const string_list& v(const std::string& k) const {
  50. if (lists_.count(k) == 0) {
  51. if (parent) {
  52. return parent->v(k);
  53. }
  54. notFound(k);
  55. }
  56. return lists_.at(k);
  57. }
  58. // Test if a string 'k' is a string (as opposed to a list.)
  59. bool keyIsString(const std::string& k) const {
  60. if (strings_.count(k) > 0)
  61. return true;
  62. if (lists_.count(k) > 0)
  63. return false;
  64. if (parent)
  65. return parent->keyIsString(k);
  66. notFound(k);
  67. }
  68. private:
  69. [[noreturn]] void notFound(const std::string& k) const {
  70. std::stringstream ss;
  71. ss << "key not found: " << k;
  72. throw std::logic_error(ss.str());
  73. }
  74. std::unordered_map<std::string, std::string> strings_;
  75. std::unordered_map<std::string, string_list> lists_;
  76. TemplateEnv* parent{nullptr};
  77. };
  78. /*
  79. # Match $identifier or ${identifier} and replace with the value in env.
  80. # If this identifier is at the beginning of whitespace on a line
  81. # and its value is a list then it is treated as
  82. # block substitution by indenting all lines of all elements.
  83. # If the identifier is on a line starting with non-whitespace and a list
  84. # then it is comma separated. ${,foo} will insert a comma before the list
  85. # if this list is not empty and ${foo,} will insert one after.
  86. */
  87. struct CodeTemplate {
  88. /* implicit */ CodeTemplate(std::string t) : template_text(std::move(t)) {}
  89. std::string format(const TemplateEnv& env) const {
  90. std::stringstream out;
  91. size_t pos = 0;
  92. size_t indent = 0;
  93. bool all_whitespace = true;
  94. while (pos < template_text.size()) {
  95. char c = template_text[pos];
  96. if (c == '$') {
  97. std::stringstream kss;
  98. // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
  99. bool comma_before;
  100. // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
  101. bool comma_after;
  102. size_t new_pos = parseKey(pos, kss, comma_before, comma_after);
  103. std::string k = kss.str();
  104. bool is_string = env.keyIsString(k);
  105. if (all_whitespace) {
  106. if (is_string)
  107. emitStringWithIndents(out, indent, env.s(k));
  108. else
  109. emitLinesIndented(out, indent, env.v(k));
  110. } else {
  111. if (is_string)
  112. out << env.s(k);
  113. else
  114. emitCommaSeparatedList(out, env.v(k), comma_before, comma_after);
  115. }
  116. all_whitespace = false;
  117. pos = new_pos;
  118. } else {
  119. out << c;
  120. if (!isspace(c))
  121. all_whitespace = false;
  122. indent++;
  123. if (c == '\n') {
  124. indent = 0;
  125. all_whitespace = true;
  126. }
  127. pos++;
  128. }
  129. }
  130. return out.str();
  131. }
  132. private:
  133. using string_list = std::vector<std::string>;
  134. char charAt(size_t p) const {
  135. if (p >= template_text.size())
  136. throw std::logic_error("EOS found in key");
  137. return template_text[p];
  138. }
  139. size_t parseKey(
  140. size_t pos,
  141. std::ostream& k,
  142. bool& comma_before,
  143. bool& comma_after) const {
  144. comma_before = false;
  145. comma_after = false;
  146. pos++;
  147. if (charAt(pos) == '{') {
  148. pos++;
  149. if (charAt(pos) == ',') {
  150. comma_before = true;
  151. pos++;
  152. }
  153. pos = parseIdent(pos, k);
  154. if (charAt(pos) == ',') {
  155. comma_after = true;
  156. pos++;
  157. }
  158. if (charAt(pos) != '}')
  159. throw std::logic_error("missing terminating '}'");
  160. pos++;
  161. return pos;
  162. } else {
  163. return parseIdent(pos, k);
  164. }
  165. }
  166. size_t parseIdent(size_t pos, std::ostream& k) const {
  167. while (pos < template_text.size() &&
  168. (isalnum(template_text[pos]) || template_text[pos] == '_')) {
  169. k << template_text[pos];
  170. pos++;
  171. }
  172. return pos;
  173. }
  174. void emitCommaSeparatedList(
  175. std::ostream& out,
  176. const string_list& strings,
  177. bool comma_before,
  178. bool comma_after) const {
  179. if (comma_before && !strings.empty())
  180. out << ", ";
  181. for (const auto i : c10::irange(strings.size())) {
  182. if (i > 0)
  183. out << ", ";
  184. out << strings[i];
  185. }
  186. if (comma_after && !strings.empty())
  187. out << ", ";
  188. }
  189. // These indentation functions follow the convention that they never emit
  190. // leading or trailing newlines when the input string does not have leading
  191. // or trailing newlines. It's the responsibility of the calling function
  192. // to indent correctly in the context.
  193. void emitIndent(std::ostream& out, size_t indent) const {
  194. for (const auto i : c10::irange(indent)) {
  195. (void)i; // Suppress unused variable warning
  196. out << " ";
  197. }
  198. }
  199. void emitStringWithIndents(
  200. std::ostream& out,
  201. size_t indent,
  202. const std::string& str) const {
  203. for (auto c : str) {
  204. out << c;
  205. if (c == '\n') {
  206. emitIndent(out, indent);
  207. }
  208. }
  209. }
  210. void emitLinesIndented(
  211. std::stringstream& out,
  212. size_t indent,
  213. const string_list& strings) const {
  214. for (const auto i : c10::irange(strings.size())) {
  215. if (i > 0)
  216. emitIndent(out, indent);
  217. emitStringWithIndents(out, indent, strings[i]);
  218. if (i + 1 != strings.size())
  219. out << "\n";
  220. }
  221. }
  222. std::string template_text;
  223. };
  224. static inline std::string format(const std::string& fmt, TemplateEnv& env) {
  225. return CodeTemplate(fmt).format(env);
  226. }
  227. } // namespace jit
  228. } // namespace at