plot_mode.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400
  1. from .plot_interval import PlotInterval
  2. from .plot_object import PlotObject
  3. from .util import parse_option_string
  4. from sympy.core.symbol import Symbol
  5. from sympy.core.sympify import sympify
  6. from sympy.geometry.entity import GeometryEntity
  7. from sympy.utilities.iterables import is_sequence
  8. class PlotMode(PlotObject):
  9. """
  10. Grandparent class for plotting
  11. modes. Serves as interface for
  12. registration, lookup, and init
  13. of modes.
  14. To create a new plot mode,
  15. inherit from PlotModeBase
  16. or one of its children, such
  17. as PlotSurface or PlotCurve.
  18. """
  19. ## Class-level attributes
  20. ## used to register and lookup
  21. ## plot modes. See PlotModeBase
  22. ## for descriptions and usage.
  23. i_vars, d_vars = '', ''
  24. intervals = []
  25. aliases = []
  26. is_default = False
  27. ## Draw is the only method here which
  28. ## is meant to be overridden in child
  29. ## classes, and PlotModeBase provides
  30. ## a base implementation.
  31. def draw(self):
  32. raise NotImplementedError()
  33. ## Everything else in this file has to
  34. ## do with registration and retrieval
  35. ## of plot modes. This is where I've
  36. ## hidden much of the ugliness of automatic
  37. ## plot mode divination...
  38. ## Plot mode registry data structures
  39. _mode_alias_list = []
  40. _mode_map = {
  41. 1: {1: {}, 2: {}},
  42. 2: {1: {}, 2: {}},
  43. 3: {1: {}, 2: {}},
  44. } # [d][i][alias_str]: class
  45. _mode_default_map = {
  46. 1: {},
  47. 2: {},
  48. 3: {},
  49. } # [d][i]: class
  50. _i_var_max, _d_var_max = 2, 3
  51. def __new__(cls, *args, **kwargs):
  52. """
  53. This is the function which interprets
  54. arguments given to Plot.__init__ and
  55. Plot.__setattr__. Returns an initialized
  56. instance of the appropriate child class.
  57. """
  58. newargs, newkwargs = PlotMode._extract_options(args, kwargs)
  59. mode_arg = newkwargs.get('mode', '')
  60. # Interpret the arguments
  61. d_vars, intervals = PlotMode._interpret_args(newargs)
  62. i_vars = PlotMode._find_i_vars(d_vars, intervals)
  63. i, d = max([len(i_vars), len(intervals)]), len(d_vars)
  64. # Find the appropriate mode
  65. subcls = PlotMode._get_mode(mode_arg, i, d)
  66. # Create the object
  67. o = object.__new__(subcls)
  68. # Do some setup for the mode instance
  69. o.d_vars = d_vars
  70. o._fill_i_vars(i_vars)
  71. o._fill_intervals(intervals)
  72. o.options = newkwargs
  73. return o
  74. @staticmethod
  75. def _get_mode(mode_arg, i_var_count, d_var_count):
  76. """
  77. Tries to return an appropriate mode class.
  78. Intended to be called only by __new__.
  79. mode_arg
  80. Can be a string or a class. If it is a
  81. PlotMode subclass, it is simply returned.
  82. If it is a string, it can an alias for
  83. a mode or an empty string. In the latter
  84. case, we try to find a default mode for
  85. the i_var_count and d_var_count.
  86. i_var_count
  87. The number of independent variables
  88. needed to evaluate the d_vars.
  89. d_var_count
  90. The number of dependent variables;
  91. usually the number of functions to
  92. be evaluated in plotting.
  93. For example, a Cartesian function y = f(x) has
  94. one i_var (x) and one d_var (y). A parametric
  95. form x,y,z = f(u,v), f(u,v), f(u,v) has two
  96. two i_vars (u,v) and three d_vars (x,y,z).
  97. """
  98. # if the mode_arg is simply a PlotMode class,
  99. # check that the mode supports the numbers
  100. # of independent and dependent vars, then
  101. # return it
  102. try:
  103. m = None
  104. if issubclass(mode_arg, PlotMode):
  105. m = mode_arg
  106. except TypeError:
  107. pass
  108. if m:
  109. if not m._was_initialized:
  110. raise ValueError(("To use unregistered plot mode %s "
  111. "you must first call %s._init_mode().")
  112. % (m.__name__, m.__name__))
  113. if d_var_count != m.d_var_count:
  114. raise ValueError(("%s can only plot functions "
  115. "with %i dependent variables.")
  116. % (m.__name__,
  117. m.d_var_count))
  118. if i_var_count > m.i_var_count:
  119. raise ValueError(("%s cannot plot functions "
  120. "with more than %i independent "
  121. "variables.")
  122. % (m.__name__,
  123. m.i_var_count))
  124. return m
  125. # If it is a string, there are two possibilities.
  126. if isinstance(mode_arg, str):
  127. i, d = i_var_count, d_var_count
  128. if i > PlotMode._i_var_max:
  129. raise ValueError(var_count_error(True, True))
  130. if d > PlotMode._d_var_max:
  131. raise ValueError(var_count_error(False, True))
  132. # If the string is '', try to find a suitable
  133. # default mode
  134. if not mode_arg:
  135. return PlotMode._get_default_mode(i, d)
  136. # Otherwise, interpret the string as a mode
  137. # alias (e.g. 'cartesian', 'parametric', etc)
  138. else:
  139. return PlotMode._get_aliased_mode(mode_arg, i, d)
  140. else:
  141. raise ValueError("PlotMode argument must be "
  142. "a class or a string")
  143. @staticmethod
  144. def _get_default_mode(i, d, i_vars=-1):
  145. if i_vars == -1:
  146. i_vars = i
  147. try:
  148. return PlotMode._mode_default_map[d][i]
  149. except KeyError:
  150. # Keep looking for modes in higher i var counts
  151. # which support the given d var count until we
  152. # reach the max i_var count.
  153. if i < PlotMode._i_var_max:
  154. return PlotMode._get_default_mode(i + 1, d, i_vars)
  155. else:
  156. raise ValueError(("Couldn't find a default mode "
  157. "for %i independent and %i "
  158. "dependent variables.") % (i_vars, d))
  159. @staticmethod
  160. def _get_aliased_mode(alias, i, d, i_vars=-1):
  161. if i_vars == -1:
  162. i_vars = i
  163. if alias not in PlotMode._mode_alias_list:
  164. raise ValueError(("Couldn't find a mode called"
  165. " %s. Known modes: %s.")
  166. % (alias, ", ".join(PlotMode._mode_alias_list)))
  167. try:
  168. return PlotMode._mode_map[d][i][alias]
  169. except TypeError:
  170. # Keep looking for modes in higher i var counts
  171. # which support the given d var count and alias
  172. # until we reach the max i_var count.
  173. if i < PlotMode._i_var_max:
  174. return PlotMode._get_aliased_mode(alias, i + 1, d, i_vars)
  175. else:
  176. raise ValueError(("Couldn't find a %s mode "
  177. "for %i independent and %i "
  178. "dependent variables.")
  179. % (alias, i_vars, d))
  180. @classmethod
  181. def _register(cls):
  182. """
  183. Called once for each user-usable plot mode.
  184. For Cartesian2D, it is invoked after the
  185. class definition: Cartesian2D._register()
  186. """
  187. name = cls.__name__
  188. cls._init_mode()
  189. try:
  190. i, d = cls.i_var_count, cls.d_var_count
  191. # Add the mode to _mode_map under all
  192. # given aliases
  193. for a in cls.aliases:
  194. if a not in PlotMode._mode_alias_list:
  195. # Also track valid aliases, so
  196. # we can quickly know when given
  197. # an invalid one in _get_mode.
  198. PlotMode._mode_alias_list.append(a)
  199. PlotMode._mode_map[d][i][a] = cls
  200. if cls.is_default:
  201. # If this mode was marked as the
  202. # default for this d,i combination,
  203. # also set that.
  204. PlotMode._mode_default_map[d][i] = cls
  205. except Exception as e:
  206. raise RuntimeError(("Failed to register "
  207. "plot mode %s. Reason: %s")
  208. % (name, (str(e))))
  209. @classmethod
  210. def _init_mode(cls):
  211. """
  212. Initializes the plot mode based on
  213. the 'mode-specific parameters' above.
  214. Only intended to be called by
  215. PlotMode._register(). To use a mode without
  216. registering it, you can directly call
  217. ModeSubclass._init_mode().
  218. """
  219. def symbols_list(symbol_str):
  220. return [Symbol(s) for s in symbol_str]
  221. # Convert the vars strs into
  222. # lists of symbols.
  223. cls.i_vars = symbols_list(cls.i_vars)
  224. cls.d_vars = symbols_list(cls.d_vars)
  225. # Var count is used often, calculate
  226. # it once here
  227. cls.i_var_count = len(cls.i_vars)
  228. cls.d_var_count = len(cls.d_vars)
  229. if cls.i_var_count > PlotMode._i_var_max:
  230. raise ValueError(var_count_error(True, False))
  231. if cls.d_var_count > PlotMode._d_var_max:
  232. raise ValueError(var_count_error(False, False))
  233. # Try to use first alias as primary_alias
  234. if len(cls.aliases) > 0:
  235. cls.primary_alias = cls.aliases[0]
  236. else:
  237. cls.primary_alias = cls.__name__
  238. di = cls.intervals
  239. if len(di) != cls.i_var_count:
  240. raise ValueError("Plot mode must provide a "
  241. "default interval for each i_var.")
  242. for i in range(cls.i_var_count):
  243. # default intervals must be given [min,max,steps]
  244. # (no var, but they must be in the same order as i_vars)
  245. if len(di[i]) != 3:
  246. raise ValueError("length should be equal to 3")
  247. # Initialize an incomplete interval,
  248. # to later be filled with a var when
  249. # the mode is instantiated.
  250. di[i] = PlotInterval(None, *di[i])
  251. # To prevent people from using modes
  252. # without these required fields set up.
  253. cls._was_initialized = True
  254. _was_initialized = False
  255. ## Initializer Helper Methods
  256. @staticmethod
  257. def _find_i_vars(functions, intervals):
  258. i_vars = []
  259. # First, collect i_vars in the
  260. # order they are given in any
  261. # intervals.
  262. for i in intervals:
  263. if i.v is None:
  264. continue
  265. elif i.v in i_vars:
  266. raise ValueError(("Multiple intervals given "
  267. "for %s.") % (str(i.v)))
  268. i_vars.append(i.v)
  269. # Then, find any remaining
  270. # i_vars in given functions
  271. # (aka d_vars)
  272. for f in functions:
  273. for a in f.free_symbols:
  274. if a not in i_vars:
  275. i_vars.append(a)
  276. return i_vars
  277. def _fill_i_vars(self, i_vars):
  278. # copy default i_vars
  279. self.i_vars = [Symbol(str(i)) for i in self.i_vars]
  280. # replace with given i_vars
  281. for i in range(len(i_vars)):
  282. self.i_vars[i] = i_vars[i]
  283. def _fill_intervals(self, intervals):
  284. # copy default intervals
  285. self.intervals = [PlotInterval(i) for i in self.intervals]
  286. # track i_vars used so far
  287. v_used = []
  288. # fill copy of default
  289. # intervals with given info
  290. for i in range(len(intervals)):
  291. self.intervals[i].fill_from(intervals[i])
  292. if self.intervals[i].v is not None:
  293. v_used.append(self.intervals[i].v)
  294. # Find any orphan intervals and
  295. # assign them i_vars
  296. for i in range(len(self.intervals)):
  297. if self.intervals[i].v is None:
  298. u = [v for v in self.i_vars if v not in v_used]
  299. if len(u) == 0:
  300. raise ValueError("length should not be equal to 0")
  301. self.intervals[i].v = u[0]
  302. v_used.append(u[0])
  303. @staticmethod
  304. def _interpret_args(args):
  305. interval_wrong_order = "PlotInterval %s was given before any function(s)."
  306. interpret_error = "Could not interpret %s as a function or interval."
  307. functions, intervals = [], []
  308. if isinstance(args[0], GeometryEntity):
  309. for coords in list(args[0].arbitrary_point()):
  310. functions.append(coords)
  311. intervals.append(PlotInterval.try_parse(args[0].plot_interval()))
  312. else:
  313. for a in args:
  314. i = PlotInterval.try_parse(a)
  315. if i is not None:
  316. if len(functions) == 0:
  317. raise ValueError(interval_wrong_order % (str(i)))
  318. else:
  319. intervals.append(i)
  320. else:
  321. if is_sequence(a, include=str):
  322. raise ValueError(interpret_error % (str(a)))
  323. try:
  324. f = sympify(a)
  325. functions.append(f)
  326. except TypeError:
  327. raise ValueError(interpret_error % str(a))
  328. return functions, intervals
  329. @staticmethod
  330. def _extract_options(args, kwargs):
  331. newkwargs, newargs = {}, []
  332. for a in args:
  333. if isinstance(a, str):
  334. newkwargs = dict(newkwargs, **parse_option_string(a))
  335. else:
  336. newargs.append(a)
  337. newkwargs = dict(newkwargs, **kwargs)
  338. return newargs, newkwargs
  339. def var_count_error(is_independent, is_plotting):
  340. """
  341. Used to format an error message which differs
  342. slightly in 4 places.
  343. """
  344. if is_plotting:
  345. v = "Plotting"
  346. else:
  347. v = "Registering plot modes"
  348. if is_independent:
  349. n, s = PlotMode._i_var_max, "independent"
  350. else:
  351. n, s = PlotMode._d_var_max, "dependent"
  352. return ("%s with more than %i %s variables "
  353. "is not supported.") % (v, n, s)