123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246 |
- #pragma once
- #include <c10/util/irange.h>
- #include <sstream>
- #include <string>
- #include <unordered_map>
- #include <vector>
- namespace at {
- namespace jit {
- // A template environment is a mapping from template variable names, e.g.,
- // identifier (corresponding to $identifier) to their expansions.
- //
- // This template environment supports storing strings, numbers and lists
- // of strings, and can be chained together (so that lookup proceeds in
- // in the top level environment, and then recurses into a parent
- // environment if the key is not found.)
- struct TemplateEnv {
- TemplateEnv() = default;
- TemplateEnv(TemplateEnv& parent) : parent(&parent) {}
- using string_list = std::vector<std::string>;
- // Add a string 'v' to the map at key 'k'.
- void s(const std::string& k, const std::string& v) {
- strings_[k] = v;
- lists_.erase(k);
- }
- // Add a number 'v' to the map at key 'k'
- template <typename T>
- void d(const std::string& k, const T& v) {
- strings_[k] = c10::to_string(v);
- lists_.erase(k);
- }
- // Retrieve the string representation of the value stored at 'k' from the map.
- // Raises an exception if the key is not found.
- const std::string& s(const std::string& k) const {
- if (strings_.count(k) == 0) {
- if (parent) {
- return parent->s(k);
- }
- notFound(k);
- }
- return strings_.at(k);
- }
- // Store a list of strings 'v' in the map at 'k'.
- void v(const std::string& k, const string_list& v) {
- lists_[k] = v;
- strings_.erase(k);
- }
- // Retrieve a list of strings stored at 'k' from the map.
- // Raises an exception if the key is not found.
- const string_list& v(const std::string& k) const {
- if (lists_.count(k) == 0) {
- if (parent) {
- return parent->v(k);
- }
- notFound(k);
- }
- return lists_.at(k);
- }
- // Test if a string 'k' is a string (as opposed to a list.)
- bool keyIsString(const std::string& k) const {
- if (strings_.count(k) > 0)
- return true;
- if (lists_.count(k) > 0)
- return false;
- if (parent)
- return parent->keyIsString(k);
- notFound(k);
- }
- private:
- [[noreturn]] void notFound(const std::string& k) const {
- std::stringstream ss;
- ss << "key not found: " << k;
- throw std::logic_error(ss.str());
- }
- std::unordered_map<std::string, std::string> strings_;
- std::unordered_map<std::string, string_list> lists_;
- TemplateEnv* parent{nullptr};
- };
- /*
- # Match $identifier or ${identifier} and replace with the value in env.
- # If this identifier is at the beginning of whitespace on a line
- # and its value is a list then it is treated as
- # block substitution by indenting all lines of all elements.
- # If the identifier is on a line starting with non-whitespace and a list
- # then it is comma separated. ${,foo} will insert a comma before the list
- # if this list is not empty and ${foo,} will insert one after.
- */
- struct CodeTemplate {
- /* implicit */ CodeTemplate(std::string t) : template_text(std::move(t)) {}
- std::string format(const TemplateEnv& env) const {
- std::stringstream out;
- size_t pos = 0;
- size_t indent = 0;
- bool all_whitespace = true;
- while (pos < template_text.size()) {
- char c = template_text[pos];
- if (c == '$') {
- std::stringstream kss;
- // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- bool comma_before;
- // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- bool comma_after;
- size_t new_pos = parseKey(pos, kss, comma_before, comma_after);
- std::string k = kss.str();
- bool is_string = env.keyIsString(k);
- if (all_whitespace) {
- if (is_string)
- emitStringWithIndents(out, indent, env.s(k));
- else
- emitLinesIndented(out, indent, env.v(k));
- } else {
- if (is_string)
- out << env.s(k);
- else
- emitCommaSeparatedList(out, env.v(k), comma_before, comma_after);
- }
- all_whitespace = false;
- pos = new_pos;
- } else {
- out << c;
- if (!isspace(c))
- all_whitespace = false;
- indent++;
- if (c == '\n') {
- indent = 0;
- all_whitespace = true;
- }
- pos++;
- }
- }
- return out.str();
- }
- private:
- using string_list = std::vector<std::string>;
- char charAt(size_t p) const {
- if (p >= template_text.size())
- throw std::logic_error("EOS found in key");
- return template_text[p];
- }
- size_t parseKey(
- size_t pos,
- std::ostream& k,
- bool& comma_before,
- bool& comma_after) const {
- comma_before = false;
- comma_after = false;
- pos++;
- if (charAt(pos) == '{') {
- pos++;
- if (charAt(pos) == ',') {
- comma_before = true;
- pos++;
- }
- pos = parseIdent(pos, k);
- if (charAt(pos) == ',') {
- comma_after = true;
- pos++;
- }
- if (charAt(pos) != '}')
- throw std::logic_error("missing terminating '}'");
- pos++;
- return pos;
- } else {
- return parseIdent(pos, k);
- }
- }
- size_t parseIdent(size_t pos, std::ostream& k) const {
- while (pos < template_text.size() &&
- (isalnum(template_text[pos]) || template_text[pos] == '_')) {
- k << template_text[pos];
- pos++;
- }
- return pos;
- }
- void emitCommaSeparatedList(
- std::ostream& out,
- const string_list& strings,
- bool comma_before,
- bool comma_after) const {
- if (comma_before && !strings.empty())
- out << ", ";
- for (const auto i : c10::irange(strings.size())) {
- if (i > 0)
- out << ", ";
- out << strings[i];
- }
- if (comma_after && !strings.empty())
- out << ", ";
- }
- // These indentation functions follow the convention that they never emit
- // leading or trailing newlines when the input string does not have leading
- // or trailing newlines. It's the responsibility of the calling function
- // to indent correctly in the context.
- void emitIndent(std::ostream& out, size_t indent) const {
- for (const auto i : c10::irange(indent)) {
- (void)i; // Suppress unused variable warning
- out << " ";
- }
- }
- void emitStringWithIndents(
- std::ostream& out,
- size_t indent,
- const std::string& str) const {
- for (auto c : str) {
- out << c;
- if (c == '\n') {
- emitIndent(out, indent);
- }
- }
- }
- void emitLinesIndented(
- std::stringstream& out,
- size_t indent,
- const string_list& strings) const {
- for (const auto i : c10::irange(strings.size())) {
- if (i > 0)
- emitIndent(out, indent);
- emitStringWithIndents(out, indent, strings[i]);
- if (i + 1 != strings.size())
- out << "\n";
- }
- }
- std::string template_text;
- };
- static inline std::string format(const std::string& fmt, TemplateEnv& env) {
- return CodeTemplate(fmt).format(env);
- }
- } // namespace jit
- } // namespace at
|