|
- from .plot_interval import PlotInterval
- from .plot_object import PlotObject
- from .util import parse_option_string
- from sympy.core.symbol import Symbol
- from sympy.core.sympify import sympify
- from sympy.geometry.entity import GeometryEntity
- from sympy.utilities.iterables import is_sequence
- class PlotMode(PlotObject):
- """
- Grandparent class for plotting
- modes. Serves as interface for
- registration, lookup, and init
- of modes.
- To create a new plot mode,
- inherit from PlotModeBase
- or one of its children, such
- as PlotSurface or PlotCurve.
- """
- ## Class-level attributes
- ## used to register and lookup
- ## plot modes. See PlotModeBase
- ## for descriptions and usage.
- i_vars, d_vars = '', ''
- intervals = []
- aliases = []
- is_default = False
- ## Draw is the only method here which
- ## is meant to be overridden in child
- ## classes, and PlotModeBase provides
- ## a base implementation.
- def draw(self):
- raise NotImplementedError()
- ## Everything else in this file has to
- ## do with registration and retrieval
- ## of plot modes. This is where I've
- ## hidden much of the ugliness of automatic
- ## plot mode divination...
- ## Plot mode registry data structures
- _mode_alias_list = []
- _mode_map = {
- 1: {1: {}, 2: {}},
- 2: {1: {}, 2: {}},
- 3: {1: {}, 2: {}},
- } # [d][i][alias_str]: class
- _mode_default_map = {
- 1: {},
- 2: {},
- 3: {},
- } # [d][i]: class
- _i_var_max, _d_var_max = 2, 3
- def __new__(cls, *args, **kwargs):
- """
- This is the function which interprets
- arguments given to Plot.__init__ and
- Plot.__setattr__. Returns an initialized
- instance of the appropriate child class.
- """
- newargs, newkwargs = PlotMode._extract_options(args, kwargs)
- mode_arg = newkwargs.get('mode', '')
- # Interpret the arguments
- d_vars, intervals = PlotMode._interpret_args(newargs)
- i_vars = PlotMode._find_i_vars(d_vars, intervals)
- i, d = max([len(i_vars), len(intervals)]), len(d_vars)
- # Find the appropriate mode
- subcls = PlotMode._get_mode(mode_arg, i, d)
- # Create the object
- o = object.__new__(subcls)
- # Do some setup for the mode instance
- o.d_vars = d_vars
- o._fill_i_vars(i_vars)
- o._fill_intervals(intervals)
- o.options = newkwargs
- return o
- @staticmethod
- def _get_mode(mode_arg, i_var_count, d_var_count):
- """
- Tries to return an appropriate mode class.
- Intended to be called only by __new__.
- mode_arg
- Can be a string or a class. If it is a
- PlotMode subclass, it is simply returned.
- If it is a string, it can an alias for
- a mode or an empty string. In the latter
- case, we try to find a default mode for
- the i_var_count and d_var_count.
- i_var_count
- The number of independent variables
- needed to evaluate the d_vars.
- d_var_count
- The number of dependent variables;
- usually the number of functions to
- be evaluated in plotting.
- For example, a Cartesian function y = f(x) has
- one i_var (x) and one d_var (y). A parametric
- form x,y,z = f(u,v), f(u,v), f(u,v) has two
- two i_vars (u,v) and three d_vars (x,y,z).
- """
- # if the mode_arg is simply a PlotMode class,
- # check that the mode supports the numbers
- # of independent and dependent vars, then
- # return it
- try:
- m = None
- if issubclass(mode_arg, PlotMode):
- m = mode_arg
- except TypeError:
- pass
- if m:
- if not m._was_initialized:
- raise ValueError(("To use unregistered plot mode %s "
- "you must first call %s._init_mode().")
- % (m.__name__, m.__name__))
- if d_var_count != m.d_var_count:
- raise ValueError(("%s can only plot functions "
- "with %i dependent variables.")
- % (m.__name__,
- m.d_var_count))
- if i_var_count > m.i_var_count:
- raise ValueError(("%s cannot plot functions "
- "with more than %i independent "
- "variables.")
- % (m.__name__,
- m.i_var_count))
- return m
- # If it is a string, there are two possibilities.
- if isinstance(mode_arg, str):
- i, d = i_var_count, d_var_count
- if i > PlotMode._i_var_max:
- raise ValueError(var_count_error(True, True))
- if d > PlotMode._d_var_max:
- raise ValueError(var_count_error(False, True))
- # If the string is '', try to find a suitable
- # default mode
- if not mode_arg:
- return PlotMode._get_default_mode(i, d)
- # Otherwise, interpret the string as a mode
- # alias (e.g. 'cartesian', 'parametric', etc)
- else:
- return PlotMode._get_aliased_mode(mode_arg, i, d)
- else:
- raise ValueError("PlotMode argument must be "
- "a class or a string")
- @staticmethod
- def _get_default_mode(i, d, i_vars=-1):
- if i_vars == -1:
- i_vars = i
- try:
- return PlotMode._mode_default_map[d][i]
- except KeyError:
- # Keep looking for modes in higher i var counts
- # which support the given d var count until we
- # reach the max i_var count.
- if i < PlotMode._i_var_max:
- return PlotMode._get_default_mode(i + 1, d, i_vars)
- else:
- raise ValueError(("Couldn't find a default mode "
- "for %i independent and %i "
- "dependent variables.") % (i_vars, d))
- @staticmethod
- def _get_aliased_mode(alias, i, d, i_vars=-1):
- if i_vars == -1:
- i_vars = i
- if alias not in PlotMode._mode_alias_list:
- raise ValueError(("Couldn't find a mode called"
- " %s. Known modes: %s.")
- % (alias, ", ".join(PlotMode._mode_alias_list)))
- try:
- return PlotMode._mode_map[d][i][alias]
- except TypeError:
- # Keep looking for modes in higher i var counts
- # which support the given d var count and alias
- # until we reach the max i_var count.
- if i < PlotMode._i_var_max:
- return PlotMode._get_aliased_mode(alias, i + 1, d, i_vars)
- else:
- raise ValueError(("Couldn't find a %s mode "
- "for %i independent and %i "
- "dependent variables.")
- % (alias, i_vars, d))
- @classmethod
- def _register(cls):
- """
- Called once for each user-usable plot mode.
- For Cartesian2D, it is invoked after the
- class definition: Cartesian2D._register()
- """
- name = cls.__name__
- cls._init_mode()
- try:
- i, d = cls.i_var_count, cls.d_var_count
- # Add the mode to _mode_map under all
- # given aliases
- for a in cls.aliases:
- if a not in PlotMode._mode_alias_list:
- # Also track valid aliases, so
- # we can quickly know when given
- # an invalid one in _get_mode.
- PlotMode._mode_alias_list.append(a)
- PlotMode._mode_map[d][i][a] = cls
- if cls.is_default:
- # If this mode was marked as the
- # default for this d,i combination,
- # also set that.
- PlotMode._mode_default_map[d][i] = cls
- except Exception as e:
- raise RuntimeError(("Failed to register "
- "plot mode %s. Reason: %s")
- % (name, (str(e))))
- @classmethod
- def _init_mode(cls):
- """
- Initializes the plot mode based on
- the 'mode-specific parameters' above.
- Only intended to be called by
- PlotMode._register(). To use a mode without
- registering it, you can directly call
- ModeSubclass._init_mode().
- """
- def symbols_list(symbol_str):
- return [Symbol(s) for s in symbol_str]
- # Convert the vars strs into
- # lists of symbols.
- cls.i_vars = symbols_list(cls.i_vars)
- cls.d_vars = symbols_list(cls.d_vars)
- # Var count is used often, calculate
- # it once here
- cls.i_var_count = len(cls.i_vars)
- cls.d_var_count = len(cls.d_vars)
- if cls.i_var_count > PlotMode._i_var_max:
- raise ValueError(var_count_error(True, False))
- if cls.d_var_count > PlotMode._d_var_max:
- raise ValueError(var_count_error(False, False))
- # Try to use first alias as primary_alias
- if len(cls.aliases) > 0:
- cls.primary_alias = cls.aliases[0]
- else:
- cls.primary_alias = cls.__name__
- di = cls.intervals
- if len(di) != cls.i_var_count:
- raise ValueError("Plot mode must provide a "
- "default interval for each i_var.")
- for i in range(cls.i_var_count):
- # default intervals must be given [min,max,steps]
- # (no var, but they must be in the same order as i_vars)
- if len(di[i]) != 3:
- raise ValueError("length should be equal to 3")
- # Initialize an incomplete interval,
- # to later be filled with a var when
- # the mode is instantiated.
- di[i] = PlotInterval(None, *di[i])
- # To prevent people from using modes
- # without these required fields set up.
- cls._was_initialized = True
- _was_initialized = False
- ## Initializer Helper Methods
- @staticmethod
- def _find_i_vars(functions, intervals):
- i_vars = []
- # First, collect i_vars in the
- # order they are given in any
- # intervals.
- for i in intervals:
- if i.v is None:
- continue
- elif i.v in i_vars:
- raise ValueError(("Multiple intervals given "
- "for %s.") % (str(i.v)))
- i_vars.append(i.v)
- # Then, find any remaining
- # i_vars in given functions
- # (aka d_vars)
- for f in functions:
- for a in f.free_symbols:
- if a not in i_vars:
- i_vars.append(a)
- return i_vars
- def _fill_i_vars(self, i_vars):
- # copy default i_vars
- self.i_vars = [Symbol(str(i)) for i in self.i_vars]
- # replace with given i_vars
- for i in range(len(i_vars)):
- self.i_vars[i] = i_vars[i]
- def _fill_intervals(self, intervals):
- # copy default intervals
- self.intervals = [PlotInterval(i) for i in self.intervals]
- # track i_vars used so far
- v_used = []
- # fill copy of default
- # intervals with given info
- for i in range(len(intervals)):
- self.intervals[i].fill_from(intervals[i])
- if self.intervals[i].v is not None:
- v_used.append(self.intervals[i].v)
- # Find any orphan intervals and
- # assign them i_vars
- for i in range(len(self.intervals)):
- if self.intervals[i].v is None:
- u = [v for v in self.i_vars if v not in v_used]
- if len(u) == 0:
- raise ValueError("length should not be equal to 0")
- self.intervals[i].v = u[0]
- v_used.append(u[0])
- @staticmethod
- def _interpret_args(args):
- interval_wrong_order = "PlotInterval %s was given before any function(s)."
- interpret_error = "Could not interpret %s as a function or interval."
- functions, intervals = [], []
- if isinstance(args[0], GeometryEntity):
- for coords in list(args[0].arbitrary_point()):
- functions.append(coords)
- intervals.append(PlotInterval.try_parse(args[0].plot_interval()))
- else:
- for a in args:
- i = PlotInterval.try_parse(a)
- if i is not None:
- if len(functions) == 0:
- raise ValueError(interval_wrong_order % (str(i)))
- else:
- intervals.append(i)
- else:
- if is_sequence(a, include=str):
- raise ValueError(interpret_error % (str(a)))
- try:
- f = sympify(a)
- functions.append(f)
- except TypeError:
- raise ValueError(interpret_error % str(a))
- return functions, intervals
- @staticmethod
- def _extract_options(args, kwargs):
- newkwargs, newargs = {}, []
- for a in args:
- if isinstance(a, str):
- newkwargs = dict(newkwargs, **parse_option_string(a))
- else:
- newargs.append(a)
- newkwargs = dict(newkwargs, **kwargs)
- return newargs, newkwargs
- def var_count_error(is_independent, is_plotting):
- """
- Used to format an error message which differs
- slightly in 4 places.
- """
- if is_plotting:
- v = "Plotting"
- else:
- v = "Registering plot modes"
- if is_independent:
- n, s = PlotMode._i_var_max, "independent"
- else:
- n, s = PlotMode._d_var_max, "dependent"
- return ("%s with more than %i %s variables "
- "is not supported.") % (v, n, s)
|