circuitplot.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370
  1. """Matplotlib based plotting of quantum circuits.
  2. Todo:
  3. * Optimize printing of large circuits.
  4. * Get this to work with single gates.
  5. * Do a better job checking the form of circuits to make sure it is a Mul of
  6. Gates.
  7. * Get multi-target gates plotting.
  8. * Get initial and final states to plot.
  9. * Get measurements to plot. Might need to rethink measurement as a gate
  10. issue.
  11. * Get scale and figsize to be handled in a better way.
  12. * Write some tests/examples!
  13. """
  14. from __future__ import annotations
  15. from sympy.core.mul import Mul
  16. from sympy.external import import_module
  17. from sympy.physics.quantum.gate import Gate, OneQubitGate, CGate, CGateS
  18. __all__ = [
  19. 'CircuitPlot',
  20. 'circuit_plot',
  21. 'labeller',
  22. 'Mz',
  23. 'Mx',
  24. 'CreateOneQubitGate',
  25. 'CreateCGate',
  26. ]
  27. np = import_module('numpy')
  28. matplotlib = import_module(
  29. 'matplotlib', import_kwargs={'fromlist': ['pyplot']},
  30. catch=(RuntimeError,)) # This is raised in environments that have no display.
  31. if np and matplotlib:
  32. pyplot = matplotlib.pyplot
  33. Line2D = matplotlib.lines.Line2D
  34. Circle = matplotlib.patches.Circle
  35. #from matplotlib import rc
  36. #rc('text',usetex=True)
  37. class CircuitPlot:
  38. """A class for managing a circuit plot."""
  39. scale = 1.0
  40. fontsize = 20.0
  41. linewidth = 1.0
  42. control_radius = 0.05
  43. not_radius = 0.15
  44. swap_delta = 0.05
  45. labels: list[str] = []
  46. inits: dict[str, str] = {}
  47. label_buffer = 0.5
  48. def __init__(self, c, nqubits, **kwargs):
  49. if not np or not matplotlib:
  50. raise ImportError('numpy or matplotlib not available.')
  51. self.circuit = c
  52. self.ngates = len(self.circuit.args)
  53. self.nqubits = nqubits
  54. self.update(kwargs)
  55. self._create_grid()
  56. self._create_figure()
  57. self._plot_wires()
  58. self._plot_gates()
  59. self._finish()
  60. def update(self, kwargs):
  61. """Load the kwargs into the instance dict."""
  62. self.__dict__.update(kwargs)
  63. def _create_grid(self):
  64. """Create the grid of wires."""
  65. scale = self.scale
  66. wire_grid = np.arange(0.0, self.nqubits*scale, scale, dtype=float)
  67. gate_grid = np.arange(0.0, self.ngates*scale, scale, dtype=float)
  68. self._wire_grid = wire_grid
  69. self._gate_grid = gate_grid
  70. def _create_figure(self):
  71. """Create the main matplotlib figure."""
  72. self._figure = pyplot.figure(
  73. figsize=(self.ngates*self.scale, self.nqubits*self.scale),
  74. facecolor='w',
  75. edgecolor='w'
  76. )
  77. ax = self._figure.add_subplot(
  78. 1, 1, 1,
  79. frameon=True
  80. )
  81. ax.set_axis_off()
  82. offset = 0.5*self.scale
  83. ax.set_xlim(self._gate_grid[0] - offset, self._gate_grid[-1] + offset)
  84. ax.set_ylim(self._wire_grid[0] - offset, self._wire_grid[-1] + offset)
  85. ax.set_aspect('equal')
  86. self._axes = ax
  87. def _plot_wires(self):
  88. """Plot the wires of the circuit diagram."""
  89. xstart = self._gate_grid[0]
  90. xstop = self._gate_grid[-1]
  91. xdata = (xstart - self.scale, xstop + self.scale)
  92. for i in range(self.nqubits):
  93. ydata = (self._wire_grid[i], self._wire_grid[i])
  94. line = Line2D(
  95. xdata, ydata,
  96. color='k',
  97. lw=self.linewidth
  98. )
  99. self._axes.add_line(line)
  100. if self.labels:
  101. init_label_buffer = 0
  102. if self.inits.get(self.labels[i]): init_label_buffer = 0.25
  103. self._axes.text(
  104. xdata[0]-self.label_buffer-init_label_buffer,ydata[0],
  105. render_label(self.labels[i],self.inits),
  106. size=self.fontsize,
  107. color='k',ha='center',va='center')
  108. self._plot_measured_wires()
  109. def _plot_measured_wires(self):
  110. ismeasured = self._measurements()
  111. xstop = self._gate_grid[-1]
  112. dy = 0.04 # amount to shift wires when doubled
  113. # Plot doubled wires after they are measured
  114. for im in ismeasured:
  115. xdata = (self._gate_grid[ismeasured[im]],xstop+self.scale)
  116. ydata = (self._wire_grid[im]+dy,self._wire_grid[im]+dy)
  117. line = Line2D(
  118. xdata, ydata,
  119. color='k',
  120. lw=self.linewidth
  121. )
  122. self._axes.add_line(line)
  123. # Also double any controlled lines off these wires
  124. for i,g in enumerate(self._gates()):
  125. if isinstance(g, (CGate, CGateS)):
  126. wires = g.controls + g.targets
  127. for wire in wires:
  128. if wire in ismeasured and \
  129. self._gate_grid[i] > self._gate_grid[ismeasured[wire]]:
  130. ydata = min(wires), max(wires)
  131. xdata = self._gate_grid[i]-dy, self._gate_grid[i]-dy
  132. line = Line2D(
  133. xdata, ydata,
  134. color='k',
  135. lw=self.linewidth
  136. )
  137. self._axes.add_line(line)
  138. def _gates(self):
  139. """Create a list of all gates in the circuit plot."""
  140. gates = []
  141. if isinstance(self.circuit, Mul):
  142. for g in reversed(self.circuit.args):
  143. if isinstance(g, Gate):
  144. gates.append(g)
  145. elif isinstance(self.circuit, Gate):
  146. gates.append(self.circuit)
  147. return gates
  148. def _plot_gates(self):
  149. """Iterate through the gates and plot each of them."""
  150. for i, gate in enumerate(self._gates()):
  151. gate.plot_gate(self, i)
  152. def _measurements(self):
  153. """Return a dict ``{i:j}`` where i is the index of the wire that has
  154. been measured, and j is the gate where the wire is measured.
  155. """
  156. ismeasured = {}
  157. for i,g in enumerate(self._gates()):
  158. if getattr(g,'measurement',False):
  159. for target in g.targets:
  160. if target in ismeasured:
  161. if ismeasured[target] > i:
  162. ismeasured[target] = i
  163. else:
  164. ismeasured[target] = i
  165. return ismeasured
  166. def _finish(self):
  167. # Disable clipping to make panning work well for large circuits.
  168. for o in self._figure.findobj():
  169. o.set_clip_on(False)
  170. def one_qubit_box(self, t, gate_idx, wire_idx):
  171. """Draw a box for a single qubit gate."""
  172. x = self._gate_grid[gate_idx]
  173. y = self._wire_grid[wire_idx]
  174. self._axes.text(
  175. x, y, t,
  176. color='k',
  177. ha='center',
  178. va='center',
  179. bbox={"ec": 'k', "fc": 'w', "fill": True, "lw": self.linewidth},
  180. size=self.fontsize
  181. )
  182. def two_qubit_box(self, t, gate_idx, wire_idx):
  183. """Draw a box for a two qubit gate. Does not work yet.
  184. """
  185. # x = self._gate_grid[gate_idx]
  186. # y = self._wire_grid[wire_idx]+0.5
  187. print(self._gate_grid)
  188. print(self._wire_grid)
  189. # unused:
  190. # obj = self._axes.text(
  191. # x, y, t,
  192. # color='k',
  193. # ha='center',
  194. # va='center',
  195. # bbox=dict(ec='k', fc='w', fill=True, lw=self.linewidth),
  196. # size=self.fontsize
  197. # )
  198. def control_line(self, gate_idx, min_wire, max_wire):
  199. """Draw a vertical control line."""
  200. xdata = (self._gate_grid[gate_idx], self._gate_grid[gate_idx])
  201. ydata = (self._wire_grid[min_wire], self._wire_grid[max_wire])
  202. line = Line2D(
  203. xdata, ydata,
  204. color='k',
  205. lw=self.linewidth
  206. )
  207. self._axes.add_line(line)
  208. def control_point(self, gate_idx, wire_idx):
  209. """Draw a control point."""
  210. x = self._gate_grid[gate_idx]
  211. y = self._wire_grid[wire_idx]
  212. radius = self.control_radius
  213. c = Circle(
  214. (x, y),
  215. radius*self.scale,
  216. ec='k',
  217. fc='k',
  218. fill=True,
  219. lw=self.linewidth
  220. )
  221. self._axes.add_patch(c)
  222. def not_point(self, gate_idx, wire_idx):
  223. """Draw a NOT gates as the circle with plus in the middle."""
  224. x = self._gate_grid[gate_idx]
  225. y = self._wire_grid[wire_idx]
  226. radius = self.not_radius
  227. c = Circle(
  228. (x, y),
  229. radius,
  230. ec='k',
  231. fc='w',
  232. fill=False,
  233. lw=self.linewidth
  234. )
  235. self._axes.add_patch(c)
  236. l = Line2D(
  237. (x, x), (y - radius, y + radius),
  238. color='k',
  239. lw=self.linewidth
  240. )
  241. self._axes.add_line(l)
  242. def swap_point(self, gate_idx, wire_idx):
  243. """Draw a swap point as a cross."""
  244. x = self._gate_grid[gate_idx]
  245. y = self._wire_grid[wire_idx]
  246. d = self.swap_delta
  247. l1 = Line2D(
  248. (x - d, x + d),
  249. (y - d, y + d),
  250. color='k',
  251. lw=self.linewidth
  252. )
  253. l2 = Line2D(
  254. (x - d, x + d),
  255. (y + d, y - d),
  256. color='k',
  257. lw=self.linewidth
  258. )
  259. self._axes.add_line(l1)
  260. self._axes.add_line(l2)
  261. def circuit_plot(c, nqubits, **kwargs):
  262. """Draw the circuit diagram for the circuit with nqubits.
  263. Parameters
  264. ==========
  265. c : circuit
  266. The circuit to plot. Should be a product of Gate instances.
  267. nqubits : int
  268. The number of qubits to include in the circuit. Must be at least
  269. as big as the largest ``min_qubits`` of the gates.
  270. """
  271. return CircuitPlot(c, nqubits, **kwargs)
  272. def render_label(label, inits={}):
  273. """Slightly more flexible way to render labels.
  274. >>> from sympy.physics.quantum.circuitplot import render_label
  275. >>> render_label('q0')
  276. '$\\\\left|q0\\\\right\\\\rangle$'
  277. >>> render_label('q0', {'q0':'0'})
  278. '$\\\\left|q0\\\\right\\\\rangle=\\\\left|0\\\\right\\\\rangle$'
  279. """
  280. init = inits.get(label)
  281. if init:
  282. return r'$\left|%s\right\rangle=\left|%s\right\rangle$' % (label, init)
  283. return r'$\left|%s\right\rangle$' % label
  284. def labeller(n, symbol='q'):
  285. """Autogenerate labels for wires of quantum circuits.
  286. Parameters
  287. ==========
  288. n : int
  289. number of qubits in the circuit.
  290. symbol : string
  291. A character string to precede all gate labels. E.g. 'q_0', 'q_1', etc.
  292. >>> from sympy.physics.quantum.circuitplot import labeller
  293. >>> labeller(2)
  294. ['q_1', 'q_0']
  295. >>> labeller(3,'j')
  296. ['j_2', 'j_1', 'j_0']
  297. """
  298. return ['%s_%d' % (symbol,n-i-1) for i in range(n)]
  299. class Mz(OneQubitGate):
  300. """Mock-up of a z measurement gate.
  301. This is in circuitplot rather than gate.py because it's not a real
  302. gate, it just draws one.
  303. """
  304. measurement = True
  305. gate_name='Mz'
  306. gate_name_latex='M_z'
  307. class Mx(OneQubitGate):
  308. """Mock-up of an x measurement gate.
  309. This is in circuitplot rather than gate.py because it's not a real
  310. gate, it just draws one.
  311. """
  312. measurement = True
  313. gate_name='Mx'
  314. gate_name_latex='M_x'
  315. class CreateOneQubitGate(type):
  316. def __new__(mcl, name, latexname=None):
  317. if not latexname:
  318. latexname = name
  319. return type(name + "Gate", (OneQubitGate,),
  320. {'gate_name': name, 'gate_name_latex': latexname})
  321. def CreateCGate(name, latexname=None):
  322. """Use a lexical closure to make a controlled gate.
  323. """
  324. if not latexname:
  325. latexname = name
  326. onequbitgate = CreateOneQubitGate(name, latexname)
  327. def ControlledGate(ctrls,target):
  328. return CGate(tuple(ctrls),onequbitgate(target))
  329. return ControlledGate