functions.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779
  1. from sympy.utilities import dict_merge
  2. from sympy.utilities.iterables import iterable
  3. from sympy.physics.vector import (Dyadic, Vector, ReferenceFrame,
  4. Point, dynamicsymbols)
  5. from sympy.physics.vector.printing import (vprint, vsprint, vpprint, vlatex,
  6. init_vprinting)
  7. from sympy.physics.mechanics.particle import Particle
  8. from sympy.physics.mechanics.rigidbody import RigidBody
  9. from sympy.simplify.simplify import simplify
  10. from sympy.core.backend import (Matrix, sympify, Mul, Derivative, sin, cos,
  11. tan, AppliedUndef, S)
  12. __all__ = ['inertia',
  13. 'inertia_of_point_mass',
  14. 'linear_momentum',
  15. 'angular_momentum',
  16. 'kinetic_energy',
  17. 'potential_energy',
  18. 'Lagrangian',
  19. 'mechanics_printing',
  20. 'mprint',
  21. 'msprint',
  22. 'mpprint',
  23. 'mlatex',
  24. 'msubs',
  25. 'find_dynamicsymbols']
  26. # These are functions that we've moved and renamed during extracting the
  27. # basic vector calculus code from the mechanics packages.
  28. mprint = vprint
  29. msprint = vsprint
  30. mpprint = vpprint
  31. mlatex = vlatex
  32. def mechanics_printing(**kwargs):
  33. """
  34. Initializes time derivative printing for all SymPy objects in
  35. mechanics module.
  36. """
  37. init_vprinting(**kwargs)
  38. mechanics_printing.__doc__ = init_vprinting.__doc__
  39. def inertia(frame, ixx, iyy, izz, ixy=0, iyz=0, izx=0):
  40. """Simple way to create inertia Dyadic object.
  41. Explanation
  42. ===========
  43. If you do not know what a Dyadic is, just treat this like the inertia
  44. tensor. Then, do the easy thing and define it in a body-fixed frame.
  45. Parameters
  46. ==========
  47. frame : ReferenceFrame
  48. The frame the inertia is defined in
  49. ixx : Sympifyable
  50. the xx element in the inertia dyadic
  51. iyy : Sympifyable
  52. the yy element in the inertia dyadic
  53. izz : Sympifyable
  54. the zz element in the inertia dyadic
  55. ixy : Sympifyable
  56. the xy element in the inertia dyadic
  57. iyz : Sympifyable
  58. the yz element in the inertia dyadic
  59. izx : Sympifyable
  60. the zx element in the inertia dyadic
  61. Examples
  62. ========
  63. >>> from sympy.physics.mechanics import ReferenceFrame, inertia
  64. >>> N = ReferenceFrame('N')
  65. >>> inertia(N, 1, 2, 3)
  66. (N.x|N.x) + 2*(N.y|N.y) + 3*(N.z|N.z)
  67. """
  68. if not isinstance(frame, ReferenceFrame):
  69. raise TypeError('Need to define the inertia in a frame')
  70. ixx = sympify(ixx)
  71. ixy = sympify(ixy)
  72. iyy = sympify(iyy)
  73. iyz = sympify(iyz)
  74. izx = sympify(izx)
  75. izz = sympify(izz)
  76. ol = ixx * (frame.x | frame.x)
  77. ol += ixy * (frame.x | frame.y)
  78. ol += izx * (frame.x | frame.z)
  79. ol += ixy * (frame.y | frame.x)
  80. ol += iyy * (frame.y | frame.y)
  81. ol += iyz * (frame.y | frame.z)
  82. ol += izx * (frame.z | frame.x)
  83. ol += iyz * (frame.z | frame.y)
  84. ol += izz * (frame.z | frame.z)
  85. return ol
  86. def inertia_of_point_mass(mass, pos_vec, frame):
  87. """Inertia dyadic of a point mass relative to point O.
  88. Parameters
  89. ==========
  90. mass : Sympifyable
  91. Mass of the point mass
  92. pos_vec : Vector
  93. Position from point O to point mass
  94. frame : ReferenceFrame
  95. Reference frame to express the dyadic in
  96. Examples
  97. ========
  98. >>> from sympy import symbols
  99. >>> from sympy.physics.mechanics import ReferenceFrame, inertia_of_point_mass
  100. >>> N = ReferenceFrame('N')
  101. >>> r, m = symbols('r m')
  102. >>> px = r * N.x
  103. >>> inertia_of_point_mass(m, px, N)
  104. m*r**2*(N.y|N.y) + m*r**2*(N.z|N.z)
  105. """
  106. return mass * (((frame.x | frame.x) + (frame.y | frame.y) +
  107. (frame.z | frame.z)) * (pos_vec & pos_vec) -
  108. (pos_vec | pos_vec))
  109. def linear_momentum(frame, *body):
  110. """Linear momentum of the system.
  111. Explanation
  112. ===========
  113. This function returns the linear momentum of a system of Particle's and/or
  114. RigidBody's. The linear momentum of a system is equal to the vector sum of
  115. the linear momentum of its constituents. Consider a system, S, comprised of
  116. a rigid body, A, and a particle, P. The linear momentum of the system, L,
  117. is equal to the vector sum of the linear momentum of the particle, L1, and
  118. the linear momentum of the rigid body, L2, i.e.
  119. L = L1 + L2
  120. Parameters
  121. ==========
  122. frame : ReferenceFrame
  123. The frame in which linear momentum is desired.
  124. body1, body2, body3... : Particle and/or RigidBody
  125. The body (or bodies) whose linear momentum is required.
  126. Examples
  127. ========
  128. >>> from sympy.physics.mechanics import Point, Particle, ReferenceFrame
  129. >>> from sympy.physics.mechanics import RigidBody, outer, linear_momentum
  130. >>> N = ReferenceFrame('N')
  131. >>> P = Point('P')
  132. >>> P.set_vel(N, 10 * N.x)
  133. >>> Pa = Particle('Pa', P, 1)
  134. >>> Ac = Point('Ac')
  135. >>> Ac.set_vel(N, 25 * N.y)
  136. >>> I = outer(N.x, N.x)
  137. >>> A = RigidBody('A', Ac, N, 20, (I, Ac))
  138. >>> linear_momentum(N, A, Pa)
  139. 10*N.x + 500*N.y
  140. """
  141. if not isinstance(frame, ReferenceFrame):
  142. raise TypeError('Please specify a valid ReferenceFrame')
  143. else:
  144. linear_momentum_sys = Vector(0)
  145. for e in body:
  146. if isinstance(e, (RigidBody, Particle)):
  147. linear_momentum_sys += e.linear_momentum(frame)
  148. else:
  149. raise TypeError('*body must have only Particle or RigidBody')
  150. return linear_momentum_sys
  151. def angular_momentum(point, frame, *body):
  152. """Angular momentum of a system.
  153. Explanation
  154. ===========
  155. This function returns the angular momentum of a system of Particle's and/or
  156. RigidBody's. The angular momentum of such a system is equal to the vector
  157. sum of the angular momentum of its constituents. Consider a system, S,
  158. comprised of a rigid body, A, and a particle, P. The angular momentum of
  159. the system, H, is equal to the vector sum of the angular momentum of the
  160. particle, H1, and the angular momentum of the rigid body, H2, i.e.
  161. H = H1 + H2
  162. Parameters
  163. ==========
  164. point : Point
  165. The point about which angular momentum of the system is desired.
  166. frame : ReferenceFrame
  167. The frame in which angular momentum is desired.
  168. body1, body2, body3... : Particle and/or RigidBody
  169. The body (or bodies) whose angular momentum is required.
  170. Examples
  171. ========
  172. >>> from sympy.physics.mechanics import Point, Particle, ReferenceFrame
  173. >>> from sympy.physics.mechanics import RigidBody, outer, angular_momentum
  174. >>> N = ReferenceFrame('N')
  175. >>> O = Point('O')
  176. >>> O.set_vel(N, 0 * N.x)
  177. >>> P = O.locatenew('P', 1 * N.x)
  178. >>> P.set_vel(N, 10 * N.x)
  179. >>> Pa = Particle('Pa', P, 1)
  180. >>> Ac = O.locatenew('Ac', 2 * N.y)
  181. >>> Ac.set_vel(N, 5 * N.y)
  182. >>> a = ReferenceFrame('a')
  183. >>> a.set_ang_vel(N, 10 * N.z)
  184. >>> I = outer(N.z, N.z)
  185. >>> A = RigidBody('A', Ac, a, 20, (I, Ac))
  186. >>> angular_momentum(O, N, Pa, A)
  187. 10*N.z
  188. """
  189. if not isinstance(frame, ReferenceFrame):
  190. raise TypeError('Please enter a valid ReferenceFrame')
  191. if not isinstance(point, Point):
  192. raise TypeError('Please specify a valid Point')
  193. else:
  194. angular_momentum_sys = Vector(0)
  195. for e in body:
  196. if isinstance(e, (RigidBody, Particle)):
  197. angular_momentum_sys += e.angular_momentum(point, frame)
  198. else:
  199. raise TypeError('*body must have only Particle or RigidBody')
  200. return angular_momentum_sys
  201. def kinetic_energy(frame, *body):
  202. """Kinetic energy of a multibody system.
  203. Explanation
  204. ===========
  205. This function returns the kinetic energy of a system of Particle's and/or
  206. RigidBody's. The kinetic energy of such a system is equal to the sum of
  207. the kinetic energies of its constituents. Consider a system, S, comprising
  208. a rigid body, A, and a particle, P. The kinetic energy of the system, T,
  209. is equal to the vector sum of the kinetic energy of the particle, T1, and
  210. the kinetic energy of the rigid body, T2, i.e.
  211. T = T1 + T2
  212. Kinetic energy is a scalar.
  213. Parameters
  214. ==========
  215. frame : ReferenceFrame
  216. The frame in which the velocity or angular velocity of the body is
  217. defined.
  218. body1, body2, body3... : Particle and/or RigidBody
  219. The body (or bodies) whose kinetic energy is required.
  220. Examples
  221. ========
  222. >>> from sympy.physics.mechanics import Point, Particle, ReferenceFrame
  223. >>> from sympy.physics.mechanics import RigidBody, outer, kinetic_energy
  224. >>> N = ReferenceFrame('N')
  225. >>> O = Point('O')
  226. >>> O.set_vel(N, 0 * N.x)
  227. >>> P = O.locatenew('P', 1 * N.x)
  228. >>> P.set_vel(N, 10 * N.x)
  229. >>> Pa = Particle('Pa', P, 1)
  230. >>> Ac = O.locatenew('Ac', 2 * N.y)
  231. >>> Ac.set_vel(N, 5 * N.y)
  232. >>> a = ReferenceFrame('a')
  233. >>> a.set_ang_vel(N, 10 * N.z)
  234. >>> I = outer(N.z, N.z)
  235. >>> A = RigidBody('A', Ac, a, 20, (I, Ac))
  236. >>> kinetic_energy(N, Pa, A)
  237. 350
  238. """
  239. if not isinstance(frame, ReferenceFrame):
  240. raise TypeError('Please enter a valid ReferenceFrame')
  241. ke_sys = S.Zero
  242. for e in body:
  243. if isinstance(e, (RigidBody, Particle)):
  244. ke_sys += e.kinetic_energy(frame)
  245. else:
  246. raise TypeError('*body must have only Particle or RigidBody')
  247. return ke_sys
  248. def potential_energy(*body):
  249. """Potential energy of a multibody system.
  250. Explanation
  251. ===========
  252. This function returns the potential energy of a system of Particle's and/or
  253. RigidBody's. The potential energy of such a system is equal to the sum of
  254. the potential energy of its constituents. Consider a system, S, comprising
  255. a rigid body, A, and a particle, P. The potential energy of the system, V,
  256. is equal to the vector sum of the potential energy of the particle, V1, and
  257. the potential energy of the rigid body, V2, i.e.
  258. V = V1 + V2
  259. Potential energy is a scalar.
  260. Parameters
  261. ==========
  262. body1, body2, body3... : Particle and/or RigidBody
  263. The body (or bodies) whose potential energy is required.
  264. Examples
  265. ========
  266. >>> from sympy.physics.mechanics import Point, Particle, ReferenceFrame
  267. >>> from sympy.physics.mechanics import RigidBody, outer, potential_energy
  268. >>> from sympy import symbols
  269. >>> M, m, g, h = symbols('M m g h')
  270. >>> N = ReferenceFrame('N')
  271. >>> O = Point('O')
  272. >>> O.set_vel(N, 0 * N.x)
  273. >>> P = O.locatenew('P', 1 * N.x)
  274. >>> Pa = Particle('Pa', P, m)
  275. >>> Ac = O.locatenew('Ac', 2 * N.y)
  276. >>> a = ReferenceFrame('a')
  277. >>> I = outer(N.z, N.z)
  278. >>> A = RigidBody('A', Ac, a, M, (I, Ac))
  279. >>> Pa.potential_energy = m * g * h
  280. >>> A.potential_energy = M * g * h
  281. >>> potential_energy(Pa, A)
  282. M*g*h + g*h*m
  283. """
  284. pe_sys = S.Zero
  285. for e in body:
  286. if isinstance(e, (RigidBody, Particle)):
  287. pe_sys += e.potential_energy
  288. else:
  289. raise TypeError('*body must have only Particle or RigidBody')
  290. return pe_sys
  291. def gravity(acceleration, *bodies):
  292. """
  293. Returns a list of gravity forces given the acceleration
  294. due to gravity and any number of particles or rigidbodies.
  295. Example
  296. =======
  297. >>> from sympy.physics.mechanics import ReferenceFrame, Point, Particle, outer, RigidBody
  298. >>> from sympy.physics.mechanics.functions import gravity
  299. >>> from sympy import symbols
  300. >>> N = ReferenceFrame('N')
  301. >>> m, M, g = symbols('m M g')
  302. >>> F1, F2 = symbols('F1 F2')
  303. >>> po = Point('po')
  304. >>> pa = Particle('pa', po, m)
  305. >>> A = ReferenceFrame('A')
  306. >>> P = Point('P')
  307. >>> I = outer(A.x, A.x)
  308. >>> B = RigidBody('B', P, A, M, (I, P))
  309. >>> forceList = [(po, F1), (P, F2)]
  310. >>> forceList.extend(gravity(g*N.y, pa, B))
  311. >>> forceList
  312. [(po, F1), (P, F2), (po, g*m*N.y), (P, M*g*N.y)]
  313. """
  314. gravity_force = []
  315. if not bodies:
  316. raise TypeError("No bodies(instances of Particle or Rigidbody) were passed.")
  317. for e in bodies:
  318. point = getattr(e, 'masscenter', None)
  319. if point is None:
  320. point = e.point
  321. gravity_force.append((point, e.mass*acceleration))
  322. return gravity_force
  323. def center_of_mass(point, *bodies):
  324. """
  325. Returns the position vector from the given point to the center of mass
  326. of the given bodies(particles or rigidbodies).
  327. Example
  328. =======
  329. >>> from sympy import symbols, S
  330. >>> from sympy.physics.vector import Point
  331. >>> from sympy.physics.mechanics import Particle, ReferenceFrame, RigidBody, outer
  332. >>> from sympy.physics.mechanics.functions import center_of_mass
  333. >>> a = ReferenceFrame('a')
  334. >>> m = symbols('m', real=True)
  335. >>> p1 = Particle('p1', Point('p1_pt'), S(1))
  336. >>> p2 = Particle('p2', Point('p2_pt'), S(2))
  337. >>> p3 = Particle('p3', Point('p3_pt'), S(3))
  338. >>> p4 = Particle('p4', Point('p4_pt'), m)
  339. >>> b_f = ReferenceFrame('b_f')
  340. >>> b_cm = Point('b_cm')
  341. >>> mb = symbols('mb')
  342. >>> b = RigidBody('b', b_cm, b_f, mb, (outer(b_f.x, b_f.x), b_cm))
  343. >>> p2.point.set_pos(p1.point, a.x)
  344. >>> p3.point.set_pos(p1.point, a.x + a.y)
  345. >>> p4.point.set_pos(p1.point, a.y)
  346. >>> b.masscenter.set_pos(p1.point, a.y + a.z)
  347. >>> point_o=Point('o')
  348. >>> point_o.set_pos(p1.point, center_of_mass(p1.point, p1, p2, p3, p4, b))
  349. >>> expr = 5/(m + mb + 6)*a.x + (m + mb + 3)/(m + mb + 6)*a.y + mb/(m + mb + 6)*a.z
  350. >>> point_o.pos_from(p1.point)
  351. 5/(m + mb + 6)*a.x + (m + mb + 3)/(m + mb + 6)*a.y + mb/(m + mb + 6)*a.z
  352. """
  353. if not bodies:
  354. raise TypeError("No bodies(instances of Particle or Rigidbody) were passed.")
  355. total_mass = 0
  356. vec = Vector(0)
  357. for i in bodies:
  358. total_mass += i.mass
  359. masscenter = getattr(i, 'masscenter', None)
  360. if masscenter is None:
  361. masscenter = i.point
  362. vec += i.mass*masscenter.pos_from(point)
  363. return vec/total_mass
  364. def Lagrangian(frame, *body):
  365. """Lagrangian of a multibody system.
  366. Explanation
  367. ===========
  368. This function returns the Lagrangian of a system of Particle's and/or
  369. RigidBody's. The Lagrangian of such a system is equal to the difference
  370. between the kinetic energies and potential energies of its constituents. If
  371. T and V are the kinetic and potential energies of a system then it's
  372. Lagrangian, L, is defined as
  373. L = T - V
  374. The Lagrangian is a scalar.
  375. Parameters
  376. ==========
  377. frame : ReferenceFrame
  378. The frame in which the velocity or angular velocity of the body is
  379. defined to determine the kinetic energy.
  380. body1, body2, body3... : Particle and/or RigidBody
  381. The body (or bodies) whose Lagrangian is required.
  382. Examples
  383. ========
  384. >>> from sympy.physics.mechanics import Point, Particle, ReferenceFrame
  385. >>> from sympy.physics.mechanics import RigidBody, outer, Lagrangian
  386. >>> from sympy import symbols
  387. >>> M, m, g, h = symbols('M m g h')
  388. >>> N = ReferenceFrame('N')
  389. >>> O = Point('O')
  390. >>> O.set_vel(N, 0 * N.x)
  391. >>> P = O.locatenew('P', 1 * N.x)
  392. >>> P.set_vel(N, 10 * N.x)
  393. >>> Pa = Particle('Pa', P, 1)
  394. >>> Ac = O.locatenew('Ac', 2 * N.y)
  395. >>> Ac.set_vel(N, 5 * N.y)
  396. >>> a = ReferenceFrame('a')
  397. >>> a.set_ang_vel(N, 10 * N.z)
  398. >>> I = outer(N.z, N.z)
  399. >>> A = RigidBody('A', Ac, a, 20, (I, Ac))
  400. >>> Pa.potential_energy = m * g * h
  401. >>> A.potential_energy = M * g * h
  402. >>> Lagrangian(N, Pa, A)
  403. -M*g*h - g*h*m + 350
  404. """
  405. if not isinstance(frame, ReferenceFrame):
  406. raise TypeError('Please supply a valid ReferenceFrame')
  407. for e in body:
  408. if not isinstance(e, (RigidBody, Particle)):
  409. raise TypeError('*body must have only Particle or RigidBody')
  410. return kinetic_energy(frame, *body) - potential_energy(*body)
  411. def find_dynamicsymbols(expression, exclude=None, reference_frame=None):
  412. """Find all dynamicsymbols in expression.
  413. Explanation
  414. ===========
  415. If the optional ``exclude`` kwarg is used, only dynamicsymbols
  416. not in the iterable ``exclude`` are returned.
  417. If we intend to apply this function on a vector, the optional
  418. ``reference_frame`` is also used to inform about the corresponding frame
  419. with respect to which the dynamic symbols of the given vector is to be
  420. determined.
  421. Parameters
  422. ==========
  423. expression : SymPy expression
  424. exclude : iterable of dynamicsymbols, optional
  425. reference_frame : ReferenceFrame, optional
  426. The frame with respect to which the dynamic symbols of the
  427. given vector is to be determined.
  428. Examples
  429. ========
  430. >>> from sympy.physics.mechanics import dynamicsymbols, find_dynamicsymbols
  431. >>> from sympy.physics.mechanics import ReferenceFrame
  432. >>> x, y = dynamicsymbols('x, y')
  433. >>> expr = x + x.diff()*y
  434. >>> find_dynamicsymbols(expr)
  435. {x(t), y(t), Derivative(x(t), t)}
  436. >>> find_dynamicsymbols(expr, exclude=[x, y])
  437. {Derivative(x(t), t)}
  438. >>> a, b, c = dynamicsymbols('a, b, c')
  439. >>> A = ReferenceFrame('A')
  440. >>> v = a * A.x + b * A.y + c * A.z
  441. >>> find_dynamicsymbols(v, reference_frame=A)
  442. {a(t), b(t), c(t)}
  443. """
  444. t_set = {dynamicsymbols._t}
  445. if exclude:
  446. if iterable(exclude):
  447. exclude_set = set(exclude)
  448. else:
  449. raise TypeError("exclude kwarg must be iterable")
  450. else:
  451. exclude_set = set()
  452. if isinstance(expression, Vector):
  453. if reference_frame is None:
  454. raise ValueError("You must provide reference_frame when passing a "
  455. "vector expression, got %s." % reference_frame)
  456. else:
  457. expression = expression.to_matrix(reference_frame)
  458. return {i for i in expression.atoms(AppliedUndef, Derivative) if
  459. i.free_symbols == t_set} - exclude_set
  460. def msubs(expr, *sub_dicts, smart=False, **kwargs):
  461. """A custom subs for use on expressions derived in physics.mechanics.
  462. Traverses the expression tree once, performing the subs found in sub_dicts.
  463. Terms inside ``Derivative`` expressions are ignored:
  464. Examples
  465. ========
  466. >>> from sympy.physics.mechanics import dynamicsymbols, msubs
  467. >>> x = dynamicsymbols('x')
  468. >>> msubs(x.diff() + x, {x: 1})
  469. Derivative(x(t), t) + 1
  470. Note that sub_dicts can be a single dictionary, or several dictionaries:
  471. >>> x, y, z = dynamicsymbols('x, y, z')
  472. >>> sub1 = {x: 1, y: 2}
  473. >>> sub2 = {z: 3, x.diff(): 4}
  474. >>> msubs(x.diff() + x + y + z, sub1, sub2)
  475. 10
  476. If smart=True (default False), also checks for conditions that may result
  477. in ``nan``, but if simplified would yield a valid expression. For example:
  478. >>> from sympy import sin, tan
  479. >>> (sin(x)/tan(x)).subs(x, 0)
  480. nan
  481. >>> msubs(sin(x)/tan(x), {x: 0}, smart=True)
  482. 1
  483. It does this by first replacing all ``tan`` with ``sin/cos``. Then each
  484. node is traversed. If the node is a fraction, subs is first evaluated on
  485. the denominator. If this results in 0, simplification of the entire
  486. fraction is attempted. Using this selective simplification, only
  487. subexpressions that result in 1/0 are targeted, resulting in faster
  488. performance.
  489. """
  490. sub_dict = dict_merge(*sub_dicts)
  491. if smart:
  492. func = _smart_subs
  493. elif hasattr(expr, 'msubs'):
  494. return expr.msubs(sub_dict)
  495. else:
  496. func = lambda expr, sub_dict: _crawl(expr, _sub_func, sub_dict)
  497. if isinstance(expr, (Matrix, Vector, Dyadic)):
  498. return expr.applyfunc(lambda x: func(x, sub_dict))
  499. else:
  500. return func(expr, sub_dict)
  501. def _crawl(expr, func, *args, **kwargs):
  502. """Crawl the expression tree, and apply func to every node."""
  503. val = func(expr, *args, **kwargs)
  504. if val is not None:
  505. return val
  506. new_args = (_crawl(arg, func, *args, **kwargs) for arg in expr.args)
  507. return expr.func(*new_args)
  508. def _sub_func(expr, sub_dict):
  509. """Perform direct matching substitution, ignoring derivatives."""
  510. if expr in sub_dict:
  511. return sub_dict[expr]
  512. elif not expr.args or expr.is_Derivative:
  513. return expr
  514. def _tan_repl_func(expr):
  515. """Replace tan with sin/cos."""
  516. if isinstance(expr, tan):
  517. return sin(*expr.args) / cos(*expr.args)
  518. elif not expr.args or expr.is_Derivative:
  519. return expr
  520. def _smart_subs(expr, sub_dict):
  521. """Performs subs, checking for conditions that may result in `nan` or
  522. `oo`, and attempts to simplify them out.
  523. The expression tree is traversed twice, and the following steps are
  524. performed on each expression node:
  525. - First traverse:
  526. Replace all `tan` with `sin/cos`.
  527. - Second traverse:
  528. If node is a fraction, check if the denominator evaluates to 0.
  529. If so, attempt to simplify it out. Then if node is in sub_dict,
  530. sub in the corresponding value.
  531. """
  532. expr = _crawl(expr, _tan_repl_func)
  533. def _recurser(expr, sub_dict):
  534. # Decompose the expression into num, den
  535. num, den = _fraction_decomp(expr)
  536. if den != 1:
  537. # If there is a non trivial denominator, we need to handle it
  538. denom_subbed = _recurser(den, sub_dict)
  539. if denom_subbed.evalf() == 0:
  540. # If denom is 0 after this, attempt to simplify the bad expr
  541. expr = simplify(expr)
  542. else:
  543. # Expression won't result in nan, find numerator
  544. num_subbed = _recurser(num, sub_dict)
  545. return num_subbed / denom_subbed
  546. # We have to crawl the tree manually, because `expr` may have been
  547. # modified in the simplify step. First, perform subs as normal:
  548. val = _sub_func(expr, sub_dict)
  549. if val is not None:
  550. return val
  551. new_args = (_recurser(arg, sub_dict) for arg in expr.args)
  552. return expr.func(*new_args)
  553. return _recurser(expr, sub_dict)
  554. def _fraction_decomp(expr):
  555. """Return num, den such that expr = num/den."""
  556. if not isinstance(expr, Mul):
  557. return expr, 1
  558. num = []
  559. den = []
  560. for a in expr.args:
  561. if a.is_Pow and a.args[1] < 0:
  562. den.append(1 / a)
  563. else:
  564. num.append(a)
  565. if not den:
  566. return expr, 1
  567. num = Mul(*num)
  568. den = Mul(*den)
  569. return num, den
  570. def _f_list_parser(fl, ref_frame):
  571. """Parses the provided forcelist composed of items
  572. of the form (obj, force).
  573. Returns a tuple containing:
  574. vel_list: The velocity (ang_vel for Frames, vel for Points) in
  575. the provided reference frame.
  576. f_list: The forces.
  577. Used internally in the KanesMethod and LagrangesMethod classes.
  578. """
  579. def flist_iter():
  580. for pair in fl:
  581. obj, force = pair
  582. if isinstance(obj, ReferenceFrame):
  583. yield obj.ang_vel_in(ref_frame), force
  584. elif isinstance(obj, Point):
  585. yield obj.vel(ref_frame), force
  586. else:
  587. raise TypeError('First entry in each forcelist pair must '
  588. 'be a point or frame.')
  589. if not fl:
  590. vel_list, f_list = (), ()
  591. else:
  592. unzip = lambda l: list(zip(*l)) if l[0] else [(), ()]
  593. vel_list, f_list = unzip(list(flist_iter()))
  594. return vel_list, f_list
  595. def _validate_coordinates(coordinates=None, speeds=None, check_duplicates=True,
  596. is_dynamicsymbols=True):
  597. t_set = {dynamicsymbols._t}
  598. # Convert input to iterables
  599. if coordinates is None:
  600. coordinates = []
  601. elif not iterable(coordinates):
  602. coordinates = [coordinates]
  603. if speeds is None:
  604. speeds = []
  605. elif not iterable(speeds):
  606. speeds = [speeds]
  607. if check_duplicates: # Check for duplicates
  608. seen = set()
  609. coord_duplicates = {x for x in coordinates if x in seen or seen.add(x)}
  610. seen = set()
  611. speed_duplicates = {x for x in speeds if x in seen or seen.add(x)}
  612. overlap = set(coordinates).intersection(speeds)
  613. if coord_duplicates:
  614. raise ValueError(f'The generalized coordinates {coord_duplicates} '
  615. f'are duplicated, all generalized coordinates '
  616. f'should be unique.')
  617. if speed_duplicates:
  618. raise ValueError(f'The generalized speeds {speed_duplicates} are '
  619. f'duplicated, all generalized speeds should be '
  620. f'unique.')
  621. if overlap:
  622. raise ValueError(f'{overlap} are defined as both generalized '
  623. f'coordinates and generalized speeds.')
  624. if is_dynamicsymbols: # Check whether all coordinates are dynamicsymbols
  625. for coordinate in coordinates:
  626. if not (isinstance(coordinate, (AppliedUndef, Derivative)) and
  627. coordinate.free_symbols == t_set):
  628. raise ValueError(f'Generalized coordinate "{coordinate}" is not'
  629. f' a dynamicsymbol.')
  630. for speed in speeds:
  631. if not (isinstance(speed, (AppliedUndef, Derivative)) and
  632. speed.free_symbols == t_set):
  633. raise ValueError(f'Generalized speed "{speed}" is not a '
  634. f'dynamicsymbol.')