Registry.h 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311
  1. #ifndef C10_UTIL_REGISTRY_H_
  2. #define C10_UTIL_REGISTRY_H_
  3. /**
  4. * Simple registry implementation that uses static variables to
  5. * register object creators during program initialization time.
  6. */
  7. // NB: This Registry works poorly when you have other namespaces.
  8. // Make all macro invocations from inside the at namespace.
  9. #include <algorithm>
  10. #include <cstdio>
  11. #include <cstdlib>
  12. #include <functional>
  13. #include <memory>
  14. #include <mutex>
  15. #include <string>
  16. #include <unordered_map>
  17. #include <vector>
  18. #include <c10/macros/Macros.h>
  19. #include <c10/util/Type.h>
  20. namespace c10 {
  21. template <typename KeyType>
  22. inline std::string KeyStrRepr(const KeyType& /*key*/) {
  23. return "[key type printing not supported]";
  24. }
  25. template <>
  26. inline std::string KeyStrRepr(const std::string& key) {
  27. return key;
  28. }
  29. enum RegistryPriority {
  30. REGISTRY_FALLBACK = 1,
  31. REGISTRY_DEFAULT = 2,
  32. REGISTRY_PREFERRED = 3,
  33. };
  34. /**
  35. * @brief A template class that allows one to register classes by keys.
  36. *
  37. * The keys are usually a std::string specifying the name, but can be anything
  38. * that can be used in a std::map.
  39. *
  40. * You should most likely not use the Registry class explicitly, but use the
  41. * helper macros below to declare specific registries as well as registering
  42. * objects.
  43. */
  44. template <class SrcType, class ObjectPtrType, class... Args>
  45. class Registry {
  46. public:
  47. typedef std::function<ObjectPtrType(Args...)> Creator;
  48. Registry(bool warning = true)
  49. : registry_(), priority_(), terminate_(true), warning_(warning) {}
  50. void Register(
  51. const SrcType& key,
  52. Creator creator,
  53. const RegistryPriority priority = REGISTRY_DEFAULT) {
  54. std::lock_guard<std::mutex> lock(register_mutex_);
  55. // The if statement below is essentially the same as the following line:
  56. // TORCH_CHECK_EQ(registry_.count(key), 0) << "Key " << key
  57. // << " registered twice.";
  58. // However, TORCH_CHECK_EQ depends on google logging, and since registration
  59. // is carried out at static initialization time, we do not want to have an
  60. // explicit dependency on glog's initialization function.
  61. if (registry_.count(key) != 0) {
  62. auto cur_priority = priority_[key];
  63. if (priority > cur_priority) {
  64. #ifdef DEBUG
  65. std::string warn_msg =
  66. "Overwriting already registered item for key " + KeyStrRepr(key);
  67. fprintf(stderr, "%s\n", warn_msg.c_str());
  68. #endif
  69. registry_[key] = creator;
  70. priority_[key] = priority;
  71. } else if (priority == cur_priority) {
  72. std::string err_msg =
  73. "Key already registered with the same priority: " + KeyStrRepr(key);
  74. fprintf(stderr, "%s\n", err_msg.c_str());
  75. if (terminate_) {
  76. std::exit(1);
  77. } else {
  78. throw std::runtime_error(err_msg);
  79. }
  80. } else if (warning_) {
  81. std::string warn_msg =
  82. "Higher priority item already registered, skipping registration of " +
  83. KeyStrRepr(key);
  84. fprintf(stderr, "%s\n", warn_msg.c_str());
  85. }
  86. } else {
  87. registry_[key] = creator;
  88. priority_[key] = priority;
  89. }
  90. }
  91. void Register(
  92. const SrcType& key,
  93. Creator creator,
  94. const std::string& help_msg,
  95. const RegistryPriority priority = REGISTRY_DEFAULT) {
  96. Register(key, creator, priority);
  97. help_message_[key] = help_msg;
  98. }
  99. inline bool Has(const SrcType& key) {
  100. return (registry_.count(key) != 0);
  101. }
  102. ObjectPtrType Create(const SrcType& key, Args... args) {
  103. auto it = registry_.find(key);
  104. if (it == registry_.end()) {
  105. // Returns nullptr if the key is not registered.
  106. return nullptr;
  107. }
  108. return it->second(args...);
  109. }
  110. /**
  111. * Returns the keys currently registered as a std::vector.
  112. */
  113. std::vector<SrcType> Keys() const {
  114. std::vector<SrcType> keys;
  115. keys.reserve(registry_.size());
  116. for (const auto& it : registry_) {
  117. keys.push_back(it.first);
  118. }
  119. return keys;
  120. }
  121. inline const std::unordered_map<SrcType, std::string>& HelpMessage() const {
  122. return help_message_;
  123. }
  124. const char* HelpMessage(const SrcType& key) const {
  125. auto it = help_message_.find(key);
  126. if (it == help_message_.end()) {
  127. return nullptr;
  128. }
  129. return it->second.c_str();
  130. }
  131. // Used for testing, if terminate is unset, Registry throws instead of
  132. // calling std::exit
  133. void SetTerminate(bool terminate) {
  134. terminate_ = terminate;
  135. }
  136. private:
  137. std::unordered_map<SrcType, Creator> registry_;
  138. std::unordered_map<SrcType, RegistryPriority> priority_;
  139. bool terminate_;
  140. const bool warning_;
  141. std::unordered_map<SrcType, std::string> help_message_;
  142. std::mutex register_mutex_;
  143. C10_DISABLE_COPY_AND_ASSIGN(Registry);
  144. };
  145. template <class SrcType, class ObjectPtrType, class... Args>
  146. class Registerer {
  147. public:
  148. explicit Registerer(
  149. const SrcType& key,
  150. Registry<SrcType, ObjectPtrType, Args...>* registry,
  151. typename Registry<SrcType, ObjectPtrType, Args...>::Creator creator,
  152. const std::string& help_msg = "") {
  153. registry->Register(key, creator, help_msg);
  154. }
  155. explicit Registerer(
  156. const SrcType& key,
  157. const RegistryPriority priority,
  158. Registry<SrcType, ObjectPtrType, Args...>* registry,
  159. typename Registry<SrcType, ObjectPtrType, Args...>::Creator creator,
  160. const std::string& help_msg = "") {
  161. registry->Register(key, creator, help_msg, priority);
  162. }
  163. template <class DerivedType>
  164. static ObjectPtrType DefaultCreator(Args... args) {
  165. return ObjectPtrType(new DerivedType(args...));
  166. }
  167. };
  168. /**
  169. * C10_DECLARE_TYPED_REGISTRY is a macro that expands to a function
  170. * declaration, as well as creating a convenient typename for its corresponding
  171. * registerer.
  172. */
  173. // Note on C10_IMPORT and C10_EXPORT below: we need to explicitly mark DECLARE
  174. // as import and DEFINE as export, because these registry macros will be used
  175. // in downstream shared libraries as well, and one cannot use *_API - the API
  176. // macro will be defined on a per-shared-library basis. Semantically, when one
  177. // declares a typed registry it is always going to be IMPORT, and when one
  178. // defines a registry (which should happen ONLY ONCE and ONLY IN SOURCE FILE),
  179. // the instantiation unit is always going to be exported.
  180. //
  181. // The only unique condition is when in the same file one does DECLARE and
  182. // DEFINE - in Windows compilers, this generates a warning that dllimport and
  183. // dllexport are mixed, but the warning is fine and linker will be properly
  184. // exporting the symbol. Same thing happens in the gflags flag declaration and
  185. // definition caes.
  186. #define C10_DECLARE_TYPED_REGISTRY( \
  187. RegistryName, SrcType, ObjectType, PtrType, ...) \
  188. C10_IMPORT ::c10::Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>* \
  189. RegistryName(); \
  190. typedef ::c10::Registerer<SrcType, PtrType<ObjectType>, ##__VA_ARGS__> \
  191. Registerer##RegistryName
  192. #define C10_DEFINE_TYPED_REGISTRY( \
  193. RegistryName, SrcType, ObjectType, PtrType, ...) \
  194. C10_EXPORT ::c10::Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>* \
  195. RegistryName() { \
  196. static ::c10::Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>* \
  197. registry = new ::c10:: \
  198. Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>(); \
  199. return registry; \
  200. }
  201. #define C10_DEFINE_TYPED_REGISTRY_WITHOUT_WARNING( \
  202. RegistryName, SrcType, ObjectType, PtrType, ...) \
  203. C10_EXPORT ::c10::Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>* \
  204. RegistryName() { \
  205. static ::c10::Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>* \
  206. registry = \
  207. new ::c10::Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>( \
  208. false); \
  209. return registry; \
  210. }
  211. // Note(Yangqing): The __VA_ARGS__ below allows one to specify a templated
  212. // creator with comma in its templated arguments.
  213. #define C10_REGISTER_TYPED_CREATOR(RegistryName, key, ...) \
  214. static Registerer##RegistryName C10_ANONYMOUS_VARIABLE(g_##RegistryName)( \
  215. key, RegistryName(), ##__VA_ARGS__);
  216. #define C10_REGISTER_TYPED_CREATOR_WITH_PRIORITY( \
  217. RegistryName, key, priority, ...) \
  218. static Registerer##RegistryName C10_ANONYMOUS_VARIABLE(g_##RegistryName)( \
  219. key, priority, RegistryName(), ##__VA_ARGS__);
  220. #define C10_REGISTER_TYPED_CLASS(RegistryName, key, ...) \
  221. static Registerer##RegistryName C10_ANONYMOUS_VARIABLE(g_##RegistryName)( \
  222. key, \
  223. RegistryName(), \
  224. Registerer##RegistryName::DefaultCreator<__VA_ARGS__>, \
  225. ::c10::demangle_type<__VA_ARGS__>());
  226. #define C10_REGISTER_TYPED_CLASS_WITH_PRIORITY( \
  227. RegistryName, key, priority, ...) \
  228. static Registerer##RegistryName C10_ANONYMOUS_VARIABLE(g_##RegistryName)( \
  229. key, \
  230. priority, \
  231. RegistryName(), \
  232. Registerer##RegistryName::DefaultCreator<__VA_ARGS__>, \
  233. ::c10::demangle_type<__VA_ARGS__>());
  234. // C10_DECLARE_REGISTRY and C10_DEFINE_REGISTRY are hard-wired to use
  235. // std::string as the key type, because that is the most commonly used cases.
  236. #define C10_DECLARE_REGISTRY(RegistryName, ObjectType, ...) \
  237. C10_DECLARE_TYPED_REGISTRY( \
  238. RegistryName, std::string, ObjectType, std::unique_ptr, ##__VA_ARGS__)
  239. #define C10_DEFINE_REGISTRY(RegistryName, ObjectType, ...) \
  240. C10_DEFINE_TYPED_REGISTRY( \
  241. RegistryName, std::string, ObjectType, std::unique_ptr, ##__VA_ARGS__)
  242. #define C10_DEFINE_REGISTRY_WITHOUT_WARNING(RegistryName, ObjectType, ...) \
  243. C10_DEFINE_TYPED_REGISTRY_WITHOUT_WARNING( \
  244. RegistryName, std::string, ObjectType, std::unique_ptr, ##__VA_ARGS__)
  245. #define C10_DECLARE_SHARED_REGISTRY(RegistryName, ObjectType, ...) \
  246. C10_DECLARE_TYPED_REGISTRY( \
  247. RegistryName, std::string, ObjectType, std::shared_ptr, ##__VA_ARGS__)
  248. #define C10_DEFINE_SHARED_REGISTRY(RegistryName, ObjectType, ...) \
  249. C10_DEFINE_TYPED_REGISTRY( \
  250. RegistryName, std::string, ObjectType, std::shared_ptr, ##__VA_ARGS__)
  251. #define C10_DEFINE_SHARED_REGISTRY_WITHOUT_WARNING( \
  252. RegistryName, ObjectType, ...) \
  253. C10_DEFINE_TYPED_REGISTRY_WITHOUT_WARNING( \
  254. RegistryName, std::string, ObjectType, std::shared_ptr, ##__VA_ARGS__)
  255. // C10_REGISTER_CREATOR and C10_REGISTER_CLASS are hard-wired to use std::string
  256. // as the key
  257. // type, because that is the most commonly used cases.
  258. #define C10_REGISTER_CREATOR(RegistryName, key, ...) \
  259. C10_REGISTER_TYPED_CREATOR(RegistryName, #key, __VA_ARGS__)
  260. #define C10_REGISTER_CREATOR_WITH_PRIORITY(RegistryName, key, priority, ...) \
  261. C10_REGISTER_TYPED_CREATOR_WITH_PRIORITY( \
  262. RegistryName, #key, priority, __VA_ARGS__)
  263. #define C10_REGISTER_CLASS(RegistryName, key, ...) \
  264. C10_REGISTER_TYPED_CLASS(RegistryName, #key, __VA_ARGS__)
  265. #define C10_REGISTER_CLASS_WITH_PRIORITY(RegistryName, key, priority, ...) \
  266. C10_REGISTER_TYPED_CLASS_WITH_PRIORITY( \
  267. RegistryName, #key, priority, __VA_ARGS__)
  268. } // namespace c10
  269. #endif // C10_UTIL_REGISTRY_H_