test_sql.py 114 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830283128322833283428352836283728382839284028412842284328442845284628472848284928502851285228532854285528562857285828592860286128622863286428652866286728682869287028712872287328742875287628772878287928802881288228832884288528862887288828892890289128922893289428952896289728982899290029012902290329042905290629072908290929102911291229132914291529162917291829192920292129222923292429252926292729282929293029312932293329342935293629372938293929402941294229432944294529462947294829492950295129522953295429552956295729582959296029612962296329642965296629672968296929702971297229732974297529762977297829792980298129822983298429852986298729882989299029912992299329942995299629972998299930003001300230033004300530063007300830093010301130123013301430153016301730183019302030213022302330243025302630273028302930303031303230333034303530363037303830393040304130423043304430453046304730483049305030513052305330543055305630573058305930603061306230633064306530663067306830693070307130723073307430753076307730783079308030813082308330843085308630873088308930903091309230933094309530963097309830993100310131023103310431053106310731083109311031113112311331143115311631173118311931203121312231233124312531263127312831293130313131323133313431353136313731383139314031413142314331443145314631473148314931503151315231533154315531563157315831593160316131623163316431653166316731683169317031713172317331743175317631773178317931803181318231833184318531863187318831893190319131923193319431953196319731983199320032013202320332043205320632073208320932103211321232133214321532163217321832193220322132223223322432253226322732283229323032313232323332343235323632373238323932403241324232433244324532463247324832493250325132523253325432553256325732583259326032613262326332643265326632673268326932703271327232733274327532763277327832793280328132823283328432853286328732883289329032913292329332943295329632973298329933003301330233033304330533063307330833093310331133123313331433153316331733183319332033213322
  1. """SQL io tests
  2. The SQL tests are broken down in different classes:
  3. - `PandasSQLTest`: base class with common methods for all test classes
  4. - Tests for the public API (only tests with sqlite3)
  5. - `_TestSQLApi` base class
  6. - `TestSQLApi`: test the public API with sqlalchemy engine
  7. - `TestSQLiteFallbackApi`: test the public API with a sqlite DBAPI
  8. connection
  9. - Tests for the different SQL flavors (flavor specific type conversions)
  10. - Tests for the sqlalchemy mode: `_TestSQLAlchemy` is the base class with
  11. common methods. The different tested flavors (sqlite3, MySQL,
  12. PostgreSQL) derive from the base class
  13. - Tests for the fallback mode (`TestSQLiteFallback`)
  14. """
  15. from __future__ import annotations
  16. import contextlib
  17. from contextlib import closing
  18. import csv
  19. from datetime import (
  20. date,
  21. datetime,
  22. time,
  23. timedelta,
  24. )
  25. from io import StringIO
  26. from pathlib import Path
  27. import sqlite3
  28. import numpy as np
  29. import pytest
  30. from pandas._libs import lib
  31. import pandas.util._test_decorators as td
  32. from pandas.core.dtypes.common import (
  33. is_datetime64_dtype,
  34. is_datetime64tz_dtype,
  35. )
  36. import pandas as pd
  37. from pandas import (
  38. DataFrame,
  39. Index,
  40. MultiIndex,
  41. Series,
  42. Timestamp,
  43. concat,
  44. date_range,
  45. isna,
  46. to_datetime,
  47. to_timedelta,
  48. )
  49. import pandas._testing as tm
  50. from pandas.core.arrays import (
  51. ArrowStringArray,
  52. StringArray,
  53. )
  54. from pandas.util.version import Version
  55. from pandas.io import sql
  56. from pandas.io.sql import (
  57. SQLAlchemyEngine,
  58. SQLDatabase,
  59. SQLiteDatabase,
  60. get_engine,
  61. pandasSQL_builder,
  62. read_sql_query,
  63. read_sql_table,
  64. )
  65. try:
  66. import sqlalchemy
  67. SQLALCHEMY_INSTALLED = True
  68. except ImportError:
  69. SQLALCHEMY_INSTALLED = False
  70. SQL_STRINGS = {
  71. "read_parameters": {
  72. "sqlite": "SELECT * FROM iris WHERE Name=? AND SepalLength=?",
  73. "mysql": "SELECT * FROM iris WHERE `Name`=%s AND `SepalLength`=%s",
  74. "postgresql": 'SELECT * FROM iris WHERE "Name"=%s AND "SepalLength"=%s',
  75. },
  76. "read_named_parameters": {
  77. "sqlite": """
  78. SELECT * FROM iris WHERE Name=:name AND SepalLength=:length
  79. """,
  80. "mysql": """
  81. SELECT * FROM iris WHERE
  82. `Name`=%(name)s AND `SepalLength`=%(length)s
  83. """,
  84. "postgresql": """
  85. SELECT * FROM iris WHERE
  86. "Name"=%(name)s AND "SepalLength"=%(length)s
  87. """,
  88. },
  89. "read_no_parameters_with_percent": {
  90. "sqlite": "SELECT * FROM iris WHERE Name LIKE '%'",
  91. "mysql": "SELECT * FROM iris WHERE `Name` LIKE '%'",
  92. "postgresql": "SELECT * FROM iris WHERE \"Name\" LIKE '%'",
  93. },
  94. }
  95. def iris_table_metadata(dialect: str):
  96. from sqlalchemy import (
  97. REAL,
  98. Column,
  99. Float,
  100. MetaData,
  101. String,
  102. Table,
  103. )
  104. dtype = Float if dialect == "postgresql" else REAL
  105. metadata = MetaData()
  106. iris = Table(
  107. "iris",
  108. metadata,
  109. Column("SepalLength", dtype),
  110. Column("SepalWidth", dtype),
  111. Column("PetalLength", dtype),
  112. Column("PetalWidth", dtype),
  113. Column("Name", String(200)),
  114. )
  115. return iris
  116. def create_and_load_iris_sqlite3(conn: sqlite3.Connection, iris_file: Path):
  117. cur = conn.cursor()
  118. stmt = """CREATE TABLE iris (
  119. "SepalLength" REAL,
  120. "SepalWidth" REAL,
  121. "PetalLength" REAL,
  122. "PetalWidth" REAL,
  123. "Name" TEXT
  124. )"""
  125. cur.execute(stmt)
  126. with iris_file.open(newline=None) as csvfile:
  127. reader = csv.reader(csvfile)
  128. next(reader)
  129. stmt = "INSERT INTO iris VALUES(?, ?, ?, ?, ?)"
  130. cur.executemany(stmt, reader)
  131. def create_and_load_iris(conn, iris_file: Path, dialect: str):
  132. from sqlalchemy import insert
  133. from sqlalchemy.engine import Engine
  134. iris = iris_table_metadata(dialect)
  135. with iris_file.open(newline=None) as csvfile:
  136. reader = csv.reader(csvfile)
  137. header = next(reader)
  138. params = [dict(zip(header, row)) for row in reader]
  139. stmt = insert(iris).values(params)
  140. if isinstance(conn, Engine):
  141. with conn.connect() as conn:
  142. with conn.begin():
  143. iris.drop(conn, checkfirst=True)
  144. iris.create(bind=conn)
  145. conn.execute(stmt)
  146. else:
  147. with conn.begin():
  148. iris.drop(conn, checkfirst=True)
  149. iris.create(bind=conn)
  150. conn.execute(stmt)
  151. def create_and_load_iris_view(conn):
  152. stmt = "CREATE VIEW iris_view AS SELECT * FROM iris"
  153. if isinstance(conn, sqlite3.Connection):
  154. cur = conn.cursor()
  155. cur.execute(stmt)
  156. else:
  157. from sqlalchemy import text
  158. from sqlalchemy.engine import Engine
  159. stmt = text(stmt)
  160. if isinstance(conn, Engine):
  161. with conn.connect() as conn:
  162. with conn.begin():
  163. conn.execute(stmt)
  164. else:
  165. with conn.begin():
  166. conn.execute(stmt)
  167. def types_table_metadata(dialect: str):
  168. from sqlalchemy import (
  169. TEXT,
  170. Boolean,
  171. Column,
  172. DateTime,
  173. Float,
  174. Integer,
  175. MetaData,
  176. Table,
  177. )
  178. date_type = TEXT if dialect == "sqlite" else DateTime
  179. bool_type = Integer if dialect == "sqlite" else Boolean
  180. metadata = MetaData()
  181. types = Table(
  182. "types",
  183. metadata,
  184. Column("TextCol", TEXT),
  185. Column("DateCol", date_type),
  186. Column("IntDateCol", Integer),
  187. Column("IntDateOnlyCol", Integer),
  188. Column("FloatCol", Float),
  189. Column("IntCol", Integer),
  190. Column("BoolCol", bool_type),
  191. Column("IntColWithNull", Integer),
  192. Column("BoolColWithNull", bool_type),
  193. )
  194. if dialect == "postgresql":
  195. types.append_column(Column("DateColWithTz", DateTime(timezone=True)))
  196. return types
  197. def create_and_load_types_sqlite3(conn: sqlite3.Connection, types_data: list[dict]):
  198. cur = conn.cursor()
  199. stmt = """CREATE TABLE types (
  200. "TextCol" TEXT,
  201. "DateCol" TEXT,
  202. "IntDateCol" INTEGER,
  203. "IntDateOnlyCol" INTEGER,
  204. "FloatCol" REAL,
  205. "IntCol" INTEGER,
  206. "BoolCol" INTEGER,
  207. "IntColWithNull" INTEGER,
  208. "BoolColWithNull" INTEGER
  209. )"""
  210. cur.execute(stmt)
  211. stmt = """
  212. INSERT INTO types
  213. VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?)
  214. """
  215. cur.executemany(stmt, types_data)
  216. def create_and_load_types(conn, types_data: list[dict], dialect: str):
  217. from sqlalchemy import insert
  218. from sqlalchemy.engine import Engine
  219. types = types_table_metadata(dialect)
  220. stmt = insert(types).values(types_data)
  221. if isinstance(conn, Engine):
  222. with conn.connect() as conn:
  223. with conn.begin():
  224. types.drop(conn, checkfirst=True)
  225. types.create(bind=conn)
  226. conn.execute(stmt)
  227. else:
  228. with conn.begin():
  229. types.drop(conn, checkfirst=True)
  230. types.create(bind=conn)
  231. conn.execute(stmt)
  232. def check_iris_frame(frame: DataFrame):
  233. pytype = frame.dtypes[0].type
  234. row = frame.iloc[0]
  235. assert issubclass(pytype, np.floating)
  236. tm.equalContents(row.values, [5.1, 3.5, 1.4, 0.2, "Iris-setosa"])
  237. assert frame.shape in ((150, 5), (8, 5))
  238. def count_rows(conn, table_name: str):
  239. stmt = f"SELECT count(*) AS count_1 FROM {table_name}"
  240. if isinstance(conn, sqlite3.Connection):
  241. cur = conn.cursor()
  242. return cur.execute(stmt).fetchone()[0]
  243. else:
  244. from sqlalchemy import create_engine
  245. from sqlalchemy.engine import Engine
  246. if isinstance(conn, str):
  247. try:
  248. engine = create_engine(conn)
  249. with engine.connect() as conn:
  250. return conn.exec_driver_sql(stmt).scalar_one()
  251. finally:
  252. engine.dispose()
  253. elif isinstance(conn, Engine):
  254. with conn.connect() as conn:
  255. return conn.exec_driver_sql(stmt).scalar_one()
  256. else:
  257. return conn.exec_driver_sql(stmt).scalar_one()
  258. @pytest.fixture
  259. def iris_path(datapath):
  260. iris_path = datapath("io", "data", "csv", "iris.csv")
  261. return Path(iris_path)
  262. @pytest.fixture
  263. def types_data():
  264. return [
  265. {
  266. "TextCol": "first",
  267. "DateCol": "2000-01-03 00:00:00",
  268. "IntDateCol": 535852800,
  269. "IntDateOnlyCol": 20101010,
  270. "FloatCol": 10.10,
  271. "IntCol": 1,
  272. "BoolCol": False,
  273. "IntColWithNull": 1,
  274. "BoolColWithNull": False,
  275. "DateColWithTz": "2000-01-01 00:00:00-08:00",
  276. },
  277. {
  278. "TextCol": "first",
  279. "DateCol": "2000-01-04 00:00:00",
  280. "IntDateCol": 1356998400,
  281. "IntDateOnlyCol": 20101212,
  282. "FloatCol": 10.10,
  283. "IntCol": 1,
  284. "BoolCol": False,
  285. "IntColWithNull": None,
  286. "BoolColWithNull": None,
  287. "DateColWithTz": "2000-06-01 00:00:00-07:00",
  288. },
  289. ]
  290. @pytest.fixture
  291. def types_data_frame(types_data):
  292. dtypes = {
  293. "TextCol": "str",
  294. "DateCol": "str",
  295. "IntDateCol": "int64",
  296. "IntDateOnlyCol": "int64",
  297. "FloatCol": "float",
  298. "IntCol": "int64",
  299. "BoolCol": "int64",
  300. "IntColWithNull": "float",
  301. "BoolColWithNull": "float",
  302. }
  303. df = DataFrame(types_data)
  304. return df[dtypes.keys()].astype(dtypes)
  305. @pytest.fixture
  306. def test_frame1():
  307. columns = ["index", "A", "B", "C", "D"]
  308. data = [
  309. (
  310. "2000-01-03 00:00:00",
  311. 0.980268513777,
  312. 3.68573087906,
  313. -0.364216805298,
  314. -1.15973806169,
  315. ),
  316. (
  317. "2000-01-04 00:00:00",
  318. 1.04791624281,
  319. -0.0412318367011,
  320. -0.16181208307,
  321. 0.212549316967,
  322. ),
  323. (
  324. "2000-01-05 00:00:00",
  325. 0.498580885705,
  326. 0.731167677815,
  327. -0.537677223318,
  328. 1.34627041952,
  329. ),
  330. (
  331. "2000-01-06 00:00:00",
  332. 1.12020151869,
  333. 1.56762092543,
  334. 0.00364077397681,
  335. 0.67525259227,
  336. ),
  337. ]
  338. return DataFrame(data, columns=columns)
  339. @pytest.fixture
  340. def test_frame3():
  341. columns = ["index", "A", "B"]
  342. data = [
  343. ("2000-01-03 00:00:00", 2**31 - 1, -1.987670),
  344. ("2000-01-04 00:00:00", -29, -0.0412318367011),
  345. ("2000-01-05 00:00:00", 20000, 0.731167677815),
  346. ("2000-01-06 00:00:00", -290867, 1.56762092543),
  347. ]
  348. return DataFrame(data, columns=columns)
  349. @pytest.fixture
  350. def mysql_pymysql_engine(iris_path, types_data):
  351. sqlalchemy = pytest.importorskip("sqlalchemy")
  352. pymysql = pytest.importorskip("pymysql")
  353. engine = sqlalchemy.create_engine(
  354. "mysql+pymysql://root@localhost:3306/pandas",
  355. connect_args={"client_flag": pymysql.constants.CLIENT.MULTI_STATEMENTS},
  356. poolclass=sqlalchemy.pool.NullPool,
  357. )
  358. insp = sqlalchemy.inspect(engine)
  359. if not insp.has_table("iris"):
  360. create_and_load_iris(engine, iris_path, "mysql")
  361. if not insp.has_table("types"):
  362. for entry in types_data:
  363. entry.pop("DateColWithTz")
  364. create_and_load_types(engine, types_data, "mysql")
  365. yield engine
  366. with engine.connect() as conn:
  367. with conn.begin():
  368. stmt = sqlalchemy.text("DROP TABLE IF EXISTS test_frame;")
  369. conn.execute(stmt)
  370. engine.dispose()
  371. @pytest.fixture
  372. def mysql_pymysql_conn(mysql_pymysql_engine):
  373. with mysql_pymysql_engine.connect() as conn:
  374. yield conn
  375. @pytest.fixture
  376. def postgresql_psycopg2_engine(iris_path, types_data):
  377. sqlalchemy = pytest.importorskip("sqlalchemy")
  378. pytest.importorskip("psycopg2")
  379. engine = sqlalchemy.create_engine(
  380. "postgresql+psycopg2://postgres:postgres@localhost:5432/pandas",
  381. poolclass=sqlalchemy.pool.NullPool,
  382. )
  383. insp = sqlalchemy.inspect(engine)
  384. if not insp.has_table("iris"):
  385. create_and_load_iris(engine, iris_path, "postgresql")
  386. if not insp.has_table("types"):
  387. create_and_load_types(engine, types_data, "postgresql")
  388. yield engine
  389. with engine.connect() as conn:
  390. with conn.begin():
  391. stmt = sqlalchemy.text("DROP TABLE IF EXISTS test_frame;")
  392. conn.execute(stmt)
  393. engine.dispose()
  394. @pytest.fixture
  395. def postgresql_psycopg2_conn(postgresql_psycopg2_engine):
  396. with postgresql_psycopg2_engine.connect() as conn:
  397. yield conn
  398. @pytest.fixture
  399. def sqlite_str():
  400. pytest.importorskip("sqlalchemy")
  401. with tm.ensure_clean() as name:
  402. yield "sqlite:///" + name
  403. @pytest.fixture
  404. def sqlite_engine(sqlite_str):
  405. sqlalchemy = pytest.importorskip("sqlalchemy")
  406. engine = sqlalchemy.create_engine(sqlite_str, poolclass=sqlalchemy.pool.NullPool)
  407. yield engine
  408. engine.dispose()
  409. @pytest.fixture
  410. def sqlite_conn(sqlite_engine):
  411. with sqlite_engine.connect() as conn:
  412. yield conn
  413. @pytest.fixture
  414. def sqlite_iris_str(sqlite_str, iris_path):
  415. sqlalchemy = pytest.importorskip("sqlalchemy")
  416. engine = sqlalchemy.create_engine(sqlite_str)
  417. create_and_load_iris(engine, iris_path, "sqlite")
  418. engine.dispose()
  419. return sqlite_str
  420. @pytest.fixture
  421. def sqlite_iris_engine(sqlite_engine, iris_path):
  422. create_and_load_iris(sqlite_engine, iris_path, "sqlite")
  423. return sqlite_engine
  424. @pytest.fixture
  425. def sqlite_iris_conn(sqlite_iris_engine):
  426. with sqlite_iris_engine.connect() as conn:
  427. yield conn
  428. @pytest.fixture
  429. def sqlite_buildin():
  430. with contextlib.closing(sqlite3.connect(":memory:")) as closing_conn:
  431. with closing_conn as conn:
  432. yield conn
  433. @pytest.fixture
  434. def sqlite_buildin_iris(sqlite_buildin, iris_path):
  435. create_and_load_iris_sqlite3(sqlite_buildin, iris_path)
  436. return sqlite_buildin
  437. mysql_connectable = [
  438. "mysql_pymysql_engine",
  439. "mysql_pymysql_conn",
  440. ]
  441. postgresql_connectable = [
  442. "postgresql_psycopg2_engine",
  443. "postgresql_psycopg2_conn",
  444. ]
  445. sqlite_connectable = [
  446. "sqlite_engine",
  447. "sqlite_conn",
  448. "sqlite_str",
  449. ]
  450. sqlite_iris_connectable = [
  451. "sqlite_iris_engine",
  452. "sqlite_iris_conn",
  453. "sqlite_iris_str",
  454. ]
  455. sqlalchemy_connectable = mysql_connectable + postgresql_connectable + sqlite_connectable
  456. sqlalchemy_connectable_iris = (
  457. mysql_connectable + postgresql_connectable + sqlite_iris_connectable
  458. )
  459. all_connectable = sqlalchemy_connectable + ["sqlite_buildin"]
  460. all_connectable_iris = sqlalchemy_connectable_iris + ["sqlite_buildin_iris"]
  461. @pytest.mark.db
  462. @pytest.mark.parametrize("conn", all_connectable)
  463. def test_dataframe_to_sql(conn, test_frame1, request):
  464. # GH 51086 if conn is sqlite_engine
  465. conn = request.getfixturevalue(conn)
  466. test_frame1.to_sql("test", conn, if_exists="append", index=False)
  467. @pytest.mark.db
  468. @pytest.mark.parametrize("conn", all_connectable)
  469. def test_dataframe_to_sql_arrow_dtypes(conn, request):
  470. # GH 52046
  471. pytest.importorskip("pyarrow")
  472. df = DataFrame(
  473. {
  474. "int": pd.array([1], dtype="int8[pyarrow]"),
  475. "datetime": pd.array(
  476. [datetime(2023, 1, 1)], dtype="timestamp[ns][pyarrow]"
  477. ),
  478. "timedelta": pd.array([timedelta(1)], dtype="duration[ns][pyarrow]"),
  479. "string": pd.array(["a"], dtype="string[pyarrow]"),
  480. }
  481. )
  482. conn = request.getfixturevalue(conn)
  483. with tm.assert_produces_warning(UserWarning, match="the 'timedelta'"):
  484. df.to_sql("test_arrow", conn, if_exists="replace", index=False)
  485. @pytest.mark.db
  486. @pytest.mark.parametrize("conn", all_connectable)
  487. def test_dataframe_to_sql_arrow_dtypes_missing(conn, request, nulls_fixture):
  488. # GH 52046
  489. pytest.importorskip("pyarrow")
  490. df = DataFrame(
  491. {
  492. "datetime": pd.array(
  493. [datetime(2023, 1, 1), nulls_fixture], dtype="timestamp[ns][pyarrow]"
  494. ),
  495. }
  496. )
  497. conn = request.getfixturevalue(conn)
  498. df.to_sql("test_arrow", conn, if_exists="replace", index=False)
  499. @pytest.mark.db
  500. @pytest.mark.parametrize("conn", all_connectable)
  501. @pytest.mark.parametrize("method", [None, "multi"])
  502. def test_to_sql(conn, method, test_frame1, request):
  503. conn = request.getfixturevalue(conn)
  504. with pandasSQL_builder(conn, need_transaction=True) as pandasSQL:
  505. pandasSQL.to_sql(test_frame1, "test_frame", method=method)
  506. assert pandasSQL.has_table("test_frame")
  507. assert count_rows(conn, "test_frame") == len(test_frame1)
  508. @pytest.mark.db
  509. @pytest.mark.parametrize("conn", all_connectable)
  510. @pytest.mark.parametrize("mode, num_row_coef", [("replace", 1), ("append", 2)])
  511. def test_to_sql_exist(conn, mode, num_row_coef, test_frame1, request):
  512. conn = request.getfixturevalue(conn)
  513. with pandasSQL_builder(conn, need_transaction=True) as pandasSQL:
  514. pandasSQL.to_sql(test_frame1, "test_frame", if_exists="fail")
  515. pandasSQL.to_sql(test_frame1, "test_frame", if_exists=mode)
  516. assert pandasSQL.has_table("test_frame")
  517. assert count_rows(conn, "test_frame") == num_row_coef * len(test_frame1)
  518. @pytest.mark.db
  519. @pytest.mark.parametrize("conn", all_connectable)
  520. def test_to_sql_exist_fail(conn, test_frame1, request):
  521. conn = request.getfixturevalue(conn)
  522. with pandasSQL_builder(conn, need_transaction=True) as pandasSQL:
  523. pandasSQL.to_sql(test_frame1, "test_frame", if_exists="fail")
  524. assert pandasSQL.has_table("test_frame")
  525. msg = "Table 'test_frame' already exists"
  526. with pytest.raises(ValueError, match=msg):
  527. pandasSQL.to_sql(test_frame1, "test_frame", if_exists="fail")
  528. @pytest.mark.db
  529. @pytest.mark.parametrize("conn", all_connectable_iris)
  530. def test_read_iris_query(conn, request):
  531. conn = request.getfixturevalue(conn)
  532. iris_frame = read_sql_query("SELECT * FROM iris", conn)
  533. check_iris_frame(iris_frame)
  534. iris_frame = pd.read_sql("SELECT * FROM iris", conn)
  535. check_iris_frame(iris_frame)
  536. iris_frame = pd.read_sql("SELECT * FROM iris where 0=1", conn)
  537. assert iris_frame.shape == (0, 5)
  538. assert "SepalWidth" in iris_frame.columns
  539. @pytest.mark.db
  540. @pytest.mark.parametrize("conn", all_connectable_iris)
  541. def test_read_iris_query_chunksize(conn, request):
  542. conn = request.getfixturevalue(conn)
  543. iris_frame = concat(read_sql_query("SELECT * FROM iris", conn, chunksize=7))
  544. check_iris_frame(iris_frame)
  545. iris_frame = concat(pd.read_sql("SELECT * FROM iris", conn, chunksize=7))
  546. check_iris_frame(iris_frame)
  547. iris_frame = concat(pd.read_sql("SELECT * FROM iris where 0=1", conn, chunksize=7))
  548. assert iris_frame.shape == (0, 5)
  549. assert "SepalWidth" in iris_frame.columns
  550. @pytest.mark.db
  551. @pytest.mark.parametrize("conn", sqlalchemy_connectable_iris)
  552. def test_read_iris_query_expression_with_parameter(conn, request):
  553. conn = request.getfixturevalue(conn)
  554. from sqlalchemy import (
  555. MetaData,
  556. Table,
  557. create_engine,
  558. select,
  559. )
  560. metadata = MetaData()
  561. autoload_con = create_engine(conn) if isinstance(conn, str) else conn
  562. iris = Table("iris", metadata, autoload_with=autoload_con)
  563. iris_frame = read_sql_query(
  564. select(iris), conn, params={"name": "Iris-setosa", "length": 5.1}
  565. )
  566. check_iris_frame(iris_frame)
  567. if isinstance(conn, str):
  568. autoload_con.dispose()
  569. @pytest.mark.db
  570. @pytest.mark.parametrize("conn", all_connectable_iris)
  571. def test_read_iris_query_string_with_parameter(conn, request):
  572. for db, query in SQL_STRINGS["read_parameters"].items():
  573. if db in conn:
  574. break
  575. else:
  576. raise KeyError(f"No part of {conn} found in SQL_STRINGS['read_parameters']")
  577. conn = request.getfixturevalue(conn)
  578. iris_frame = read_sql_query(query, conn, params=("Iris-setosa", 5.1))
  579. check_iris_frame(iris_frame)
  580. @pytest.mark.db
  581. @pytest.mark.parametrize("conn", sqlalchemy_connectable_iris)
  582. def test_read_iris_table(conn, request):
  583. # GH 51015 if conn = sqlite_iris_str
  584. conn = request.getfixturevalue(conn)
  585. iris_frame = read_sql_table("iris", conn)
  586. check_iris_frame(iris_frame)
  587. iris_frame = pd.read_sql("iris", conn)
  588. check_iris_frame(iris_frame)
  589. @pytest.mark.db
  590. @pytest.mark.parametrize("conn", sqlalchemy_connectable_iris)
  591. def test_read_iris_table_chunksize(conn, request):
  592. conn = request.getfixturevalue(conn)
  593. iris_frame = concat(read_sql_table("iris", conn, chunksize=7))
  594. check_iris_frame(iris_frame)
  595. iris_frame = concat(pd.read_sql("iris", conn, chunksize=7))
  596. check_iris_frame(iris_frame)
  597. @pytest.mark.db
  598. @pytest.mark.parametrize("conn", sqlalchemy_connectable)
  599. def test_to_sql_callable(conn, test_frame1, request):
  600. conn = request.getfixturevalue(conn)
  601. check = [] # used to double check function below is really being used
  602. def sample(pd_table, conn, keys, data_iter):
  603. check.append(1)
  604. data = [dict(zip(keys, row)) for row in data_iter]
  605. conn.execute(pd_table.table.insert(), data)
  606. with pandasSQL_builder(conn, need_transaction=True) as pandasSQL:
  607. pandasSQL.to_sql(test_frame1, "test_frame", method=sample)
  608. assert pandasSQL.has_table("test_frame")
  609. assert check == [1]
  610. assert count_rows(conn, "test_frame") == len(test_frame1)
  611. @pytest.mark.db
  612. @pytest.mark.parametrize("conn", mysql_connectable)
  613. def test_default_type_conversion(conn, request):
  614. conn = request.getfixturevalue(conn)
  615. df = sql.read_sql_table("types", conn)
  616. assert issubclass(df.FloatCol.dtype.type, np.floating)
  617. assert issubclass(df.IntCol.dtype.type, np.integer)
  618. # MySQL has no real BOOL type (it's an alias for TINYINT)
  619. assert issubclass(df.BoolCol.dtype.type, np.integer)
  620. # Int column with NA values stays as float
  621. assert issubclass(df.IntColWithNull.dtype.type, np.floating)
  622. # Bool column with NA = int column with NA values => becomes float
  623. assert issubclass(df.BoolColWithNull.dtype.type, np.floating)
  624. @pytest.mark.db
  625. @pytest.mark.parametrize("conn", mysql_connectable)
  626. def test_read_procedure(conn, request):
  627. conn = request.getfixturevalue(conn)
  628. # GH 7324
  629. # Although it is more an api test, it is added to the
  630. # mysql tests as sqlite does not have stored procedures
  631. from sqlalchemy import text
  632. from sqlalchemy.engine import Engine
  633. df = DataFrame({"a": [1, 2, 3], "b": [0.1, 0.2, 0.3]})
  634. df.to_sql("test_frame", conn, index=False)
  635. proc = """DROP PROCEDURE IF EXISTS get_testdb;
  636. CREATE PROCEDURE get_testdb ()
  637. BEGIN
  638. SELECT * FROM test_frame;
  639. END"""
  640. proc = text(proc)
  641. if isinstance(conn, Engine):
  642. with conn.connect() as engine_conn:
  643. with engine_conn.begin():
  644. engine_conn.execute(proc)
  645. else:
  646. with conn.begin():
  647. conn.execute(proc)
  648. res1 = sql.read_sql_query("CALL get_testdb();", conn)
  649. tm.assert_frame_equal(df, res1)
  650. # test delegation to read_sql_query
  651. res2 = sql.read_sql("CALL get_testdb();", conn)
  652. tm.assert_frame_equal(df, res2)
  653. @pytest.mark.db
  654. @pytest.mark.parametrize("conn", postgresql_connectable)
  655. @pytest.mark.parametrize("expected_count", [2, "Success!"])
  656. def test_copy_from_callable_insertion_method(conn, expected_count, request):
  657. # GH 8953
  658. # Example in io.rst found under _io.sql.method
  659. # not available in sqlite, mysql
  660. def psql_insert_copy(table, conn, keys, data_iter):
  661. # gets a DBAPI connection that can provide a cursor
  662. dbapi_conn = conn.connection
  663. with dbapi_conn.cursor() as cur:
  664. s_buf = StringIO()
  665. writer = csv.writer(s_buf)
  666. writer.writerows(data_iter)
  667. s_buf.seek(0)
  668. columns = ", ".join([f'"{k}"' for k in keys])
  669. if table.schema:
  670. table_name = f"{table.schema}.{table.name}"
  671. else:
  672. table_name = table.name
  673. sql_query = f"COPY {table_name} ({columns}) FROM STDIN WITH CSV"
  674. cur.copy_expert(sql=sql_query, file=s_buf)
  675. return expected_count
  676. conn = request.getfixturevalue(conn)
  677. expected = DataFrame({"col1": [1, 2], "col2": [0.1, 0.2], "col3": ["a", "n"]})
  678. result_count = expected.to_sql(
  679. "test_frame", conn, index=False, method=psql_insert_copy
  680. )
  681. # GH 46891
  682. if not isinstance(expected_count, int):
  683. assert result_count is None
  684. else:
  685. assert result_count == expected_count
  686. result = sql.read_sql_table("test_frame", conn)
  687. tm.assert_frame_equal(result, expected)
  688. def test_execute_typeerror(sqlite_iris_engine):
  689. with pytest.raises(TypeError, match="pandas.io.sql.execute requires a connection"):
  690. with tm.assert_produces_warning(
  691. FutureWarning,
  692. match="`pandas.io.sql.execute` is deprecated and "
  693. "will be removed in the future version.",
  694. ):
  695. sql.execute("select * from iris", sqlite_iris_engine)
  696. def test_execute_deprecated(sqlite_buildin_iris):
  697. # GH50185
  698. with tm.assert_produces_warning(
  699. FutureWarning,
  700. match="`pandas.io.sql.execute` is deprecated and "
  701. "will be removed in the future version.",
  702. ):
  703. sql.execute("select * from iris", sqlite_buildin_iris)
  704. class MixInBase:
  705. def teardown_method(self):
  706. # if setup fails, there may not be a connection to close.
  707. if hasattr(self, "conn"):
  708. self.conn.close()
  709. # use a fresh connection to ensure we can drop all tables.
  710. try:
  711. conn = self.connect()
  712. except (sqlalchemy.exc.OperationalError, sqlite3.OperationalError):
  713. pass
  714. else:
  715. with conn:
  716. for view in self._get_all_views(conn):
  717. self.drop_view(view, conn)
  718. for tbl in self._get_all_tables(conn):
  719. self.drop_table(tbl, conn)
  720. class SQLiteMixIn(MixInBase):
  721. def connect(self):
  722. return sqlite3.connect(":memory:")
  723. def drop_table(self, table_name, conn):
  724. conn.execute(f"DROP TABLE IF EXISTS {sql._get_valid_sqlite_name(table_name)}")
  725. conn.commit()
  726. def _get_all_tables(self, conn):
  727. c = conn.execute("SELECT name FROM sqlite_master WHERE type='table'")
  728. return [table[0] for table in c.fetchall()]
  729. def drop_view(self, view_name, conn):
  730. conn.execute(f"DROP VIEW IF EXISTS {sql._get_valid_sqlite_name(view_name)}")
  731. conn.commit()
  732. def _get_all_views(self, conn):
  733. c = conn.execute("SELECT name FROM sqlite_master WHERE type='view'")
  734. return [view[0] for view in c.fetchall()]
  735. class SQLAlchemyMixIn(MixInBase):
  736. @classmethod
  737. def teardown_class(cls):
  738. cls.engine.dispose()
  739. def connect(self):
  740. return self.engine.connect()
  741. def drop_table(self, table_name, conn):
  742. if conn.in_transaction():
  743. conn.get_transaction().rollback()
  744. with conn.begin():
  745. sql.SQLDatabase(conn).drop_table(table_name)
  746. def _get_all_tables(self, conn):
  747. from sqlalchemy import inspect
  748. return inspect(conn).get_table_names()
  749. def drop_view(self, view_name, conn):
  750. quoted_view = conn.engine.dialect.identifier_preparer.quote_identifier(
  751. view_name
  752. )
  753. if conn.in_transaction():
  754. conn.get_transaction().rollback()
  755. with conn.begin():
  756. conn.exec_driver_sql(f"DROP VIEW IF EXISTS {quoted_view}")
  757. def _get_all_views(self, conn):
  758. from sqlalchemy import inspect
  759. return inspect(conn).get_view_names()
  760. class PandasSQLTest:
  761. """
  762. Base class with common private methods for SQLAlchemy and fallback cases.
  763. """
  764. def load_iris_data(self, iris_path):
  765. self.drop_table("iris", self.conn)
  766. if isinstance(self.conn, sqlite3.Connection):
  767. create_and_load_iris_sqlite3(self.conn, iris_path)
  768. else:
  769. create_and_load_iris(self.conn, iris_path, self.flavor)
  770. def load_types_data(self, types_data):
  771. if self.flavor != "postgresql":
  772. for entry in types_data:
  773. entry.pop("DateColWithTz")
  774. if isinstance(self.conn, sqlite3.Connection):
  775. types_data = [tuple(entry.values()) for entry in types_data]
  776. create_and_load_types_sqlite3(self.conn, types_data)
  777. else:
  778. create_and_load_types(self.conn, types_data, self.flavor)
  779. def _read_sql_iris_parameter(self):
  780. query = SQL_STRINGS["read_parameters"][self.flavor]
  781. params = ("Iris-setosa", 5.1)
  782. iris_frame = self.pandasSQL.read_query(query, params=params)
  783. check_iris_frame(iris_frame)
  784. def _read_sql_iris_named_parameter(self):
  785. query = SQL_STRINGS["read_named_parameters"][self.flavor]
  786. params = {"name": "Iris-setosa", "length": 5.1}
  787. iris_frame = self.pandasSQL.read_query(query, params=params)
  788. check_iris_frame(iris_frame)
  789. def _read_sql_iris_no_parameter_with_percent(self):
  790. query = SQL_STRINGS["read_no_parameters_with_percent"][self.flavor]
  791. iris_frame = self.pandasSQL.read_query(query, params=None)
  792. check_iris_frame(iris_frame)
  793. def _to_sql_empty(self, test_frame1):
  794. self.drop_table("test_frame1", self.conn)
  795. assert self.pandasSQL.to_sql(test_frame1.iloc[:0], "test_frame1") == 0
  796. def _to_sql_with_sql_engine(self, test_frame1, engine="auto", **engine_kwargs):
  797. """`to_sql` with the `engine` param"""
  798. # mostly copied from this class's `_to_sql()` method
  799. self.drop_table("test_frame1", self.conn)
  800. assert (
  801. self.pandasSQL.to_sql(
  802. test_frame1, "test_frame1", engine=engine, **engine_kwargs
  803. )
  804. == 4
  805. )
  806. assert self.pandasSQL.has_table("test_frame1")
  807. num_entries = len(test_frame1)
  808. num_rows = count_rows(self.conn, "test_frame1")
  809. assert num_rows == num_entries
  810. # Nuke table
  811. self.drop_table("test_frame1", self.conn)
  812. def _roundtrip(self, test_frame1):
  813. self.drop_table("test_frame_roundtrip", self.conn)
  814. assert self.pandasSQL.to_sql(test_frame1, "test_frame_roundtrip") == 4
  815. result = self.pandasSQL.read_query("SELECT * FROM test_frame_roundtrip")
  816. result.set_index("level_0", inplace=True)
  817. # result.index.astype(int)
  818. result.index.name = None
  819. tm.assert_frame_equal(result, test_frame1)
  820. def _execute_sql(self):
  821. # drop_sql = "DROP TABLE IF EXISTS test" # should already be done
  822. iris_results = self.pandasSQL.execute("SELECT * FROM iris")
  823. row = iris_results.fetchone()
  824. tm.equalContents(row, [5.1, 3.5, 1.4, 0.2, "Iris-setosa"])
  825. def _to_sql_save_index(self):
  826. df = DataFrame.from_records(
  827. [(1, 2.1, "line1"), (2, 1.5, "line2")], columns=["A", "B", "C"], index=["A"]
  828. )
  829. assert self.pandasSQL.to_sql(df, "test_to_sql_saves_index") == 2
  830. ix_cols = self._get_index_columns("test_to_sql_saves_index")
  831. assert ix_cols == [["A"]]
  832. def _transaction_test(self):
  833. with self.pandasSQL.run_transaction() as trans:
  834. stmt = "CREATE TABLE test_trans (A INT, B TEXT)"
  835. if isinstance(self.pandasSQL, SQLiteDatabase):
  836. trans.execute(stmt)
  837. else:
  838. from sqlalchemy import text
  839. stmt = text(stmt)
  840. trans.execute(stmt)
  841. class DummyException(Exception):
  842. pass
  843. # Make sure when transaction is rolled back, no rows get inserted
  844. ins_sql = "INSERT INTO test_trans (A,B) VALUES (1, 'blah')"
  845. if isinstance(self.pandasSQL, SQLDatabase):
  846. from sqlalchemy import text
  847. ins_sql = text(ins_sql)
  848. try:
  849. with self.pandasSQL.run_transaction() as trans:
  850. trans.execute(ins_sql)
  851. raise DummyException("error")
  852. except DummyException:
  853. # ignore raised exception
  854. pass
  855. res = self.pandasSQL.read_query("SELECT * FROM test_trans")
  856. assert len(res) == 0
  857. # Make sure when transaction is committed, rows do get inserted
  858. with self.pandasSQL.run_transaction() as trans:
  859. trans.execute(ins_sql)
  860. res2 = self.pandasSQL.read_query("SELECT * FROM test_trans")
  861. assert len(res2) == 1
  862. # -----------------------------------------------------------------------------
  863. # -- Testing the public API
  864. class _TestSQLApi(PandasSQLTest):
  865. """
  866. Base class to test the public API.
  867. From this two classes are derived to run these tests for both the
  868. sqlalchemy mode (`TestSQLApi`) and the fallback mode
  869. (`TestSQLiteFallbackApi`). These tests are run with sqlite3. Specific
  870. tests for the different sql flavours are included in `_TestSQLAlchemy`.
  871. Notes:
  872. flavor can always be passed even in SQLAlchemy mode,
  873. should be correctly ignored.
  874. we don't use drop_table because that isn't part of the public api
  875. """
  876. flavor = "sqlite"
  877. mode: str
  878. @pytest.fixture(autouse=True)
  879. def setup_method(self, iris_path, types_data):
  880. self.conn = self.connect()
  881. self.load_iris_data(iris_path)
  882. self.load_types_data(types_data)
  883. self.load_test_data_and_sql()
  884. def load_test_data_and_sql(self):
  885. create_and_load_iris_view(self.conn)
  886. def test_read_sql_view(self):
  887. iris_frame = sql.read_sql_query("SELECT * FROM iris_view", self.conn)
  888. check_iris_frame(iris_frame)
  889. def test_read_sql_with_chunksize_no_result(self):
  890. query = "SELECT * FROM iris_view WHERE SepalLength < 0.0"
  891. with_batch = sql.read_sql_query(query, self.conn, chunksize=5)
  892. without_batch = sql.read_sql_query(query, self.conn)
  893. tm.assert_frame_equal(concat(with_batch), without_batch)
  894. def test_to_sql(self, test_frame1):
  895. sql.to_sql(test_frame1, "test_frame1", self.conn)
  896. assert sql.has_table("test_frame1", self.conn)
  897. def test_to_sql_fail(self, test_frame1):
  898. sql.to_sql(test_frame1, "test_frame2", self.conn, if_exists="fail")
  899. assert sql.has_table("test_frame2", self.conn)
  900. msg = "Table 'test_frame2' already exists"
  901. with pytest.raises(ValueError, match=msg):
  902. sql.to_sql(test_frame1, "test_frame2", self.conn, if_exists="fail")
  903. def test_to_sql_replace(self, test_frame1):
  904. sql.to_sql(test_frame1, "test_frame3", self.conn, if_exists="fail")
  905. # Add to table again
  906. sql.to_sql(test_frame1, "test_frame3", self.conn, if_exists="replace")
  907. assert sql.has_table("test_frame3", self.conn)
  908. num_entries = len(test_frame1)
  909. num_rows = count_rows(self.conn, "test_frame3")
  910. assert num_rows == num_entries
  911. def test_to_sql_append(self, test_frame1):
  912. assert sql.to_sql(test_frame1, "test_frame4", self.conn, if_exists="fail") == 4
  913. # Add to table again
  914. assert (
  915. sql.to_sql(test_frame1, "test_frame4", self.conn, if_exists="append") == 4
  916. )
  917. assert sql.has_table("test_frame4", self.conn)
  918. num_entries = 2 * len(test_frame1)
  919. num_rows = count_rows(self.conn, "test_frame4")
  920. assert num_rows == num_entries
  921. def test_to_sql_type_mapping(self, test_frame3):
  922. sql.to_sql(test_frame3, "test_frame5", self.conn, index=False)
  923. result = sql.read_sql("SELECT * FROM test_frame5", self.conn)
  924. tm.assert_frame_equal(test_frame3, result)
  925. def test_to_sql_series(self):
  926. s = Series(np.arange(5, dtype="int64"), name="series")
  927. sql.to_sql(s, "test_series", self.conn, index=False)
  928. s2 = sql.read_sql_query("SELECT * FROM test_series", self.conn)
  929. tm.assert_frame_equal(s.to_frame(), s2)
  930. def test_roundtrip(self, test_frame1):
  931. sql.to_sql(test_frame1, "test_frame_roundtrip", con=self.conn)
  932. result = sql.read_sql_query("SELECT * FROM test_frame_roundtrip", con=self.conn)
  933. # HACK!
  934. result.index = test_frame1.index
  935. result.set_index("level_0", inplace=True)
  936. result.index.astype(int)
  937. result.index.name = None
  938. tm.assert_frame_equal(result, test_frame1)
  939. def test_roundtrip_chunksize(self, test_frame1):
  940. sql.to_sql(
  941. test_frame1,
  942. "test_frame_roundtrip",
  943. con=self.conn,
  944. index=False,
  945. chunksize=2,
  946. )
  947. result = sql.read_sql_query("SELECT * FROM test_frame_roundtrip", con=self.conn)
  948. tm.assert_frame_equal(result, test_frame1)
  949. def test_execute_sql(self):
  950. # drop_sql = "DROP TABLE IF EXISTS test" # should already be done
  951. with sql.pandasSQL_builder(self.conn) as pandas_sql:
  952. iris_results = pandas_sql.execute("SELECT * FROM iris")
  953. row = iris_results.fetchone()
  954. tm.equalContents(row, [5.1, 3.5, 1.4, 0.2, "Iris-setosa"])
  955. def test_date_parsing(self):
  956. # Test date parsing in read_sql
  957. # No Parsing
  958. df = sql.read_sql_query("SELECT * FROM types", self.conn)
  959. assert not issubclass(df.DateCol.dtype.type, np.datetime64)
  960. df = sql.read_sql_query(
  961. "SELECT * FROM types", self.conn, parse_dates=["DateCol"]
  962. )
  963. assert issubclass(df.DateCol.dtype.type, np.datetime64)
  964. assert df.DateCol.tolist() == [
  965. Timestamp(2000, 1, 3, 0, 0, 0),
  966. Timestamp(2000, 1, 4, 0, 0, 0),
  967. ]
  968. df = sql.read_sql_query(
  969. "SELECT * FROM types",
  970. self.conn,
  971. parse_dates={"DateCol": "%Y-%m-%d %H:%M:%S"},
  972. )
  973. assert issubclass(df.DateCol.dtype.type, np.datetime64)
  974. assert df.DateCol.tolist() == [
  975. Timestamp(2000, 1, 3, 0, 0, 0),
  976. Timestamp(2000, 1, 4, 0, 0, 0),
  977. ]
  978. df = sql.read_sql_query(
  979. "SELECT * FROM types", self.conn, parse_dates=["IntDateCol"]
  980. )
  981. assert issubclass(df.IntDateCol.dtype.type, np.datetime64)
  982. assert df.IntDateCol.tolist() == [
  983. Timestamp(1986, 12, 25, 0, 0, 0),
  984. Timestamp(2013, 1, 1, 0, 0, 0),
  985. ]
  986. df = sql.read_sql_query(
  987. "SELECT * FROM types", self.conn, parse_dates={"IntDateCol": "s"}
  988. )
  989. assert issubclass(df.IntDateCol.dtype.type, np.datetime64)
  990. assert df.IntDateCol.tolist() == [
  991. Timestamp(1986, 12, 25, 0, 0, 0),
  992. Timestamp(2013, 1, 1, 0, 0, 0),
  993. ]
  994. df = sql.read_sql_query(
  995. "SELECT * FROM types",
  996. self.conn,
  997. parse_dates={"IntDateOnlyCol": "%Y%m%d"},
  998. )
  999. assert issubclass(df.IntDateOnlyCol.dtype.type, np.datetime64)
  1000. assert df.IntDateOnlyCol.tolist() == [
  1001. Timestamp("2010-10-10"),
  1002. Timestamp("2010-12-12"),
  1003. ]
  1004. @pytest.mark.parametrize("error", ["ignore", "raise", "coerce"])
  1005. @pytest.mark.parametrize(
  1006. "read_sql, text, mode",
  1007. [
  1008. (sql.read_sql, "SELECT * FROM types", ("sqlalchemy", "fallback")),
  1009. (sql.read_sql, "types", ("sqlalchemy")),
  1010. (
  1011. sql.read_sql_query,
  1012. "SELECT * FROM types",
  1013. ("sqlalchemy", "fallback"),
  1014. ),
  1015. (sql.read_sql_table, "types", ("sqlalchemy")),
  1016. ],
  1017. )
  1018. def test_custom_dateparsing_error(
  1019. self, read_sql, text, mode, error, types_data_frame
  1020. ):
  1021. if self.mode in mode:
  1022. expected = types_data_frame.astype({"DateCol": "datetime64[ns]"})
  1023. result = read_sql(
  1024. text,
  1025. con=self.conn,
  1026. parse_dates={
  1027. "DateCol": {"errors": error},
  1028. },
  1029. )
  1030. tm.assert_frame_equal(result, expected)
  1031. def test_date_and_index(self):
  1032. # Test case where same column appears in parse_date and index_col
  1033. df = sql.read_sql_query(
  1034. "SELECT * FROM types",
  1035. self.conn,
  1036. index_col="DateCol",
  1037. parse_dates=["DateCol", "IntDateCol"],
  1038. )
  1039. assert issubclass(df.index.dtype.type, np.datetime64)
  1040. assert issubclass(df.IntDateCol.dtype.type, np.datetime64)
  1041. def test_timedelta(self):
  1042. # see #6921
  1043. df = to_timedelta(Series(["00:00:01", "00:00:03"], name="foo")).to_frame()
  1044. with tm.assert_produces_warning(UserWarning):
  1045. result_count = df.to_sql("test_timedelta", self.conn)
  1046. assert result_count == 2
  1047. result = sql.read_sql_query("SELECT * FROM test_timedelta", self.conn)
  1048. tm.assert_series_equal(result["foo"], df["foo"].view("int64"))
  1049. def test_complex_raises(self):
  1050. df = DataFrame({"a": [1 + 1j, 2j]})
  1051. msg = "Complex datatypes not supported"
  1052. with pytest.raises(ValueError, match=msg):
  1053. assert df.to_sql("test_complex", self.conn) is None
  1054. @pytest.mark.parametrize(
  1055. "index_name,index_label,expected",
  1056. [
  1057. # no index name, defaults to 'index'
  1058. (None, None, "index"),
  1059. # specifying index_label
  1060. (None, "other_label", "other_label"),
  1061. # using the index name
  1062. ("index_name", None, "index_name"),
  1063. # has index name, but specifying index_label
  1064. ("index_name", "other_label", "other_label"),
  1065. # index name is integer
  1066. (0, None, "0"),
  1067. # index name is None but index label is integer
  1068. (None, 0, "0"),
  1069. ],
  1070. )
  1071. def test_to_sql_index_label(self, index_name, index_label, expected):
  1072. temp_frame = DataFrame({"col1": range(4)})
  1073. temp_frame.index.name = index_name
  1074. query = "SELECT * FROM test_index_label"
  1075. sql.to_sql(temp_frame, "test_index_label", self.conn, index_label=index_label)
  1076. frame = sql.read_sql_query(query, self.conn)
  1077. assert frame.columns[0] == expected
  1078. def test_to_sql_index_label_multiindex(self):
  1079. expected_row_count = 4
  1080. temp_frame = DataFrame(
  1081. {"col1": range(4)},
  1082. index=MultiIndex.from_product([("A0", "A1"), ("B0", "B1")]),
  1083. )
  1084. # no index name, defaults to 'level_0' and 'level_1'
  1085. result = sql.to_sql(temp_frame, "test_index_label", self.conn)
  1086. assert result == expected_row_count
  1087. frame = sql.read_sql_query("SELECT * FROM test_index_label", self.conn)
  1088. assert frame.columns[0] == "level_0"
  1089. assert frame.columns[1] == "level_1"
  1090. # specifying index_label
  1091. result = sql.to_sql(
  1092. temp_frame,
  1093. "test_index_label",
  1094. self.conn,
  1095. if_exists="replace",
  1096. index_label=["A", "B"],
  1097. )
  1098. assert result == expected_row_count
  1099. frame = sql.read_sql_query("SELECT * FROM test_index_label", self.conn)
  1100. assert frame.columns[:2].tolist() == ["A", "B"]
  1101. # using the index name
  1102. temp_frame.index.names = ["A", "B"]
  1103. result = sql.to_sql(
  1104. temp_frame, "test_index_label", self.conn, if_exists="replace"
  1105. )
  1106. assert result == expected_row_count
  1107. frame = sql.read_sql_query("SELECT * FROM test_index_label", self.conn)
  1108. assert frame.columns[:2].tolist() == ["A", "B"]
  1109. # has index name, but specifying index_label
  1110. result = sql.to_sql(
  1111. temp_frame,
  1112. "test_index_label",
  1113. self.conn,
  1114. if_exists="replace",
  1115. index_label=["C", "D"],
  1116. )
  1117. assert result == expected_row_count
  1118. frame = sql.read_sql_query("SELECT * FROM test_index_label", self.conn)
  1119. assert frame.columns[:2].tolist() == ["C", "D"]
  1120. msg = "Length of 'index_label' should match number of levels, which is 2"
  1121. with pytest.raises(ValueError, match=msg):
  1122. sql.to_sql(
  1123. temp_frame,
  1124. "test_index_label",
  1125. self.conn,
  1126. if_exists="replace",
  1127. index_label="C",
  1128. )
  1129. def test_multiindex_roundtrip(self):
  1130. df = DataFrame.from_records(
  1131. [(1, 2.1, "line1"), (2, 1.5, "line2")],
  1132. columns=["A", "B", "C"],
  1133. index=["A", "B"],
  1134. )
  1135. df.to_sql("test_multiindex_roundtrip", self.conn)
  1136. result = sql.read_sql_query(
  1137. "SELECT * FROM test_multiindex_roundtrip", self.conn, index_col=["A", "B"]
  1138. )
  1139. tm.assert_frame_equal(df, result, check_index_type=True)
  1140. @pytest.mark.parametrize(
  1141. "dtype",
  1142. [
  1143. None,
  1144. int,
  1145. float,
  1146. {"A": int, "B": float},
  1147. ],
  1148. )
  1149. def test_dtype_argument(self, dtype):
  1150. # GH10285 Add dtype argument to read_sql_query
  1151. df = DataFrame([[1.2, 3.4], [5.6, 7.8]], columns=["A", "B"])
  1152. assert df.to_sql("test_dtype_argument", self.conn) == 2
  1153. expected = df.astype(dtype)
  1154. result = sql.read_sql_query(
  1155. "SELECT A, B FROM test_dtype_argument", con=self.conn, dtype=dtype
  1156. )
  1157. tm.assert_frame_equal(result, expected)
  1158. def test_integer_col_names(self):
  1159. df = DataFrame([[1, 2], [3, 4]], columns=[0, 1])
  1160. sql.to_sql(df, "test_frame_integer_col_names", self.conn, if_exists="replace")
  1161. def test_get_schema(self, test_frame1):
  1162. create_sql = sql.get_schema(test_frame1, "test", con=self.conn)
  1163. assert "CREATE" in create_sql
  1164. def test_get_schema_with_schema(self, test_frame1):
  1165. # GH28486
  1166. create_sql = sql.get_schema(test_frame1, "test", con=self.conn, schema="pypi")
  1167. assert "CREATE TABLE pypi." in create_sql
  1168. def test_get_schema_dtypes(self):
  1169. if self.mode == "sqlalchemy":
  1170. from sqlalchemy import Integer
  1171. dtype = Integer
  1172. else:
  1173. dtype = "INTEGER"
  1174. float_frame = DataFrame({"a": [1.1, 1.2], "b": [2.1, 2.2]})
  1175. create_sql = sql.get_schema(
  1176. float_frame, "test", con=self.conn, dtype={"b": dtype}
  1177. )
  1178. assert "CREATE" in create_sql
  1179. assert "INTEGER" in create_sql
  1180. def test_get_schema_keys(self, test_frame1):
  1181. frame = DataFrame({"Col1": [1.1, 1.2], "Col2": [2.1, 2.2]})
  1182. create_sql = sql.get_schema(frame, "test", con=self.conn, keys="Col1")
  1183. constraint_sentence = 'CONSTRAINT test_pk PRIMARY KEY ("Col1")'
  1184. assert constraint_sentence in create_sql
  1185. # multiple columns as key (GH10385)
  1186. create_sql = sql.get_schema(test_frame1, "test", con=self.conn, keys=["A", "B"])
  1187. constraint_sentence = 'CONSTRAINT test_pk PRIMARY KEY ("A", "B")'
  1188. assert constraint_sentence in create_sql
  1189. def test_chunksize_read(self):
  1190. df = DataFrame(np.random.randn(22, 5), columns=list("abcde"))
  1191. df.to_sql("test_chunksize", self.conn, index=False)
  1192. # reading the query in one time
  1193. res1 = sql.read_sql_query("select * from test_chunksize", self.conn)
  1194. # reading the query in chunks with read_sql_query
  1195. res2 = DataFrame()
  1196. i = 0
  1197. sizes = [5, 5, 5, 5, 2]
  1198. for chunk in sql.read_sql_query(
  1199. "select * from test_chunksize", self.conn, chunksize=5
  1200. ):
  1201. res2 = concat([res2, chunk], ignore_index=True)
  1202. assert len(chunk) == sizes[i]
  1203. i += 1
  1204. tm.assert_frame_equal(res1, res2)
  1205. # reading the query in chunks with read_sql_query
  1206. if self.mode == "sqlalchemy":
  1207. res3 = DataFrame()
  1208. i = 0
  1209. sizes = [5, 5, 5, 5, 2]
  1210. for chunk in sql.read_sql_table("test_chunksize", self.conn, chunksize=5):
  1211. res3 = concat([res3, chunk], ignore_index=True)
  1212. assert len(chunk) == sizes[i]
  1213. i += 1
  1214. tm.assert_frame_equal(res1, res3)
  1215. def test_categorical(self):
  1216. # GH8624
  1217. # test that categorical gets written correctly as dense column
  1218. df = DataFrame(
  1219. {
  1220. "person_id": [1, 2, 3],
  1221. "person_name": ["John P. Doe", "Jane Dove", "John P. Doe"],
  1222. }
  1223. )
  1224. df2 = df.copy()
  1225. df2["person_name"] = df2["person_name"].astype("category")
  1226. df2.to_sql("test_categorical", self.conn, index=False)
  1227. res = sql.read_sql_query("SELECT * FROM test_categorical", self.conn)
  1228. tm.assert_frame_equal(res, df)
  1229. def test_unicode_column_name(self):
  1230. # GH 11431
  1231. df = DataFrame([[1, 2], [3, 4]], columns=["\xe9", "b"])
  1232. df.to_sql("test_unicode", self.conn, index=False)
  1233. def test_escaped_table_name(self):
  1234. # GH 13206
  1235. df = DataFrame({"A": [0, 1, 2], "B": [0.2, np.nan, 5.6]})
  1236. df.to_sql("d1187b08-4943-4c8d-a7f6", self.conn, index=False)
  1237. res = sql.read_sql_query("SELECT * FROM `d1187b08-4943-4c8d-a7f6`", self.conn)
  1238. tm.assert_frame_equal(res, df)
  1239. def test_read_sql_duplicate_columns(self):
  1240. # GH#53117
  1241. df = DataFrame({"a": [1, 2, 3], "b": [0.1, 0.2, 0.3], "c": 1})
  1242. df.to_sql("test_table", self.conn, index=False)
  1243. result = pd.read_sql("SELECT a, b, a +1 as a, c FROM test_table;", self.conn)
  1244. expected = DataFrame(
  1245. [[1, 0.1, 2, 1], [2, 0.2, 3, 1], [3, 0.3, 4, 1]],
  1246. columns=["a", "b", "a", "c"],
  1247. )
  1248. tm.assert_frame_equal(result, expected)
  1249. @pytest.mark.skipif(not SQLALCHEMY_INSTALLED, reason="SQLAlchemy not installed")
  1250. class TestSQLApi(SQLAlchemyMixIn, _TestSQLApi):
  1251. """
  1252. Test the public API as it would be used directly
  1253. Tests for `read_sql_table` are included here, as this is specific for the
  1254. sqlalchemy mode.
  1255. """
  1256. flavor = "sqlite"
  1257. mode = "sqlalchemy"
  1258. @classmethod
  1259. def setup_class(cls):
  1260. cls.engine = sqlalchemy.create_engine("sqlite:///:memory:")
  1261. def test_read_table_columns(self, test_frame1):
  1262. # test columns argument in read_table
  1263. sql.to_sql(test_frame1, "test_frame", self.conn)
  1264. cols = ["A", "B"]
  1265. result = sql.read_sql_table("test_frame", self.conn, columns=cols)
  1266. assert result.columns.tolist() == cols
  1267. def test_read_table_index_col(self, test_frame1):
  1268. # test columns argument in read_table
  1269. sql.to_sql(test_frame1, "test_frame", self.conn)
  1270. result = sql.read_sql_table("test_frame", self.conn, index_col="index")
  1271. assert result.index.names == ["index"]
  1272. result = sql.read_sql_table("test_frame", self.conn, index_col=["A", "B"])
  1273. assert result.index.names == ["A", "B"]
  1274. result = sql.read_sql_table(
  1275. "test_frame", self.conn, index_col=["A", "B"], columns=["C", "D"]
  1276. )
  1277. assert result.index.names == ["A", "B"]
  1278. assert result.columns.tolist() == ["C", "D"]
  1279. def test_read_sql_delegate(self):
  1280. iris_frame1 = sql.read_sql_query("SELECT * FROM iris", self.conn)
  1281. iris_frame2 = sql.read_sql("SELECT * FROM iris", self.conn)
  1282. tm.assert_frame_equal(iris_frame1, iris_frame2)
  1283. iris_frame1 = sql.read_sql_table("iris", self.conn)
  1284. iris_frame2 = sql.read_sql("iris", self.conn)
  1285. tm.assert_frame_equal(iris_frame1, iris_frame2)
  1286. def test_not_reflect_all_tables(self):
  1287. from sqlalchemy import text
  1288. from sqlalchemy.engine import Engine
  1289. # create invalid table
  1290. query_list = [
  1291. text("CREATE TABLE invalid (x INTEGER, y UNKNOWN);"),
  1292. text("CREATE TABLE other_table (x INTEGER, y INTEGER);"),
  1293. ]
  1294. for query in query_list:
  1295. if isinstance(self.conn, Engine):
  1296. with self.conn.connect() as conn:
  1297. with conn.begin():
  1298. conn.execute(query)
  1299. else:
  1300. with self.conn.begin():
  1301. self.conn.execute(query)
  1302. with tm.assert_produces_warning(None):
  1303. sql.read_sql_table("other_table", self.conn)
  1304. sql.read_sql_query("SELECT * FROM other_table", self.conn)
  1305. def test_warning_case_insensitive_table_name(self, test_frame1):
  1306. # see gh-7815
  1307. with tm.assert_produces_warning(
  1308. UserWarning,
  1309. match=(
  1310. r"The provided table name 'TABLE1' is not found exactly as such in "
  1311. r"the database after writing the table, possibly due to case "
  1312. r"sensitivity issues. Consider using lower case table names."
  1313. ),
  1314. ):
  1315. sql.SQLDatabase(self.conn).check_case_sensitive("TABLE1", "")
  1316. # Test that the warning is certainly NOT triggered in a normal case.
  1317. with tm.assert_produces_warning(None):
  1318. test_frame1.to_sql("CaseSensitive", self.conn)
  1319. def _get_index_columns(self, tbl_name):
  1320. from sqlalchemy.engine import reflection
  1321. insp = reflection.Inspector.from_engine(self.conn)
  1322. ixs = insp.get_indexes("test_index_saved")
  1323. ixs = [i["column_names"] for i in ixs]
  1324. return ixs
  1325. def test_sqlalchemy_type_mapping(self):
  1326. from sqlalchemy import TIMESTAMP
  1327. # Test Timestamp objects (no datetime64 because of timezone) (GH9085)
  1328. df = DataFrame(
  1329. {"time": to_datetime(["2014-12-12 01:54", "2014-12-11 02:54"], utc=True)}
  1330. )
  1331. db = sql.SQLDatabase(self.conn)
  1332. table = sql.SQLTable("test_type", db, frame=df)
  1333. # GH 9086: TIMESTAMP is the suggested type for datetimes with timezones
  1334. assert isinstance(table.table.c["time"].type, TIMESTAMP)
  1335. @pytest.mark.parametrize(
  1336. "integer, expected",
  1337. [
  1338. ("int8", "SMALLINT"),
  1339. ("Int8", "SMALLINT"),
  1340. ("uint8", "SMALLINT"),
  1341. ("UInt8", "SMALLINT"),
  1342. ("int16", "SMALLINT"),
  1343. ("Int16", "SMALLINT"),
  1344. ("uint16", "INTEGER"),
  1345. ("UInt16", "INTEGER"),
  1346. ("int32", "INTEGER"),
  1347. ("Int32", "INTEGER"),
  1348. ("uint32", "BIGINT"),
  1349. ("UInt32", "BIGINT"),
  1350. ("int64", "BIGINT"),
  1351. ("Int64", "BIGINT"),
  1352. (int, "BIGINT" if np.dtype(int).name == "int64" else "INTEGER"),
  1353. ],
  1354. )
  1355. def test_sqlalchemy_integer_mapping(self, integer, expected):
  1356. # GH35076 Map pandas integer to optimal SQLAlchemy integer type
  1357. df = DataFrame([0, 1], columns=["a"], dtype=integer)
  1358. db = sql.SQLDatabase(self.conn)
  1359. table = sql.SQLTable("test_type", db, frame=df)
  1360. result = str(table.table.c.a.type)
  1361. assert result == expected
  1362. @pytest.mark.parametrize("integer", ["uint64", "UInt64"])
  1363. def test_sqlalchemy_integer_overload_mapping(self, integer):
  1364. # GH35076 Map pandas integer to optimal SQLAlchemy integer type
  1365. df = DataFrame([0, 1], columns=["a"], dtype=integer)
  1366. db = sql.SQLDatabase(self.conn)
  1367. with pytest.raises(
  1368. ValueError, match="Unsigned 64 bit integer datatype is not supported"
  1369. ):
  1370. sql.SQLTable("test_type", db, frame=df)
  1371. def test_database_uri_string(self, test_frame1):
  1372. # Test read_sql and .to_sql method with a database URI (GH10654)
  1373. # db_uri = 'sqlite:///:memory:' # raises
  1374. # sqlalchemy.exc.OperationalError: (sqlite3.OperationalError) near
  1375. # "iris": syntax error [SQL: 'iris']
  1376. with tm.ensure_clean() as name:
  1377. db_uri = "sqlite:///" + name
  1378. table = "iris"
  1379. test_frame1.to_sql(table, db_uri, if_exists="replace", index=False)
  1380. test_frame2 = sql.read_sql(table, db_uri)
  1381. test_frame3 = sql.read_sql_table(table, db_uri)
  1382. query = "SELECT * FROM iris"
  1383. test_frame4 = sql.read_sql_query(query, db_uri)
  1384. tm.assert_frame_equal(test_frame1, test_frame2)
  1385. tm.assert_frame_equal(test_frame1, test_frame3)
  1386. tm.assert_frame_equal(test_frame1, test_frame4)
  1387. @td.skip_if_installed("pg8000")
  1388. def test_pg8000_sqlalchemy_passthrough_error(self):
  1389. # using driver that will not be installed on CI to trigger error
  1390. # in sqlalchemy.create_engine -> test passing of this error to user
  1391. db_uri = "postgresql+pg8000://user:pass@host/dbname"
  1392. with pytest.raises(ImportError, match="pg8000"):
  1393. sql.read_sql("select * from table", db_uri)
  1394. def test_query_by_text_obj(self):
  1395. # WIP : GH10846
  1396. from sqlalchemy import text
  1397. name_text = text("select * from iris where name=:name")
  1398. iris_df = sql.read_sql(name_text, self.conn, params={"name": "Iris-versicolor"})
  1399. all_names = set(iris_df["Name"])
  1400. assert all_names == {"Iris-versicolor"}
  1401. def test_query_by_select_obj(self):
  1402. # WIP : GH10846
  1403. from sqlalchemy import (
  1404. bindparam,
  1405. select,
  1406. )
  1407. iris = iris_table_metadata(self.flavor)
  1408. name_select = select(iris).where(iris.c.Name == bindparam("name"))
  1409. iris_df = sql.read_sql(name_select, self.conn, params={"name": "Iris-setosa"})
  1410. all_names = set(iris_df["Name"])
  1411. assert all_names == {"Iris-setosa"}
  1412. def test_column_with_percentage(self):
  1413. # GH 37157
  1414. df = DataFrame({"A": [0, 1, 2], "%_variation": [3, 4, 5]})
  1415. df.to_sql("test_column_percentage", self.conn, index=False)
  1416. res = sql.read_sql_table("test_column_percentage", self.conn)
  1417. tm.assert_frame_equal(res, df)
  1418. class TestSQLiteFallbackApi(SQLiteMixIn, _TestSQLApi):
  1419. """
  1420. Test the public sqlite connection fallback API
  1421. """
  1422. flavor = "sqlite"
  1423. mode = "fallback"
  1424. def connect(self, database=":memory:"):
  1425. return sqlite3.connect(database)
  1426. def test_sql_open_close(self, test_frame3):
  1427. # Test if the IO in the database still work if the connection closed
  1428. # between the writing and reading (as in many real situations).
  1429. with tm.ensure_clean() as name:
  1430. with closing(self.connect(name)) as conn:
  1431. assert (
  1432. sql.to_sql(test_frame3, "test_frame3_legacy", conn, index=False)
  1433. == 4
  1434. )
  1435. with closing(self.connect(name)) as conn:
  1436. result = sql.read_sql_query("SELECT * FROM test_frame3_legacy;", conn)
  1437. tm.assert_frame_equal(test_frame3, result)
  1438. @pytest.mark.skipif(SQLALCHEMY_INSTALLED, reason="SQLAlchemy is installed")
  1439. def test_con_string_import_error(self):
  1440. conn = "mysql://root@localhost/pandas"
  1441. msg = "Using URI string without sqlalchemy installed"
  1442. with pytest.raises(ImportError, match=msg):
  1443. sql.read_sql("SELECT * FROM iris", conn)
  1444. @pytest.mark.skipif(SQLALCHEMY_INSTALLED, reason="SQLAlchemy is installed")
  1445. def test_con_unknown_dbapi2_class_does_not_error_without_sql_alchemy_installed(
  1446. self,
  1447. ):
  1448. class MockSqliteConnection:
  1449. def __init__(self, *args, **kwargs) -> None:
  1450. self.conn = sqlite3.Connection(*args, **kwargs)
  1451. def __getattr__(self, name):
  1452. return getattr(self.conn, name)
  1453. def close(self):
  1454. self.conn.close()
  1455. with contextlib.closing(MockSqliteConnection(":memory:")) as conn:
  1456. with tm.assert_produces_warning(UserWarning):
  1457. sql.read_sql("SELECT 1", conn)
  1458. def test_read_sql_delegate(self):
  1459. iris_frame1 = sql.read_sql_query("SELECT * FROM iris", self.conn)
  1460. iris_frame2 = sql.read_sql("SELECT * FROM iris", self.conn)
  1461. tm.assert_frame_equal(iris_frame1, iris_frame2)
  1462. msg = "Execution failed on sql 'iris': near \"iris\": syntax error"
  1463. with pytest.raises(sql.DatabaseError, match=msg):
  1464. sql.read_sql("iris", self.conn)
  1465. def test_get_schema2(self, test_frame1):
  1466. # without providing a connection object (available for backwards comp)
  1467. create_sql = sql.get_schema(test_frame1, "test")
  1468. assert "CREATE" in create_sql
  1469. def _get_sqlite_column_type(self, schema, column):
  1470. for col in schema.split("\n"):
  1471. if col.split()[0].strip('"') == column:
  1472. return col.split()[1]
  1473. raise ValueError(f"Column {column} not found")
  1474. def test_sqlite_type_mapping(self):
  1475. # Test Timestamp objects (no datetime64 because of timezone) (GH9085)
  1476. df = DataFrame(
  1477. {"time": to_datetime(["2014-12-12 01:54", "2014-12-11 02:54"], utc=True)}
  1478. )
  1479. db = sql.SQLiteDatabase(self.conn)
  1480. table = sql.SQLiteTable("test_type", db, frame=df)
  1481. schema = table.sql_schema()
  1482. assert self._get_sqlite_column_type(schema, "time") == "TIMESTAMP"
  1483. # -----------------------------------------------------------------------------
  1484. # -- Database flavor specific tests
  1485. @pytest.mark.skipif(not SQLALCHEMY_INSTALLED, reason="SQLAlchemy not installed")
  1486. class _TestSQLAlchemy(SQLAlchemyMixIn, PandasSQLTest):
  1487. """
  1488. Base class for testing the sqlalchemy backend.
  1489. Subclasses for specific database types are created below. Tests that
  1490. deviate for each flavor are overwritten there.
  1491. """
  1492. flavor: str
  1493. @classmethod
  1494. def setup_class(cls):
  1495. cls.setup_driver()
  1496. cls.setup_engine()
  1497. @pytest.fixture(autouse=True)
  1498. def setup_method(self, iris_path, types_data):
  1499. try:
  1500. self.conn = self.engine.connect()
  1501. self.pandasSQL = sql.SQLDatabase(self.conn)
  1502. except sqlalchemy.exc.OperationalError:
  1503. pytest.skip(f"Can't connect to {self.flavor} server")
  1504. self.load_iris_data(iris_path)
  1505. self.load_types_data(types_data)
  1506. @classmethod
  1507. def setup_driver(cls):
  1508. raise NotImplementedError()
  1509. @classmethod
  1510. def setup_engine(cls):
  1511. raise NotImplementedError()
  1512. def test_read_sql_parameter(self):
  1513. self._read_sql_iris_parameter()
  1514. def test_read_sql_named_parameter(self):
  1515. self._read_sql_iris_named_parameter()
  1516. def test_to_sql_empty(self, test_frame1):
  1517. self._to_sql_empty(test_frame1)
  1518. def test_create_table(self):
  1519. from sqlalchemy import inspect
  1520. temp_conn = self.connect()
  1521. temp_frame = DataFrame(
  1522. {"one": [1.0, 2.0, 3.0, 4.0], "two": [4.0, 3.0, 2.0, 1.0]}
  1523. )
  1524. with sql.SQLDatabase(temp_conn, need_transaction=True) as pandasSQL:
  1525. assert pandasSQL.to_sql(temp_frame, "temp_frame") == 4
  1526. insp = inspect(temp_conn)
  1527. assert insp.has_table("temp_frame")
  1528. def test_drop_table(self):
  1529. from sqlalchemy import inspect
  1530. temp_conn = self.connect()
  1531. temp_frame = DataFrame(
  1532. {"one": [1.0, 2.0, 3.0, 4.0], "two": [4.0, 3.0, 2.0, 1.0]}
  1533. )
  1534. pandasSQL = sql.SQLDatabase(temp_conn)
  1535. assert pandasSQL.to_sql(temp_frame, "temp_frame") == 4
  1536. insp = inspect(temp_conn)
  1537. assert insp.has_table("temp_frame")
  1538. pandasSQL.drop_table("temp_frame")
  1539. try:
  1540. insp.clear_cache() # needed with SQLAlchemy 2.0, unavailable prior
  1541. except AttributeError:
  1542. pass
  1543. assert not insp.has_table("temp_frame")
  1544. def test_roundtrip(self, test_frame1):
  1545. self._roundtrip(test_frame1)
  1546. def test_execute_sql(self):
  1547. self._execute_sql()
  1548. def test_read_table(self):
  1549. iris_frame = sql.read_sql_table("iris", con=self.conn)
  1550. check_iris_frame(iris_frame)
  1551. def test_read_table_columns(self):
  1552. iris_frame = sql.read_sql_table(
  1553. "iris", con=self.conn, columns=["SepalLength", "SepalLength"]
  1554. )
  1555. tm.equalContents(iris_frame.columns.values, ["SepalLength", "SepalLength"])
  1556. def test_read_table_absent_raises(self):
  1557. msg = "Table this_doesnt_exist not found"
  1558. with pytest.raises(ValueError, match=msg):
  1559. sql.read_sql_table("this_doesnt_exist", con=self.conn)
  1560. def test_default_type_conversion(self):
  1561. df = sql.read_sql_table("types", self.conn)
  1562. assert issubclass(df.FloatCol.dtype.type, np.floating)
  1563. assert issubclass(df.IntCol.dtype.type, np.integer)
  1564. assert issubclass(df.BoolCol.dtype.type, np.bool_)
  1565. # Int column with NA values stays as float
  1566. assert issubclass(df.IntColWithNull.dtype.type, np.floating)
  1567. # Bool column with NA values becomes object
  1568. assert issubclass(df.BoolColWithNull.dtype.type, object)
  1569. def test_bigint(self):
  1570. # int64 should be converted to BigInteger, GH7433
  1571. df = DataFrame(data={"i64": [2**62]})
  1572. assert df.to_sql("test_bigint", self.conn, index=False) == 1
  1573. result = sql.read_sql_table("test_bigint", self.conn)
  1574. tm.assert_frame_equal(df, result)
  1575. def test_default_date_load(self):
  1576. df = sql.read_sql_table("types", self.conn)
  1577. # IMPORTANT - sqlite has no native date type, so shouldn't parse, but
  1578. # MySQL SHOULD be converted.
  1579. assert issubclass(df.DateCol.dtype.type, np.datetime64)
  1580. def test_datetime_with_timezone(self, request):
  1581. # edge case that converts postgresql datetime with time zone types
  1582. # to datetime64[ns,psycopg2.tz.FixedOffsetTimezone..], which is ok
  1583. # but should be more natural, so coerce to datetime64[ns] for now
  1584. def check(col):
  1585. # check that a column is either datetime64[ns]
  1586. # or datetime64[ns, UTC]
  1587. if is_datetime64_dtype(col.dtype):
  1588. # "2000-01-01 00:00:00-08:00" should convert to
  1589. # "2000-01-01 08:00:00"
  1590. assert col[0] == Timestamp("2000-01-01 08:00:00")
  1591. # "2000-06-01 00:00:00-07:00" should convert to
  1592. # "2000-06-01 07:00:00"
  1593. assert col[1] == Timestamp("2000-06-01 07:00:00")
  1594. elif is_datetime64tz_dtype(col.dtype):
  1595. assert str(col.dt.tz) == "UTC"
  1596. # "2000-01-01 00:00:00-08:00" should convert to
  1597. # "2000-01-01 08:00:00"
  1598. # "2000-06-01 00:00:00-07:00" should convert to
  1599. # "2000-06-01 07:00:00"
  1600. # GH 6415
  1601. expected_data = [
  1602. Timestamp("2000-01-01 08:00:00", tz="UTC"),
  1603. Timestamp("2000-06-01 07:00:00", tz="UTC"),
  1604. ]
  1605. expected = Series(expected_data, name=col.name)
  1606. tm.assert_series_equal(col, expected)
  1607. else:
  1608. raise AssertionError(
  1609. f"DateCol loaded with incorrect type -> {col.dtype}"
  1610. )
  1611. # GH11216
  1612. df = read_sql_query("select * from types", self.conn)
  1613. if not hasattr(df, "DateColWithTz"):
  1614. request.node.add_marker(
  1615. pytest.mark.xfail(reason="no column with datetime with time zone")
  1616. )
  1617. # this is parsed on Travis (linux), but not on macosx for some reason
  1618. # even with the same versions of psycopg2 & sqlalchemy, possibly a
  1619. # Postgresql server version difference
  1620. col = df.DateColWithTz
  1621. assert is_datetime64tz_dtype(col.dtype)
  1622. df = read_sql_query(
  1623. "select * from types", self.conn, parse_dates=["DateColWithTz"]
  1624. )
  1625. if not hasattr(df, "DateColWithTz"):
  1626. request.node.add_marker(
  1627. pytest.mark.xfail(reason="no column with datetime with time zone")
  1628. )
  1629. col = df.DateColWithTz
  1630. assert is_datetime64tz_dtype(col.dtype)
  1631. assert str(col.dt.tz) == "UTC"
  1632. check(df.DateColWithTz)
  1633. df = concat(
  1634. list(read_sql_query("select * from types", self.conn, chunksize=1)),
  1635. ignore_index=True,
  1636. )
  1637. col = df.DateColWithTz
  1638. assert is_datetime64tz_dtype(col.dtype)
  1639. assert str(col.dt.tz) == "UTC"
  1640. expected = sql.read_sql_table("types", self.conn)
  1641. col = expected.DateColWithTz
  1642. assert is_datetime64tz_dtype(col.dtype)
  1643. tm.assert_series_equal(df.DateColWithTz, expected.DateColWithTz)
  1644. # xref #7139
  1645. # this might or might not be converted depending on the postgres driver
  1646. df = sql.read_sql_table("types", self.conn)
  1647. check(df.DateColWithTz)
  1648. def test_datetime_with_timezone_roundtrip(self):
  1649. # GH 9086
  1650. # Write datetimetz data to a db and read it back
  1651. # For dbs that support timestamps with timezones, should get back UTC
  1652. # otherwise naive data should be returned
  1653. expected = DataFrame(
  1654. {"A": date_range("2013-01-01 09:00:00", periods=3, tz="US/Pacific")}
  1655. )
  1656. assert expected.to_sql("test_datetime_tz", self.conn, index=False) == 3
  1657. if self.flavor == "postgresql":
  1658. # SQLAlchemy "timezones" (i.e. offsets) are coerced to UTC
  1659. expected["A"] = expected["A"].dt.tz_convert("UTC")
  1660. else:
  1661. # Otherwise, timestamps are returned as local, naive
  1662. expected["A"] = expected["A"].dt.tz_localize(None)
  1663. result = sql.read_sql_table("test_datetime_tz", self.conn)
  1664. tm.assert_frame_equal(result, expected)
  1665. result = sql.read_sql_query("SELECT * FROM test_datetime_tz", self.conn)
  1666. if self.flavor == "sqlite":
  1667. # read_sql_query does not return datetime type like read_sql_table
  1668. assert isinstance(result.loc[0, "A"], str)
  1669. result["A"] = to_datetime(result["A"])
  1670. tm.assert_frame_equal(result, expected)
  1671. def test_out_of_bounds_datetime(self):
  1672. # GH 26761
  1673. data = DataFrame({"date": datetime(9999, 1, 1)}, index=[0])
  1674. assert data.to_sql("test_datetime_obb", self.conn, index=False) == 1
  1675. result = sql.read_sql_table("test_datetime_obb", self.conn)
  1676. expected = DataFrame([pd.NaT], columns=["date"])
  1677. tm.assert_frame_equal(result, expected)
  1678. def test_naive_datetimeindex_roundtrip(self):
  1679. # GH 23510
  1680. # Ensure that a naive DatetimeIndex isn't converted to UTC
  1681. dates = date_range("2018-01-01", periods=5, freq="6H")._with_freq(None)
  1682. expected = DataFrame({"nums": range(5)}, index=dates)
  1683. assert expected.to_sql("foo_table", self.conn, index_label="info_date") == 5
  1684. result = sql.read_sql_table("foo_table", self.conn, index_col="info_date")
  1685. # result index with gain a name from a set_index operation; expected
  1686. tm.assert_frame_equal(result, expected, check_names=False)
  1687. def test_date_parsing(self):
  1688. # No Parsing
  1689. df = sql.read_sql_table("types", self.conn)
  1690. expected_type = object if self.flavor == "sqlite" else np.datetime64
  1691. assert issubclass(df.DateCol.dtype.type, expected_type)
  1692. df = sql.read_sql_table("types", self.conn, parse_dates=["DateCol"])
  1693. assert issubclass(df.DateCol.dtype.type, np.datetime64)
  1694. df = sql.read_sql_table(
  1695. "types", self.conn, parse_dates={"DateCol": "%Y-%m-%d %H:%M:%S"}
  1696. )
  1697. assert issubclass(df.DateCol.dtype.type, np.datetime64)
  1698. df = sql.read_sql_table(
  1699. "types",
  1700. self.conn,
  1701. parse_dates={"DateCol": {"format": "%Y-%m-%d %H:%M:%S"}},
  1702. )
  1703. assert issubclass(df.DateCol.dtype.type, np.datetime64)
  1704. df = sql.read_sql_table("types", self.conn, parse_dates=["IntDateCol"])
  1705. assert issubclass(df.IntDateCol.dtype.type, np.datetime64)
  1706. df = sql.read_sql_table("types", self.conn, parse_dates={"IntDateCol": "s"})
  1707. assert issubclass(df.IntDateCol.dtype.type, np.datetime64)
  1708. df = sql.read_sql_table(
  1709. "types", self.conn, parse_dates={"IntDateCol": {"unit": "s"}}
  1710. )
  1711. assert issubclass(df.IntDateCol.dtype.type, np.datetime64)
  1712. def test_datetime(self):
  1713. df = DataFrame(
  1714. {"A": date_range("2013-01-01 09:00:00", periods=3), "B": np.arange(3.0)}
  1715. )
  1716. assert df.to_sql("test_datetime", self.conn) == 3
  1717. # with read_table -> type information from schema used
  1718. result = sql.read_sql_table("test_datetime", self.conn)
  1719. result = result.drop("index", axis=1)
  1720. tm.assert_frame_equal(result, df)
  1721. # with read_sql -> no type information -> sqlite has no native
  1722. result = sql.read_sql_query("SELECT * FROM test_datetime", self.conn)
  1723. result = result.drop("index", axis=1)
  1724. if self.flavor == "sqlite":
  1725. assert isinstance(result.loc[0, "A"], str)
  1726. result["A"] = to_datetime(result["A"])
  1727. tm.assert_frame_equal(result, df)
  1728. else:
  1729. tm.assert_frame_equal(result, df)
  1730. def test_datetime_NaT(self):
  1731. df = DataFrame(
  1732. {"A": date_range("2013-01-01 09:00:00", periods=3), "B": np.arange(3.0)}
  1733. )
  1734. df.loc[1, "A"] = np.nan
  1735. assert df.to_sql("test_datetime", self.conn, index=False) == 3
  1736. # with read_table -> type information from schema used
  1737. result = sql.read_sql_table("test_datetime", self.conn)
  1738. tm.assert_frame_equal(result, df)
  1739. # with read_sql -> no type information -> sqlite has no native
  1740. result = sql.read_sql_query("SELECT * FROM test_datetime", self.conn)
  1741. if self.flavor == "sqlite":
  1742. assert isinstance(result.loc[0, "A"], str)
  1743. result["A"] = to_datetime(result["A"], errors="coerce")
  1744. tm.assert_frame_equal(result, df)
  1745. else:
  1746. tm.assert_frame_equal(result, df)
  1747. def test_datetime_date(self):
  1748. # test support for datetime.date
  1749. df = DataFrame([date(2014, 1, 1), date(2014, 1, 2)], columns=["a"])
  1750. assert df.to_sql("test_date", self.conn, index=False) == 2
  1751. res = read_sql_table("test_date", self.conn)
  1752. result = res["a"]
  1753. expected = to_datetime(df["a"])
  1754. # comes back as datetime64
  1755. tm.assert_series_equal(result, expected)
  1756. def test_datetime_time(self, sqlite_buildin):
  1757. # test support for datetime.time
  1758. df = DataFrame([time(9, 0, 0), time(9, 1, 30)], columns=["a"])
  1759. assert df.to_sql("test_time", self.conn, index=False) == 2
  1760. res = read_sql_table("test_time", self.conn)
  1761. tm.assert_frame_equal(res, df)
  1762. # GH8341
  1763. # first, use the fallback to have the sqlite adapter put in place
  1764. sqlite_conn = sqlite_buildin
  1765. assert sql.to_sql(df, "test_time2", sqlite_conn, index=False) == 2
  1766. res = sql.read_sql_query("SELECT * FROM test_time2", sqlite_conn)
  1767. ref = df.applymap(lambda _: _.strftime("%H:%M:%S.%f"))
  1768. tm.assert_frame_equal(ref, res) # check if adapter is in place
  1769. # then test if sqlalchemy is unaffected by the sqlite adapter
  1770. assert sql.to_sql(df, "test_time3", self.conn, index=False) == 2
  1771. if self.flavor == "sqlite":
  1772. res = sql.read_sql_query("SELECT * FROM test_time3", self.conn)
  1773. ref = df.applymap(lambda _: _.strftime("%H:%M:%S.%f"))
  1774. tm.assert_frame_equal(ref, res)
  1775. res = sql.read_sql_table("test_time3", self.conn)
  1776. tm.assert_frame_equal(df, res)
  1777. def test_mixed_dtype_insert(self):
  1778. # see GH6509
  1779. s1 = Series(2**25 + 1, dtype=np.int32)
  1780. s2 = Series(0.0, dtype=np.float32)
  1781. df = DataFrame({"s1": s1, "s2": s2})
  1782. # write and read again
  1783. assert df.to_sql("test_read_write", self.conn, index=False) == 1
  1784. df2 = sql.read_sql_table("test_read_write", self.conn)
  1785. tm.assert_frame_equal(df, df2, check_dtype=False, check_exact=True)
  1786. def test_nan_numeric(self):
  1787. # NaNs in numeric float column
  1788. df = DataFrame({"A": [0, 1, 2], "B": [0.2, np.nan, 5.6]})
  1789. assert df.to_sql("test_nan", self.conn, index=False) == 3
  1790. # with read_table
  1791. result = sql.read_sql_table("test_nan", self.conn)
  1792. tm.assert_frame_equal(result, df)
  1793. # with read_sql
  1794. result = sql.read_sql_query("SELECT * FROM test_nan", self.conn)
  1795. tm.assert_frame_equal(result, df)
  1796. def test_nan_fullcolumn(self):
  1797. # full NaN column (numeric float column)
  1798. df = DataFrame({"A": [0, 1, 2], "B": [np.nan, np.nan, np.nan]})
  1799. assert df.to_sql("test_nan", self.conn, index=False) == 3
  1800. # with read_table
  1801. result = sql.read_sql_table("test_nan", self.conn)
  1802. tm.assert_frame_equal(result, df)
  1803. # with read_sql -> not type info from table -> stays None
  1804. df["B"] = df["B"].astype("object")
  1805. df["B"] = None
  1806. result = sql.read_sql_query("SELECT * FROM test_nan", self.conn)
  1807. tm.assert_frame_equal(result, df)
  1808. def test_nan_string(self):
  1809. # NaNs in string column
  1810. df = DataFrame({"A": [0, 1, 2], "B": ["a", "b", np.nan]})
  1811. assert df.to_sql("test_nan", self.conn, index=False) == 3
  1812. # NaNs are coming back as None
  1813. df.loc[2, "B"] = None
  1814. # with read_table
  1815. result = sql.read_sql_table("test_nan", self.conn)
  1816. tm.assert_frame_equal(result, df)
  1817. # with read_sql
  1818. result = sql.read_sql_query("SELECT * FROM test_nan", self.conn)
  1819. tm.assert_frame_equal(result, df)
  1820. def _get_index_columns(self, tbl_name):
  1821. from sqlalchemy import inspect
  1822. insp = inspect(self.conn)
  1823. ixs = insp.get_indexes(tbl_name)
  1824. ixs = [i["column_names"] for i in ixs]
  1825. return ixs
  1826. def test_to_sql_save_index(self):
  1827. self._to_sql_save_index()
  1828. def test_transactions(self):
  1829. self._transaction_test()
  1830. def test_get_schema_create_table(self, test_frame3):
  1831. # Use a dataframe without a bool column, since MySQL converts bool to
  1832. # TINYINT (which read_sql_table returns as an int and causes a dtype
  1833. # mismatch)
  1834. from sqlalchemy import text
  1835. from sqlalchemy.engine import Engine
  1836. tbl = "test_get_schema_create_table"
  1837. create_sql = sql.get_schema(test_frame3, tbl, con=self.conn)
  1838. blank_test_df = test_frame3.iloc[:0]
  1839. self.drop_table(tbl, self.conn)
  1840. create_sql = text(create_sql)
  1841. if isinstance(self.conn, Engine):
  1842. with self.conn.connect() as conn:
  1843. with conn.begin():
  1844. conn.execute(create_sql)
  1845. else:
  1846. with self.conn.begin():
  1847. self.conn.execute(create_sql)
  1848. returned_df = sql.read_sql_table(tbl, self.conn)
  1849. tm.assert_frame_equal(returned_df, blank_test_df, check_index_type=False)
  1850. self.drop_table(tbl, self.conn)
  1851. def test_dtype(self):
  1852. from sqlalchemy import (
  1853. TEXT,
  1854. String,
  1855. )
  1856. from sqlalchemy.schema import MetaData
  1857. cols = ["A", "B"]
  1858. data = [(0.8, True), (0.9, None)]
  1859. df = DataFrame(data, columns=cols)
  1860. assert df.to_sql("dtype_test", self.conn) == 2
  1861. assert df.to_sql("dtype_test2", self.conn, dtype={"B": TEXT}) == 2
  1862. meta = MetaData()
  1863. meta.reflect(bind=self.conn)
  1864. sqltype = meta.tables["dtype_test2"].columns["B"].type
  1865. assert isinstance(sqltype, TEXT)
  1866. msg = "The type of B is not a SQLAlchemy type"
  1867. with pytest.raises(ValueError, match=msg):
  1868. df.to_sql("error", self.conn, dtype={"B": str})
  1869. # GH9083
  1870. assert df.to_sql("dtype_test3", self.conn, dtype={"B": String(10)}) == 2
  1871. meta.reflect(bind=self.conn)
  1872. sqltype = meta.tables["dtype_test3"].columns["B"].type
  1873. assert isinstance(sqltype, String)
  1874. assert sqltype.length == 10
  1875. # single dtype
  1876. assert df.to_sql("single_dtype_test", self.conn, dtype=TEXT) == 2
  1877. meta.reflect(bind=self.conn)
  1878. sqltypea = meta.tables["single_dtype_test"].columns["A"].type
  1879. sqltypeb = meta.tables["single_dtype_test"].columns["B"].type
  1880. assert isinstance(sqltypea, TEXT)
  1881. assert isinstance(sqltypeb, TEXT)
  1882. def test_notna_dtype(self):
  1883. from sqlalchemy import (
  1884. Boolean,
  1885. DateTime,
  1886. Float,
  1887. Integer,
  1888. )
  1889. from sqlalchemy.schema import MetaData
  1890. cols = {
  1891. "Bool": Series([True, None]),
  1892. "Date": Series([datetime(2012, 5, 1), None]),
  1893. "Int": Series([1, None], dtype="object"),
  1894. "Float": Series([1.1, None]),
  1895. }
  1896. df = DataFrame(cols)
  1897. tbl = "notna_dtype_test"
  1898. assert df.to_sql(tbl, self.conn) == 2
  1899. _ = sql.read_sql_table(tbl, self.conn)
  1900. meta = MetaData()
  1901. meta.reflect(bind=self.conn)
  1902. my_type = Integer if self.flavor == "mysql" else Boolean
  1903. col_dict = meta.tables[tbl].columns
  1904. assert isinstance(col_dict["Bool"].type, my_type)
  1905. assert isinstance(col_dict["Date"].type, DateTime)
  1906. assert isinstance(col_dict["Int"].type, Integer)
  1907. assert isinstance(col_dict["Float"].type, Float)
  1908. def test_double_precision(self):
  1909. from sqlalchemy import (
  1910. BigInteger,
  1911. Float,
  1912. Integer,
  1913. )
  1914. from sqlalchemy.schema import MetaData
  1915. V = 1.23456789101112131415
  1916. df = DataFrame(
  1917. {
  1918. "f32": Series([V], dtype="float32"),
  1919. "f64": Series([V], dtype="float64"),
  1920. "f64_as_f32": Series([V], dtype="float64"),
  1921. "i32": Series([5], dtype="int32"),
  1922. "i64": Series([5], dtype="int64"),
  1923. }
  1924. )
  1925. assert (
  1926. df.to_sql(
  1927. "test_dtypes",
  1928. self.conn,
  1929. index=False,
  1930. if_exists="replace",
  1931. dtype={"f64_as_f32": Float(precision=23)},
  1932. )
  1933. == 1
  1934. )
  1935. res = sql.read_sql_table("test_dtypes", self.conn)
  1936. # check precision of float64
  1937. assert np.round(df["f64"].iloc[0], 14) == np.round(res["f64"].iloc[0], 14)
  1938. # check sql types
  1939. meta = MetaData()
  1940. meta.reflect(bind=self.conn)
  1941. col_dict = meta.tables["test_dtypes"].columns
  1942. assert str(col_dict["f32"].type) == str(col_dict["f64_as_f32"].type)
  1943. assert isinstance(col_dict["f32"].type, Float)
  1944. assert isinstance(col_dict["f64"].type, Float)
  1945. assert isinstance(col_dict["i32"].type, Integer)
  1946. assert isinstance(col_dict["i64"].type, BigInteger)
  1947. def test_connectable_issue_example(self):
  1948. # This tests the example raised in issue
  1949. # https://github.com/pandas-dev/pandas/issues/10104
  1950. from sqlalchemy.engine import Engine
  1951. def test_select(connection):
  1952. query = "SELECT test_foo_data FROM test_foo_data"
  1953. return sql.read_sql_query(query, con=connection)
  1954. def test_append(connection, data):
  1955. data.to_sql(name="test_foo_data", con=connection, if_exists="append")
  1956. def test_connectable(conn):
  1957. # https://github.com/sqlalchemy/sqlalchemy/commit/
  1958. # 00b5c10846e800304caa86549ab9da373b42fa5d#r48323973
  1959. foo_data = test_select(conn)
  1960. test_append(conn, foo_data)
  1961. def main(connectable):
  1962. if isinstance(connectable, Engine):
  1963. with connectable.connect() as conn:
  1964. with conn.begin():
  1965. test_connectable(conn)
  1966. else:
  1967. test_connectable(connectable)
  1968. assert (
  1969. DataFrame({"test_foo_data": [0, 1, 2]}).to_sql("test_foo_data", self.conn)
  1970. == 3
  1971. )
  1972. main(self.conn)
  1973. @pytest.mark.parametrize(
  1974. "input",
  1975. [{"foo": [np.inf]}, {"foo": [-np.inf]}, {"foo": [-np.inf], "infe0": ["bar"]}],
  1976. )
  1977. def test_to_sql_with_negative_npinf(self, input, request):
  1978. # GH 34431
  1979. df = DataFrame(input)
  1980. if self.flavor == "mysql":
  1981. # GH 36465
  1982. # The input {"foo": [-np.inf], "infe0": ["bar"]} does not raise any error
  1983. # for pymysql version >= 0.10
  1984. # TODO(GH#36465): remove this version check after GH 36465 is fixed
  1985. pymysql = pytest.importorskip("pymysql")
  1986. if (
  1987. Version(pymysql.__version__) < Version("1.0.3")
  1988. and "infe0" in df.columns
  1989. ):
  1990. mark = pytest.mark.xfail(reason="GH 36465")
  1991. request.node.add_marker(mark)
  1992. msg = "inf cannot be used with MySQL"
  1993. with pytest.raises(ValueError, match=msg):
  1994. df.to_sql("foobar", self.conn, index=False)
  1995. else:
  1996. assert df.to_sql("foobar", self.conn, index=False) == 1
  1997. res = sql.read_sql_table("foobar", self.conn)
  1998. tm.assert_equal(df, res)
  1999. def test_temporary_table(self):
  2000. from sqlalchemy import (
  2001. Column,
  2002. Integer,
  2003. Unicode,
  2004. select,
  2005. )
  2006. from sqlalchemy.orm import (
  2007. Session,
  2008. declarative_base,
  2009. )
  2010. test_data = "Hello, World!"
  2011. expected = DataFrame({"spam": [test_data]})
  2012. Base = declarative_base()
  2013. class Temporary(Base):
  2014. __tablename__ = "temp_test"
  2015. __table_args__ = {"prefixes": ["TEMPORARY"]}
  2016. id = Column(Integer, primary_key=True)
  2017. spam = Column(Unicode(30), nullable=False)
  2018. with Session(self.conn) as session:
  2019. with session.begin():
  2020. conn = session.connection()
  2021. Temporary.__table__.create(conn)
  2022. session.add(Temporary(spam=test_data))
  2023. session.flush()
  2024. df = sql.read_sql_query(sql=select(Temporary.spam), con=conn)
  2025. tm.assert_frame_equal(df, expected)
  2026. # -- SQL Engine tests (in the base class for now)
  2027. def test_invalid_engine(self, test_frame1):
  2028. msg = "engine must be one of 'auto', 'sqlalchemy'"
  2029. with pytest.raises(ValueError, match=msg):
  2030. self._to_sql_with_sql_engine(test_frame1, "bad_engine")
  2031. def test_options_sqlalchemy(self, test_frame1):
  2032. # use the set option
  2033. with pd.option_context("io.sql.engine", "sqlalchemy"):
  2034. self._to_sql_with_sql_engine(test_frame1)
  2035. def test_options_auto(self, test_frame1):
  2036. # use the set option
  2037. with pd.option_context("io.sql.engine", "auto"):
  2038. self._to_sql_with_sql_engine(test_frame1)
  2039. def test_options_get_engine(self):
  2040. assert isinstance(get_engine("sqlalchemy"), SQLAlchemyEngine)
  2041. with pd.option_context("io.sql.engine", "sqlalchemy"):
  2042. assert isinstance(get_engine("auto"), SQLAlchemyEngine)
  2043. assert isinstance(get_engine("sqlalchemy"), SQLAlchemyEngine)
  2044. with pd.option_context("io.sql.engine", "auto"):
  2045. assert isinstance(get_engine("auto"), SQLAlchemyEngine)
  2046. assert isinstance(get_engine("sqlalchemy"), SQLAlchemyEngine)
  2047. def test_get_engine_auto_error_message(self):
  2048. # Expect different error messages from get_engine(engine="auto")
  2049. # if engines aren't installed vs. are installed but bad version
  2050. pass
  2051. # TODO(GH#36893) fill this in when we add more engines
  2052. @pytest.mark.parametrize("func", ["read_sql", "read_sql_query"])
  2053. def test_read_sql_dtype_backend(self, string_storage, func, dtype_backend):
  2054. # GH#50048
  2055. table = "test"
  2056. df = self.dtype_backend_data()
  2057. df.to_sql(table, self.conn, index=False, if_exists="replace")
  2058. with pd.option_context("mode.string_storage", string_storage):
  2059. result = getattr(pd, func)(
  2060. f"Select * from {table}", self.conn, dtype_backend=dtype_backend
  2061. )
  2062. expected = self.dtype_backend_expected(string_storage, dtype_backend)
  2063. tm.assert_frame_equal(result, expected)
  2064. with pd.option_context("mode.string_storage", string_storage):
  2065. iterator = getattr(pd, func)(
  2066. f"Select * from {table}",
  2067. self.conn,
  2068. dtype_backend=dtype_backend,
  2069. chunksize=3,
  2070. )
  2071. expected = self.dtype_backend_expected(string_storage, dtype_backend)
  2072. for result in iterator:
  2073. tm.assert_frame_equal(result, expected)
  2074. @pytest.mark.parametrize("func", ["read_sql", "read_sql_table"])
  2075. def test_read_sql_dtype_backend_table(self, string_storage, func, dtype_backend):
  2076. # GH#50048
  2077. table = "test"
  2078. df = self.dtype_backend_data()
  2079. df.to_sql(table, self.conn, index=False, if_exists="replace")
  2080. with pd.option_context("mode.string_storage", string_storage):
  2081. result = getattr(pd, func)(table, self.conn, dtype_backend=dtype_backend)
  2082. expected = self.dtype_backend_expected(string_storage, dtype_backend)
  2083. tm.assert_frame_equal(result, expected)
  2084. with pd.option_context("mode.string_storage", string_storage):
  2085. iterator = getattr(pd, func)(
  2086. table,
  2087. self.conn,
  2088. dtype_backend=dtype_backend,
  2089. chunksize=3,
  2090. )
  2091. expected = self.dtype_backend_expected(string_storage, dtype_backend)
  2092. for result in iterator:
  2093. tm.assert_frame_equal(result, expected)
  2094. @pytest.mark.parametrize("func", ["read_sql", "read_sql_table", "read_sql_query"])
  2095. def test_read_sql_invalid_dtype_backend_table(self, func):
  2096. table = "test"
  2097. df = self.dtype_backend_data()
  2098. df.to_sql(table, self.conn, index=False, if_exists="replace")
  2099. msg = (
  2100. "dtype_backend numpy is invalid, only 'numpy_nullable' and "
  2101. "'pyarrow' are allowed."
  2102. )
  2103. with pytest.raises(ValueError, match=msg):
  2104. getattr(pd, func)(table, self.conn, dtype_backend="numpy")
  2105. def dtype_backend_data(self) -> DataFrame:
  2106. return DataFrame(
  2107. {
  2108. "a": Series([1, np.nan, 3], dtype="Int64"),
  2109. "b": Series([1, 2, 3], dtype="Int64"),
  2110. "c": Series([1.5, np.nan, 2.5], dtype="Float64"),
  2111. "d": Series([1.5, 2.0, 2.5], dtype="Float64"),
  2112. "e": [True, False, None],
  2113. "f": [True, False, True],
  2114. "g": ["a", "b", "c"],
  2115. "h": ["a", "b", None],
  2116. }
  2117. )
  2118. def dtype_backend_expected(self, storage, dtype_backend) -> DataFrame:
  2119. string_array: StringArray | ArrowStringArray
  2120. string_array_na: StringArray | ArrowStringArray
  2121. if storage == "python":
  2122. string_array = StringArray(np.array(["a", "b", "c"], dtype=np.object_))
  2123. string_array_na = StringArray(np.array(["a", "b", pd.NA], dtype=np.object_))
  2124. else:
  2125. pa = pytest.importorskip("pyarrow")
  2126. string_array = ArrowStringArray(pa.array(["a", "b", "c"]))
  2127. string_array_na = ArrowStringArray(pa.array(["a", "b", None]))
  2128. df = DataFrame(
  2129. {
  2130. "a": Series([1, np.nan, 3], dtype="Int64"),
  2131. "b": Series([1, 2, 3], dtype="Int64"),
  2132. "c": Series([1.5, np.nan, 2.5], dtype="Float64"),
  2133. "d": Series([1.5, 2.0, 2.5], dtype="Float64"),
  2134. "e": Series([True, False, pd.NA], dtype="boolean"),
  2135. "f": Series([True, False, True], dtype="boolean"),
  2136. "g": string_array,
  2137. "h": string_array_na,
  2138. }
  2139. )
  2140. if dtype_backend == "pyarrow":
  2141. pa = pytest.importorskip("pyarrow")
  2142. from pandas.arrays import ArrowExtensionArray
  2143. df = DataFrame(
  2144. {
  2145. col: ArrowExtensionArray(pa.array(df[col], from_pandas=True))
  2146. for col in df.columns
  2147. }
  2148. )
  2149. return df
  2150. def test_chunksize_empty_dtypes(self):
  2151. # GH#50245
  2152. dtypes = {"a": "int64", "b": "object"}
  2153. df = DataFrame(columns=["a", "b"]).astype(dtypes)
  2154. expected = df.copy()
  2155. df.to_sql("test", self.conn, index=False, if_exists="replace")
  2156. for result in read_sql_query(
  2157. "SELECT * FROM test",
  2158. self.conn,
  2159. dtype=dtypes,
  2160. chunksize=1,
  2161. ):
  2162. tm.assert_frame_equal(result, expected)
  2163. @pytest.mark.parametrize("dtype_backend", [lib.no_default, "numpy_nullable"])
  2164. @pytest.mark.parametrize("func", ["read_sql", "read_sql_query"])
  2165. def test_read_sql_dtype(self, func, dtype_backend):
  2166. # GH#50797
  2167. table = "test"
  2168. df = DataFrame({"a": [1, 2, 3], "b": 5})
  2169. df.to_sql(table, self.conn, index=False, if_exists="replace")
  2170. result = getattr(pd, func)(
  2171. f"Select * from {table}",
  2172. self.conn,
  2173. dtype={"a": np.float64},
  2174. dtype_backend=dtype_backend,
  2175. )
  2176. expected = DataFrame(
  2177. {
  2178. "a": Series([1, 2, 3], dtype=np.float64),
  2179. "b": Series(
  2180. [5, 5, 5],
  2181. dtype="int64" if not dtype_backend == "numpy_nullable" else "Int64",
  2182. ),
  2183. }
  2184. )
  2185. tm.assert_frame_equal(result, expected)
  2186. class TestSQLiteAlchemy(_TestSQLAlchemy):
  2187. """
  2188. Test the sqlalchemy backend against an in-memory sqlite database.
  2189. """
  2190. flavor = "sqlite"
  2191. @classmethod
  2192. def setup_engine(cls):
  2193. cls.engine = sqlalchemy.create_engine("sqlite:///:memory:")
  2194. @classmethod
  2195. def setup_driver(cls):
  2196. # sqlite3 is built-in
  2197. cls.driver = None
  2198. def test_default_type_conversion(self):
  2199. df = sql.read_sql_table("types", self.conn)
  2200. assert issubclass(df.FloatCol.dtype.type, np.floating)
  2201. assert issubclass(df.IntCol.dtype.type, np.integer)
  2202. # sqlite has no boolean type, so integer type is returned
  2203. assert issubclass(df.BoolCol.dtype.type, np.integer)
  2204. # Int column with NA values stays as float
  2205. assert issubclass(df.IntColWithNull.dtype.type, np.floating)
  2206. # Non-native Bool column with NA values stays as float
  2207. assert issubclass(df.BoolColWithNull.dtype.type, np.floating)
  2208. def test_default_date_load(self):
  2209. df = sql.read_sql_table("types", self.conn)
  2210. # IMPORTANT - sqlite has no native date type, so shouldn't parse, but
  2211. assert not issubclass(df.DateCol.dtype.type, np.datetime64)
  2212. def test_bigint_warning(self):
  2213. # test no warning for BIGINT (to support int64) is raised (GH7433)
  2214. df = DataFrame({"a": [1, 2]}, dtype="int64")
  2215. assert df.to_sql("test_bigintwarning", self.conn, index=False) == 2
  2216. with tm.assert_produces_warning(None):
  2217. sql.read_sql_table("test_bigintwarning", self.conn)
  2218. def test_row_object_is_named_tuple(self):
  2219. # GH 40682
  2220. # Test for the is_named_tuple() function
  2221. # Placed here due to its usage of sqlalchemy
  2222. from sqlalchemy import (
  2223. Column,
  2224. Integer,
  2225. String,
  2226. )
  2227. from sqlalchemy.orm import (
  2228. declarative_base,
  2229. sessionmaker,
  2230. )
  2231. BaseModel = declarative_base()
  2232. class Test(BaseModel):
  2233. __tablename__ = "test_frame"
  2234. id = Column(Integer, primary_key=True)
  2235. string_column = Column(String(50))
  2236. with self.conn.begin():
  2237. BaseModel.metadata.create_all(self.conn)
  2238. Session = sessionmaker(bind=self.conn)
  2239. with Session() as session:
  2240. df = DataFrame({"id": [0, 1], "string_column": ["hello", "world"]})
  2241. assert (
  2242. df.to_sql("test_frame", con=self.conn, index=False, if_exists="replace")
  2243. == 2
  2244. )
  2245. session.commit()
  2246. test_query = session.query(Test.id, Test.string_column)
  2247. df = DataFrame(test_query)
  2248. assert list(df.columns) == ["id", "string_column"]
  2249. def dtype_backend_expected(self, storage, dtype_backend) -> DataFrame:
  2250. df = super().dtype_backend_expected(storage, dtype_backend)
  2251. if dtype_backend == "numpy_nullable":
  2252. df = df.astype({"e": "Int64", "f": "Int64"})
  2253. else:
  2254. df = df.astype({"e": "int64[pyarrow]", "f": "int64[pyarrow]"})
  2255. return df
  2256. @pytest.mark.parametrize("func", ["read_sql", "read_sql_table"])
  2257. def test_read_sql_dtype_backend_table(self, string_storage, func):
  2258. # GH#50048 Not supported for sqlite
  2259. pass
  2260. @pytest.mark.db
  2261. class TestMySQLAlchemy(_TestSQLAlchemy):
  2262. """
  2263. Test the sqlalchemy backend against an MySQL database.
  2264. """
  2265. flavor = "mysql"
  2266. port = 3306
  2267. @classmethod
  2268. def setup_engine(cls):
  2269. cls.engine = sqlalchemy.create_engine(
  2270. f"mysql+{cls.driver}://root@localhost:{cls.port}/pandas",
  2271. connect_args=cls.connect_args,
  2272. )
  2273. @classmethod
  2274. def setup_driver(cls):
  2275. pymysql = pytest.importorskip("pymysql")
  2276. cls.driver = "pymysql"
  2277. cls.connect_args = {"client_flag": pymysql.constants.CLIENT.MULTI_STATEMENTS}
  2278. def test_default_type_conversion(self):
  2279. pass
  2280. def dtype_backend_expected(self, storage, dtype_backend) -> DataFrame:
  2281. df = super().dtype_backend_expected(storage, dtype_backend)
  2282. if dtype_backend == "numpy_nullable":
  2283. df = df.astype({"e": "Int64", "f": "Int64"})
  2284. else:
  2285. df = df.astype({"e": "int64[pyarrow]", "f": "int64[pyarrow]"})
  2286. return df
  2287. @pytest.mark.db
  2288. class TestPostgreSQLAlchemy(_TestSQLAlchemy):
  2289. """
  2290. Test the sqlalchemy backend against an PostgreSQL database.
  2291. """
  2292. flavor = "postgresql"
  2293. port = 5432
  2294. @classmethod
  2295. def setup_engine(cls):
  2296. cls.engine = sqlalchemy.create_engine(
  2297. f"postgresql+{cls.driver}://postgres:postgres@localhost:{cls.port}/pandas"
  2298. )
  2299. @classmethod
  2300. def setup_driver(cls):
  2301. pytest.importorskip("psycopg2")
  2302. cls.driver = "psycopg2"
  2303. def test_schema_support(self):
  2304. from sqlalchemy.engine import Engine
  2305. # only test this for postgresql (schema's not supported in
  2306. # mysql/sqlite)
  2307. df = DataFrame({"col1": [1, 2], "col2": [0.1, 0.2], "col3": ["a", "n"]})
  2308. # create a schema
  2309. with self.conn.begin():
  2310. self.conn.exec_driver_sql("DROP SCHEMA IF EXISTS other CASCADE;")
  2311. self.conn.exec_driver_sql("CREATE SCHEMA other;")
  2312. # write dataframe to different schema's
  2313. assert df.to_sql("test_schema_public", self.conn, index=False) == 2
  2314. assert (
  2315. df.to_sql(
  2316. "test_schema_public_explicit", self.conn, index=False, schema="public"
  2317. )
  2318. == 2
  2319. )
  2320. assert (
  2321. df.to_sql("test_schema_other", self.conn, index=False, schema="other") == 2
  2322. )
  2323. # read dataframes back in
  2324. res1 = sql.read_sql_table("test_schema_public", self.conn)
  2325. tm.assert_frame_equal(df, res1)
  2326. res2 = sql.read_sql_table("test_schema_public_explicit", self.conn)
  2327. tm.assert_frame_equal(df, res2)
  2328. res3 = sql.read_sql_table(
  2329. "test_schema_public_explicit", self.conn, schema="public"
  2330. )
  2331. tm.assert_frame_equal(df, res3)
  2332. res4 = sql.read_sql_table("test_schema_other", self.conn, schema="other")
  2333. tm.assert_frame_equal(df, res4)
  2334. msg = "Table test_schema_other not found"
  2335. with pytest.raises(ValueError, match=msg):
  2336. sql.read_sql_table("test_schema_other", self.conn, schema="public")
  2337. # different if_exists options
  2338. # create a schema
  2339. with self.conn.begin():
  2340. self.conn.exec_driver_sql("DROP SCHEMA IF EXISTS other CASCADE;")
  2341. self.conn.exec_driver_sql("CREATE SCHEMA other;")
  2342. # write dataframe with different if_exists options
  2343. assert (
  2344. df.to_sql("test_schema_other", self.conn, schema="other", index=False) == 2
  2345. )
  2346. df.to_sql(
  2347. "test_schema_other",
  2348. self.conn,
  2349. schema="other",
  2350. index=False,
  2351. if_exists="replace",
  2352. )
  2353. assert (
  2354. df.to_sql(
  2355. "test_schema_other",
  2356. self.conn,
  2357. schema="other",
  2358. index=False,
  2359. if_exists="append",
  2360. )
  2361. == 2
  2362. )
  2363. res = sql.read_sql_table("test_schema_other", self.conn, schema="other")
  2364. tm.assert_frame_equal(concat([df, df], ignore_index=True), res)
  2365. # specifying schema in user-provided meta
  2366. # The schema won't be applied on another Connection
  2367. # because of transactional schemas
  2368. if isinstance(self.conn, Engine):
  2369. engine2 = self.connect()
  2370. pdsql = sql.SQLDatabase(engine2, schema="other")
  2371. assert pdsql.to_sql(df, "test_schema_other2", index=False) == 2
  2372. assert (
  2373. pdsql.to_sql(df, "test_schema_other2", index=False, if_exists="replace")
  2374. == 2
  2375. )
  2376. assert (
  2377. pdsql.to_sql(df, "test_schema_other2", index=False, if_exists="append")
  2378. == 2
  2379. )
  2380. res1 = sql.read_sql_table("test_schema_other2", self.conn, schema="other")
  2381. res2 = pdsql.read_table("test_schema_other2")
  2382. tm.assert_frame_equal(res1, res2)
  2383. # -----------------------------------------------------------------------------
  2384. # -- Test Sqlite / MySQL fallback
  2385. class TestSQLiteFallback(SQLiteMixIn, PandasSQLTest):
  2386. """
  2387. Test the fallback mode against an in-memory sqlite database.
  2388. """
  2389. flavor = "sqlite"
  2390. @pytest.fixture(autouse=True)
  2391. def setup_method(self, iris_path, types_data):
  2392. self.conn = self.connect()
  2393. self.load_iris_data(iris_path)
  2394. self.load_types_data(types_data)
  2395. self.pandasSQL = sql.SQLiteDatabase(self.conn)
  2396. def test_read_sql_parameter(self):
  2397. self._read_sql_iris_parameter()
  2398. def test_read_sql_named_parameter(self):
  2399. self._read_sql_iris_named_parameter()
  2400. def test_to_sql_empty(self, test_frame1):
  2401. self._to_sql_empty(test_frame1)
  2402. def test_create_and_drop_table(self):
  2403. temp_frame = DataFrame(
  2404. {"one": [1.0, 2.0, 3.0, 4.0], "two": [4.0, 3.0, 2.0, 1.0]}
  2405. )
  2406. assert self.pandasSQL.to_sql(temp_frame, "drop_test_frame") == 4
  2407. assert self.pandasSQL.has_table("drop_test_frame")
  2408. self.pandasSQL.drop_table("drop_test_frame")
  2409. assert not self.pandasSQL.has_table("drop_test_frame")
  2410. def test_roundtrip(self, test_frame1):
  2411. self._roundtrip(test_frame1)
  2412. def test_execute_sql(self):
  2413. self._execute_sql()
  2414. def test_datetime_date(self):
  2415. # test support for datetime.date
  2416. df = DataFrame([date(2014, 1, 1), date(2014, 1, 2)], columns=["a"])
  2417. assert df.to_sql("test_date", self.conn, index=False) == 2
  2418. res = read_sql_query("SELECT * FROM test_date", self.conn)
  2419. if self.flavor == "sqlite":
  2420. # comes back as strings
  2421. tm.assert_frame_equal(res, df.astype(str))
  2422. elif self.flavor == "mysql":
  2423. tm.assert_frame_equal(res, df)
  2424. @pytest.mark.parametrize("tz_aware", [False, True])
  2425. def test_datetime_time(self, tz_aware):
  2426. # test support for datetime.time, GH #8341
  2427. if not tz_aware:
  2428. tz_times = [time(9, 0, 0), time(9, 1, 30)]
  2429. else:
  2430. tz_dt = date_range("2013-01-01 09:00:00", periods=2, tz="US/Pacific")
  2431. tz_times = Series(tz_dt.to_pydatetime()).map(lambda dt: dt.timetz())
  2432. df = DataFrame(tz_times, columns=["a"])
  2433. assert df.to_sql("test_time", self.conn, index=False) == 2
  2434. res = read_sql_query("SELECT * FROM test_time", self.conn)
  2435. if self.flavor == "sqlite":
  2436. # comes back as strings
  2437. expected = df.applymap(lambda _: _.strftime("%H:%M:%S.%f"))
  2438. tm.assert_frame_equal(res, expected)
  2439. def _get_index_columns(self, tbl_name):
  2440. ixs = sql.read_sql_query(
  2441. "SELECT * FROM sqlite_master WHERE type = 'index' "
  2442. + f"AND tbl_name = '{tbl_name}'",
  2443. self.conn,
  2444. )
  2445. ix_cols = []
  2446. for ix_name in ixs.name:
  2447. ix_info = sql.read_sql_query(f"PRAGMA index_info({ix_name})", self.conn)
  2448. ix_cols.append(ix_info.name.tolist())
  2449. return ix_cols
  2450. def test_to_sql_save_index(self):
  2451. self._to_sql_save_index()
  2452. def test_transactions(self):
  2453. self._transaction_test()
  2454. def _get_sqlite_column_type(self, table, column):
  2455. recs = self.conn.execute(f"PRAGMA table_info({table})")
  2456. for cid, name, ctype, not_null, default, pk in recs:
  2457. if name == column:
  2458. return ctype
  2459. raise ValueError(f"Table {table}, column {column} not found")
  2460. def test_dtype(self):
  2461. if self.flavor == "mysql":
  2462. pytest.skip("Not applicable to MySQL legacy")
  2463. cols = ["A", "B"]
  2464. data = [(0.8, True), (0.9, None)]
  2465. df = DataFrame(data, columns=cols)
  2466. assert df.to_sql("dtype_test", self.conn) == 2
  2467. assert df.to_sql("dtype_test2", self.conn, dtype={"B": "STRING"}) == 2
  2468. # sqlite stores Boolean values as INTEGER
  2469. assert self._get_sqlite_column_type("dtype_test", "B") == "INTEGER"
  2470. assert self._get_sqlite_column_type("dtype_test2", "B") == "STRING"
  2471. msg = r"B \(<class 'bool'>\) not a string"
  2472. with pytest.raises(ValueError, match=msg):
  2473. df.to_sql("error", self.conn, dtype={"B": bool})
  2474. # single dtype
  2475. assert df.to_sql("single_dtype_test", self.conn, dtype="STRING") == 2
  2476. assert self._get_sqlite_column_type("single_dtype_test", "A") == "STRING"
  2477. assert self._get_sqlite_column_type("single_dtype_test", "B") == "STRING"
  2478. def test_notna_dtype(self):
  2479. if self.flavor == "mysql":
  2480. pytest.skip("Not applicable to MySQL legacy")
  2481. cols = {
  2482. "Bool": Series([True, None]),
  2483. "Date": Series([datetime(2012, 5, 1), None]),
  2484. "Int": Series([1, None], dtype="object"),
  2485. "Float": Series([1.1, None]),
  2486. }
  2487. df = DataFrame(cols)
  2488. tbl = "notna_dtype_test"
  2489. assert df.to_sql(tbl, self.conn) == 2
  2490. assert self._get_sqlite_column_type(tbl, "Bool") == "INTEGER"
  2491. assert self._get_sqlite_column_type(tbl, "Date") == "TIMESTAMP"
  2492. assert self._get_sqlite_column_type(tbl, "Int") == "INTEGER"
  2493. assert self._get_sqlite_column_type(tbl, "Float") == "REAL"
  2494. def test_illegal_names(self):
  2495. # For sqlite, these should work fine
  2496. df = DataFrame([[1, 2], [3, 4]], columns=["a", "b"])
  2497. msg = "Empty table or column name specified"
  2498. with pytest.raises(ValueError, match=msg):
  2499. df.to_sql("", self.conn)
  2500. for ndx, weird_name in enumerate(
  2501. [
  2502. "test_weird_name]",
  2503. "test_weird_name[",
  2504. "test_weird_name`",
  2505. 'test_weird_name"',
  2506. "test_weird_name'",
  2507. "_b.test_weird_name_01-30",
  2508. '"_b.test_weird_name_01-30"',
  2509. "99beginswithnumber",
  2510. "12345",
  2511. "\xe9",
  2512. ]
  2513. ):
  2514. assert df.to_sql(weird_name, self.conn) == 2
  2515. sql.table_exists(weird_name, self.conn)
  2516. df2 = DataFrame([[1, 2], [3, 4]], columns=["a", weird_name])
  2517. c_tbl = f"test_weird_col_name{ndx:d}"
  2518. assert df2.to_sql(c_tbl, self.conn) == 2
  2519. sql.table_exists(c_tbl, self.conn)
  2520. # -----------------------------------------------------------------------------
  2521. # -- Old tests from 0.13.1 (before refactor using sqlalchemy)
  2522. _formatters = {
  2523. datetime: "'{}'".format,
  2524. str: "'{}'".format,
  2525. np.str_: "'{}'".format,
  2526. bytes: "'{}'".format,
  2527. float: "{:.8f}".format,
  2528. int: "{:d}".format,
  2529. type(None): lambda x: "NULL",
  2530. np.float64: "{:.10f}".format,
  2531. bool: "'{!s}'".format,
  2532. }
  2533. def format_query(sql, *args):
  2534. processed_args = []
  2535. for arg in args:
  2536. if isinstance(arg, float) and isna(arg):
  2537. arg = None
  2538. formatter = _formatters[type(arg)]
  2539. processed_args.append(formatter(arg))
  2540. return sql % tuple(processed_args)
  2541. def tquery(query, con=None):
  2542. """Replace removed sql.tquery function"""
  2543. with sql.pandasSQL_builder(con) as pandas_sql:
  2544. res = pandas_sql.execute(query).fetchall()
  2545. return None if res is None else list(res)
  2546. class TestXSQLite:
  2547. def drop_table(self, table_name, conn):
  2548. cur = conn.cursor()
  2549. cur.execute(f"DROP TABLE IF EXISTS {sql._get_valid_sqlite_name(table_name)}")
  2550. conn.commit()
  2551. def test_basic(self, sqlite_buildin):
  2552. frame = tm.makeTimeDataFrame()
  2553. assert (
  2554. sql.to_sql(frame, name="test_table", con=sqlite_buildin, index=False) == 30
  2555. )
  2556. result = sql.read_sql("select * from test_table", sqlite_buildin)
  2557. # HACK! Change this once indexes are handled properly.
  2558. result.index = frame.index
  2559. expected = frame
  2560. tm.assert_frame_equal(result, frame)
  2561. frame["txt"] = ["a"] * len(frame)
  2562. frame2 = frame.copy()
  2563. new_idx = Index(np.arange(len(frame2)), dtype=np.int64) + 10
  2564. frame2["Idx"] = new_idx.copy()
  2565. assert (
  2566. sql.to_sql(frame2, name="test_table2", con=sqlite_buildin, index=False)
  2567. == 30
  2568. )
  2569. result = sql.read_sql(
  2570. "select * from test_table2", sqlite_buildin, index_col="Idx"
  2571. )
  2572. expected = frame.copy()
  2573. expected.index = new_idx
  2574. expected.index.name = "Idx"
  2575. tm.assert_frame_equal(expected, result)
  2576. def test_write_row_by_row(self, sqlite_buildin):
  2577. frame = tm.makeTimeDataFrame()
  2578. frame.iloc[0, 0] = np.nan
  2579. create_sql = sql.get_schema(frame, "test")
  2580. cur = sqlite_buildin.cursor()
  2581. cur.execute(create_sql)
  2582. ins = "INSERT INTO test VALUES (%s, %s, %s, %s)"
  2583. for _, row in frame.iterrows():
  2584. fmt_sql = format_query(ins, *row)
  2585. tquery(fmt_sql, con=sqlite_buildin)
  2586. sqlite_buildin.commit()
  2587. result = sql.read_sql("select * from test", con=sqlite_buildin)
  2588. result.index = frame.index
  2589. tm.assert_frame_equal(result, frame, rtol=1e-3)
  2590. def test_execute(self, sqlite_buildin):
  2591. frame = tm.makeTimeDataFrame()
  2592. create_sql = sql.get_schema(frame, "test")
  2593. cur = sqlite_buildin.cursor()
  2594. cur.execute(create_sql)
  2595. ins = "INSERT INTO test VALUES (?, ?, ?, ?)"
  2596. row = frame.iloc[0]
  2597. with sql.pandasSQL_builder(sqlite_buildin) as pandas_sql:
  2598. pandas_sql.execute(ins, tuple(row))
  2599. sqlite_buildin.commit()
  2600. result = sql.read_sql("select * from test", sqlite_buildin)
  2601. result.index = frame.index[:1]
  2602. tm.assert_frame_equal(result, frame[:1])
  2603. def test_schema(self, sqlite_buildin):
  2604. frame = tm.makeTimeDataFrame()
  2605. create_sql = sql.get_schema(frame, "test")
  2606. lines = create_sql.splitlines()
  2607. for line in lines:
  2608. tokens = line.split(" ")
  2609. if len(tokens) == 2 and tokens[0] == "A":
  2610. assert tokens[1] == "DATETIME"
  2611. create_sql = sql.get_schema(frame, "test", keys=["A", "B"])
  2612. lines = create_sql.splitlines()
  2613. assert 'PRIMARY KEY ("A", "B")' in create_sql
  2614. cur = sqlite_buildin.cursor()
  2615. cur.execute(create_sql)
  2616. def test_execute_fail(self, sqlite_buildin):
  2617. create_sql = """
  2618. CREATE TABLE test
  2619. (
  2620. a TEXT,
  2621. b TEXT,
  2622. c REAL,
  2623. PRIMARY KEY (a, b)
  2624. );
  2625. """
  2626. cur = sqlite_buildin.cursor()
  2627. cur.execute(create_sql)
  2628. with sql.pandasSQL_builder(sqlite_buildin) as pandas_sql:
  2629. pandas_sql.execute('INSERT INTO test VALUES("foo", "bar", 1.234)')
  2630. pandas_sql.execute('INSERT INTO test VALUES("foo", "baz", 2.567)')
  2631. with pytest.raises(sql.DatabaseError, match="Execution failed on sql"):
  2632. pandas_sql.execute('INSERT INTO test VALUES("foo", "bar", 7)')
  2633. def test_execute_closed_connection(self):
  2634. create_sql = """
  2635. CREATE TABLE test
  2636. (
  2637. a TEXT,
  2638. b TEXT,
  2639. c REAL,
  2640. PRIMARY KEY (a, b)
  2641. );
  2642. """
  2643. with contextlib.closing(sqlite3.connect(":memory:")) as conn:
  2644. cur = conn.cursor()
  2645. cur.execute(create_sql)
  2646. with sql.pandasSQL_builder(conn) as pandas_sql:
  2647. pandas_sql.execute('INSERT INTO test VALUES("foo", "bar", 1.234)')
  2648. msg = "Cannot operate on a closed database."
  2649. with pytest.raises(sqlite3.ProgrammingError, match=msg):
  2650. tquery("select * from test", con=conn)
  2651. def test_keyword_as_column_names(self, sqlite_buildin):
  2652. df = DataFrame({"From": np.ones(5)})
  2653. assert sql.to_sql(df, con=sqlite_buildin, name="testkeywords", index=False) == 5
  2654. def test_onecolumn_of_integer(self, sqlite_buildin):
  2655. # GH 3628
  2656. # a column_of_integers dataframe should transfer well to sql
  2657. mono_df = DataFrame([1, 2], columns=["c0"])
  2658. assert sql.to_sql(mono_df, con=sqlite_buildin, name="mono_df", index=False) == 2
  2659. # computing the sum via sql
  2660. con_x = sqlite_buildin
  2661. the_sum = sum(my_c0[0] for my_c0 in con_x.execute("select * from mono_df"))
  2662. # it should not fail, and gives 3 ( Issue #3628 )
  2663. assert the_sum == 3
  2664. result = sql.read_sql("select * from mono_df", con_x)
  2665. tm.assert_frame_equal(result, mono_df)
  2666. def test_if_exists(self, sqlite_buildin):
  2667. df_if_exists_1 = DataFrame({"col1": [1, 2], "col2": ["A", "B"]})
  2668. df_if_exists_2 = DataFrame({"col1": [3, 4, 5], "col2": ["C", "D", "E"]})
  2669. table_name = "table_if_exists"
  2670. sql_select = f"SELECT * FROM {table_name}"
  2671. msg = "'notvalidvalue' is not valid for if_exists"
  2672. with pytest.raises(ValueError, match=msg):
  2673. sql.to_sql(
  2674. frame=df_if_exists_1,
  2675. con=sqlite_buildin,
  2676. name=table_name,
  2677. if_exists="notvalidvalue",
  2678. )
  2679. self.drop_table(table_name, sqlite_buildin)
  2680. # test if_exists='fail'
  2681. sql.to_sql(
  2682. frame=df_if_exists_1, con=sqlite_buildin, name=table_name, if_exists="fail"
  2683. )
  2684. msg = "Table 'table_if_exists' already exists"
  2685. with pytest.raises(ValueError, match=msg):
  2686. sql.to_sql(
  2687. frame=df_if_exists_1,
  2688. con=sqlite_buildin,
  2689. name=table_name,
  2690. if_exists="fail",
  2691. )
  2692. # test if_exists='replace'
  2693. sql.to_sql(
  2694. frame=df_if_exists_1,
  2695. con=sqlite_buildin,
  2696. name=table_name,
  2697. if_exists="replace",
  2698. index=False,
  2699. )
  2700. assert tquery(sql_select, con=sqlite_buildin) == [(1, "A"), (2, "B")]
  2701. assert (
  2702. sql.to_sql(
  2703. frame=df_if_exists_2,
  2704. con=sqlite_buildin,
  2705. name=table_name,
  2706. if_exists="replace",
  2707. index=False,
  2708. )
  2709. == 3
  2710. )
  2711. assert tquery(sql_select, con=sqlite_buildin) == [(3, "C"), (4, "D"), (5, "E")]
  2712. self.drop_table(table_name, sqlite_buildin)
  2713. # test if_exists='append'
  2714. assert (
  2715. sql.to_sql(
  2716. frame=df_if_exists_1,
  2717. con=sqlite_buildin,
  2718. name=table_name,
  2719. if_exists="fail",
  2720. index=False,
  2721. )
  2722. == 2
  2723. )
  2724. assert tquery(sql_select, con=sqlite_buildin) == [(1, "A"), (2, "B")]
  2725. assert (
  2726. sql.to_sql(
  2727. frame=df_if_exists_2,
  2728. con=sqlite_buildin,
  2729. name=table_name,
  2730. if_exists="append",
  2731. index=False,
  2732. )
  2733. == 3
  2734. )
  2735. assert tquery(sql_select, con=sqlite_buildin) == [
  2736. (1, "A"),
  2737. (2, "B"),
  2738. (3, "C"),
  2739. (4, "D"),
  2740. (5, "E"),
  2741. ]
  2742. self.drop_table(table_name, sqlite_buildin)