stata.py 130 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830283128322833283428352836283728382839284028412842284328442845284628472848284928502851285228532854285528562857285828592860286128622863286428652866286728682869287028712872287328742875287628772878287928802881288228832884288528862887288828892890289128922893289428952896289728982899290029012902290329042905290629072908290929102911291229132914291529162917291829192920292129222923292429252926292729282929293029312932293329342935293629372938293929402941294229432944294529462947294829492950295129522953295429552956295729582959296029612962296329642965296629672968296929702971297229732974297529762977297829792980298129822983298429852986298729882989299029912992299329942995299629972998299930003001300230033004300530063007300830093010301130123013301430153016301730183019302030213022302330243025302630273028302930303031303230333034303530363037303830393040304130423043304430453046304730483049305030513052305330543055305630573058305930603061306230633064306530663067306830693070307130723073307430753076307730783079308030813082308330843085308630873088308930903091309230933094309530963097309830993100310131023103310431053106310731083109311031113112311331143115311631173118311931203121312231233124312531263127312831293130313131323133313431353136313731383139314031413142314331443145314631473148314931503151315231533154315531563157315831593160316131623163316431653166316731683169317031713172317331743175317631773178317931803181318231833184318531863187318831893190319131923193319431953196319731983199320032013202320332043205320632073208320932103211321232133214321532163217321832193220322132223223322432253226322732283229323032313232323332343235323632373238323932403241324232433244324532463247324832493250325132523253325432553256325732583259326032613262326332643265326632673268326932703271327232733274327532763277327832793280328132823283328432853286328732883289329032913292329332943295329632973298329933003301330233033304330533063307330833093310331133123313331433153316331733183319332033213322332333243325332633273328332933303331333233333334333533363337333833393340334133423343334433453346334733483349335033513352335333543355335633573358335933603361336233633364336533663367336833693370337133723373337433753376337733783379338033813382338333843385338633873388338933903391339233933394339533963397339833993400340134023403340434053406340734083409341034113412341334143415341634173418341934203421342234233424342534263427342834293430343134323433343434353436343734383439344034413442344334443445344634473448344934503451345234533454345534563457345834593460346134623463346434653466346734683469347034713472347334743475347634773478347934803481348234833484348534863487348834893490349134923493349434953496349734983499350035013502350335043505350635073508350935103511351235133514351535163517351835193520352135223523352435253526352735283529353035313532353335343535353635373538353935403541354235433544354535463547354835493550355135523553355435553556355735583559356035613562356335643565356635673568356935703571357235733574357535763577357835793580358135823583358435853586358735883589359035913592359335943595359635973598359936003601360236033604360536063607360836093610361136123613361436153616361736183619362036213622362336243625362636273628362936303631363236333634363536363637363836393640364136423643364436453646364736483649365036513652365336543655365636573658365936603661366236633664366536663667366836693670367136723673367436753676367736783679368036813682368336843685368636873688368936903691369236933694369536963697369836993700370137023703370437053706370737083709371037113712371337143715371637173718371937203721
  1. """
  2. Module contains tools for processing Stata files into DataFrames
  3. The StataReader below was originally written by Joe Presbrey as part of PyDTA.
  4. It has been extended and improved by Skipper Seabold from the Statsmodels
  5. project who also developed the StataWriter and was finally added to pandas in
  6. a once again improved version.
  7. You can find more information on http://presbrey.mit.edu/PyDTA and
  8. https://www.statsmodels.org/devel/
  9. """
  10. from __future__ import annotations
  11. from collections import abc
  12. import datetime
  13. from io import BytesIO
  14. import os
  15. import struct
  16. import sys
  17. from types import TracebackType
  18. from typing import (
  19. IO,
  20. TYPE_CHECKING,
  21. Any,
  22. AnyStr,
  23. Callable,
  24. Final,
  25. Hashable,
  26. Sequence,
  27. cast,
  28. )
  29. import warnings
  30. from dateutil.relativedelta import relativedelta
  31. import numpy as np
  32. from pandas._libs.lib import infer_dtype
  33. from pandas._libs.writers import max_len_string_array
  34. from pandas._typing import (
  35. CompressionOptions,
  36. FilePath,
  37. ReadBuffer,
  38. StorageOptions,
  39. WriteBuffer,
  40. )
  41. from pandas.errors import (
  42. CategoricalConversionWarning,
  43. InvalidColumnName,
  44. PossiblePrecisionLoss,
  45. ValueLabelTypeMismatch,
  46. )
  47. from pandas.util._decorators import (
  48. Appender,
  49. doc,
  50. )
  51. from pandas.util._exceptions import find_stack_level
  52. from pandas.core.dtypes.common import (
  53. ensure_object,
  54. is_categorical_dtype,
  55. is_datetime64_dtype,
  56. is_numeric_dtype,
  57. )
  58. from pandas import (
  59. Categorical,
  60. DatetimeIndex,
  61. NaT,
  62. Timestamp,
  63. isna,
  64. to_datetime,
  65. to_timedelta,
  66. )
  67. from pandas.core.arrays.boolean import BooleanDtype
  68. from pandas.core.arrays.integer import IntegerDtype
  69. from pandas.core.frame import DataFrame
  70. from pandas.core.indexes.base import Index
  71. from pandas.core.series import Series
  72. from pandas.core.shared_docs import _shared_docs
  73. from pandas.io.common import get_handle
  74. if TYPE_CHECKING:
  75. from typing import Literal
  76. _version_error = (
  77. "Version of given Stata file is {version}. pandas supports importing "
  78. "versions 105, 108, 111 (Stata 7SE), 113 (Stata 8/9), "
  79. "114 (Stata 10/11), 115 (Stata 12), 117 (Stata 13), 118 (Stata 14/15/16),"
  80. "and 119 (Stata 15/16, over 32,767 variables)."
  81. )
  82. _statafile_processing_params1 = """\
  83. convert_dates : bool, default True
  84. Convert date variables to DataFrame time values.
  85. convert_categoricals : bool, default True
  86. Read value labels and convert columns to Categorical/Factor variables."""
  87. _statafile_processing_params2 = """\
  88. index_col : str, optional
  89. Column to set as index.
  90. convert_missing : bool, default False
  91. Flag indicating whether to convert missing values to their Stata
  92. representations. If False, missing values are replaced with nan.
  93. If True, columns containing missing values are returned with
  94. object data types and missing values are represented by
  95. StataMissingValue objects.
  96. preserve_dtypes : bool, default True
  97. Preserve Stata datatypes. If False, numeric data are upcast to pandas
  98. default types for foreign data (float64 or int64).
  99. columns : list or None
  100. Columns to retain. Columns will be returned in the given order. None
  101. returns all columns.
  102. order_categoricals : bool, default True
  103. Flag indicating whether converted categorical data are ordered."""
  104. _chunksize_params = """\
  105. chunksize : int, default None
  106. Return StataReader object for iterations, returns chunks with
  107. given number of lines."""
  108. _iterator_params = """\
  109. iterator : bool, default False
  110. Return StataReader object."""
  111. _reader_notes = """\
  112. Notes
  113. -----
  114. Categorical variables read through an iterator may not have the same
  115. categories and dtype. This occurs when a variable stored in a DTA
  116. file is associated to an incomplete set of value labels that only
  117. label a strict subset of the values."""
  118. _read_stata_doc = f"""
  119. Read Stata file into DataFrame.
  120. Parameters
  121. ----------
  122. filepath_or_buffer : str, path object or file-like object
  123. Any valid string path is acceptable. The string could be a URL. Valid
  124. URL schemes include http, ftp, s3, and file. For file URLs, a host is
  125. expected. A local file could be: ``file://localhost/path/to/table.dta``.
  126. If you want to pass in a path object, pandas accepts any ``os.PathLike``.
  127. By file-like object, we refer to objects with a ``read()`` method,
  128. such as a file handle (e.g. via builtin ``open`` function)
  129. or ``StringIO``.
  130. {_statafile_processing_params1}
  131. {_statafile_processing_params2}
  132. {_chunksize_params}
  133. {_iterator_params}
  134. {_shared_docs["decompression_options"] % "filepath_or_buffer"}
  135. {_shared_docs["storage_options"]}
  136. Returns
  137. -------
  138. DataFrame or StataReader
  139. See Also
  140. --------
  141. io.stata.StataReader : Low-level reader for Stata data files.
  142. DataFrame.to_stata: Export Stata data files.
  143. {_reader_notes}
  144. Examples
  145. --------
  146. Creating a dummy stata for this example
  147. >>> df = pd.DataFrame({{'animal': ['falcon', 'parrot', 'falcon', 'parrot'],
  148. ... 'speed': [350, 18, 361, 15]}}) # doctest: +SKIP
  149. >>> df.to_stata('animals.dta') # doctest: +SKIP
  150. Read a Stata dta file:
  151. >>> df = pd.read_stata('animals.dta') # doctest: +SKIP
  152. Read a Stata dta file in 10,000 line chunks:
  153. >>> values = np.random.randint(0, 10, size=(20_000, 1), dtype="uint8") # doctest: +SKIP
  154. >>> df = pd.DataFrame(values, columns=["i"]) # doctest: +SKIP
  155. >>> df.to_stata('filename.dta') # doctest: +SKIP
  156. >>> with pd.read_stata('filename.dta', chunksize=10000) as itr: # doctest: +SKIP
  157. >>> for chunk in itr:
  158. ... # Operate on a single chunk, e.g., chunk.mean()
  159. ... pass # doctest: +SKIP
  160. """
  161. _read_method_doc = f"""\
  162. Reads observations from Stata file, converting them into a dataframe
  163. Parameters
  164. ----------
  165. nrows : int
  166. Number of lines to read from data file, if None read whole file.
  167. {_statafile_processing_params1}
  168. {_statafile_processing_params2}
  169. Returns
  170. -------
  171. DataFrame
  172. """
  173. _stata_reader_doc = f"""\
  174. Class for reading Stata dta files.
  175. Parameters
  176. ----------
  177. path_or_buf : path (string), buffer or path object
  178. string, path object (pathlib.Path or py._path.local.LocalPath) or object
  179. implementing a binary read() functions.
  180. {_statafile_processing_params1}
  181. {_statafile_processing_params2}
  182. {_chunksize_params}
  183. {_shared_docs["decompression_options"]}
  184. {_shared_docs["storage_options"]}
  185. {_reader_notes}
  186. """
  187. _date_formats = ["%tc", "%tC", "%td", "%d", "%tw", "%tm", "%tq", "%th", "%ty"]
  188. stata_epoch: Final = datetime.datetime(1960, 1, 1)
  189. # TODO: Add typing. As of January 2020 it is not possible to type this function since
  190. # mypy doesn't understand that a Series and an int can be combined using mathematical
  191. # operations. (+, -).
  192. def _stata_elapsed_date_to_datetime_vec(dates, fmt) -> Series:
  193. """
  194. Convert from SIF to datetime. https://www.stata.com/help.cgi?datetime
  195. Parameters
  196. ----------
  197. dates : Series
  198. The Stata Internal Format date to convert to datetime according to fmt
  199. fmt : str
  200. The format to convert to. Can be, tc, td, tw, tm, tq, th, ty
  201. Returns
  202. Returns
  203. -------
  204. converted : Series
  205. The converted dates
  206. Examples
  207. --------
  208. >>> dates = pd.Series([52])
  209. >>> _stata_elapsed_date_to_datetime_vec(dates , "%tw")
  210. 0 1961-01-01
  211. dtype: datetime64[ns]
  212. Notes
  213. -----
  214. datetime/c - tc
  215. milliseconds since 01jan1960 00:00:00.000, assuming 86,400 s/day
  216. datetime/C - tC - NOT IMPLEMENTED
  217. milliseconds since 01jan1960 00:00:00.000, adjusted for leap seconds
  218. date - td
  219. days since 01jan1960 (01jan1960 = 0)
  220. weekly date - tw
  221. weeks since 1960w1
  222. This assumes 52 weeks in a year, then adds 7 * remainder of the weeks.
  223. The datetime value is the start of the week in terms of days in the
  224. year, not ISO calendar weeks.
  225. monthly date - tm
  226. months since 1960m1
  227. quarterly date - tq
  228. quarters since 1960q1
  229. half-yearly date - th
  230. half-years since 1960h1 yearly
  231. date - ty
  232. years since 0000
  233. """
  234. MIN_YEAR, MAX_YEAR = Timestamp.min.year, Timestamp.max.year
  235. MAX_DAY_DELTA = (Timestamp.max - datetime.datetime(1960, 1, 1)).days
  236. MIN_DAY_DELTA = (Timestamp.min - datetime.datetime(1960, 1, 1)).days
  237. MIN_MS_DELTA = MIN_DAY_DELTA * 24 * 3600 * 1000
  238. MAX_MS_DELTA = MAX_DAY_DELTA * 24 * 3600 * 1000
  239. def convert_year_month_safe(year, month) -> Series:
  240. """
  241. Convert year and month to datetimes, using pandas vectorized versions
  242. when the date range falls within the range supported by pandas.
  243. Otherwise it falls back to a slower but more robust method
  244. using datetime.
  245. """
  246. if year.max() < MAX_YEAR and year.min() > MIN_YEAR:
  247. return to_datetime(100 * year + month, format="%Y%m")
  248. else:
  249. index = getattr(year, "index", None)
  250. return Series(
  251. [datetime.datetime(y, m, 1) for y, m in zip(year, month)], index=index
  252. )
  253. def convert_year_days_safe(year, days) -> Series:
  254. """
  255. Converts year (e.g. 1999) and days since the start of the year to a
  256. datetime or datetime64 Series
  257. """
  258. if year.max() < (MAX_YEAR - 1) and year.min() > MIN_YEAR:
  259. return to_datetime(year, format="%Y") + to_timedelta(days, unit="d")
  260. else:
  261. index = getattr(year, "index", None)
  262. value = [
  263. datetime.datetime(y, 1, 1) + relativedelta(days=int(d))
  264. for y, d in zip(year, days)
  265. ]
  266. return Series(value, index=index)
  267. def convert_delta_safe(base, deltas, unit) -> Series:
  268. """
  269. Convert base dates and deltas to datetimes, using pandas vectorized
  270. versions if the deltas satisfy restrictions required to be expressed
  271. as dates in pandas.
  272. """
  273. index = getattr(deltas, "index", None)
  274. if unit == "d":
  275. if deltas.max() > MAX_DAY_DELTA or deltas.min() < MIN_DAY_DELTA:
  276. values = [base + relativedelta(days=int(d)) for d in deltas]
  277. return Series(values, index=index)
  278. elif unit == "ms":
  279. if deltas.max() > MAX_MS_DELTA or deltas.min() < MIN_MS_DELTA:
  280. values = [
  281. base + relativedelta(microseconds=(int(d) * 1000)) for d in deltas
  282. ]
  283. return Series(values, index=index)
  284. else:
  285. raise ValueError("format not understood")
  286. base = to_datetime(base)
  287. deltas = to_timedelta(deltas, unit=unit)
  288. return base + deltas
  289. # TODO(non-nano): If/when pandas supports more than datetime64[ns], this
  290. # should be improved to use correct range, e.g. datetime[Y] for yearly
  291. bad_locs = np.isnan(dates)
  292. has_bad_values = False
  293. if bad_locs.any():
  294. has_bad_values = True
  295. # reset cache to avoid SettingWithCopy checks (we own the DataFrame and the
  296. # `dates` Series is used to overwrite itself in the DataFramae)
  297. dates._reset_cacher()
  298. dates[bad_locs] = 1.0 # Replace with NaT
  299. dates = dates.astype(np.int64)
  300. if fmt.startswith(("%tc", "tc")): # Delta ms relative to base
  301. base = stata_epoch
  302. ms = dates
  303. conv_dates = convert_delta_safe(base, ms, "ms")
  304. elif fmt.startswith(("%tC", "tC")):
  305. warnings.warn(
  306. "Encountered %tC format. Leaving in Stata Internal Format.",
  307. stacklevel=find_stack_level(),
  308. )
  309. conv_dates = Series(dates, dtype=object)
  310. if has_bad_values:
  311. conv_dates[bad_locs] = NaT
  312. return conv_dates
  313. # Delta days relative to base
  314. elif fmt.startswith(("%td", "td", "%d", "d")):
  315. base = stata_epoch
  316. days = dates
  317. conv_dates = convert_delta_safe(base, days, "d")
  318. # does not count leap days - 7 days is a week.
  319. # 52nd week may have more than 7 days
  320. elif fmt.startswith(("%tw", "tw")):
  321. year = stata_epoch.year + dates // 52
  322. days = (dates % 52) * 7
  323. conv_dates = convert_year_days_safe(year, days)
  324. elif fmt.startswith(("%tm", "tm")): # Delta months relative to base
  325. year = stata_epoch.year + dates // 12
  326. month = (dates % 12) + 1
  327. conv_dates = convert_year_month_safe(year, month)
  328. elif fmt.startswith(("%tq", "tq")): # Delta quarters relative to base
  329. year = stata_epoch.year + dates // 4
  330. quarter_month = (dates % 4) * 3 + 1
  331. conv_dates = convert_year_month_safe(year, quarter_month)
  332. elif fmt.startswith(("%th", "th")): # Delta half-years relative to base
  333. year = stata_epoch.year + dates // 2
  334. month = (dates % 2) * 6 + 1
  335. conv_dates = convert_year_month_safe(year, month)
  336. elif fmt.startswith(("%ty", "ty")): # Years -- not delta
  337. year = dates
  338. first_month = np.ones_like(dates)
  339. conv_dates = convert_year_month_safe(year, first_month)
  340. else:
  341. raise ValueError(f"Date fmt {fmt} not understood")
  342. if has_bad_values: # Restore NaT for bad values
  343. conv_dates[bad_locs] = NaT
  344. return conv_dates
  345. def _datetime_to_stata_elapsed_vec(dates: Series, fmt: str) -> Series:
  346. """
  347. Convert from datetime to SIF. https://www.stata.com/help.cgi?datetime
  348. Parameters
  349. ----------
  350. dates : Series
  351. Series or array containing datetime.datetime or datetime64[ns] to
  352. convert to the Stata Internal Format given by fmt
  353. fmt : str
  354. The format to convert to. Can be, tc, td, tw, tm, tq, th, ty
  355. """
  356. index = dates.index
  357. NS_PER_DAY = 24 * 3600 * 1000 * 1000 * 1000
  358. US_PER_DAY = NS_PER_DAY / 1000
  359. def parse_dates_safe(
  360. dates, delta: bool = False, year: bool = False, days: bool = False
  361. ):
  362. d = {}
  363. if is_datetime64_dtype(dates.dtype):
  364. if delta:
  365. time_delta = dates - Timestamp(stata_epoch).as_unit("ns")
  366. d["delta"] = time_delta._values.view(np.int64) // 1000 # microseconds
  367. if days or year:
  368. date_index = DatetimeIndex(dates)
  369. d["year"] = date_index._data.year
  370. d["month"] = date_index._data.month
  371. if days:
  372. days_in_ns = dates.view(np.int64) - to_datetime(
  373. d["year"], format="%Y"
  374. ).view(np.int64)
  375. d["days"] = days_in_ns // NS_PER_DAY
  376. elif infer_dtype(dates, skipna=False) == "datetime":
  377. if delta:
  378. delta = dates._values - stata_epoch
  379. def f(x: datetime.timedelta) -> float:
  380. return US_PER_DAY * x.days + 1000000 * x.seconds + x.microseconds
  381. v = np.vectorize(f)
  382. d["delta"] = v(delta)
  383. if year:
  384. year_month = dates.apply(lambda x: 100 * x.year + x.month)
  385. d["year"] = year_month._values // 100
  386. d["month"] = year_month._values - d["year"] * 100
  387. if days:
  388. def g(x: datetime.datetime) -> int:
  389. return (x - datetime.datetime(x.year, 1, 1)).days
  390. v = np.vectorize(g)
  391. d["days"] = v(dates)
  392. else:
  393. raise ValueError(
  394. "Columns containing dates must contain either "
  395. "datetime64, datetime.datetime or null values."
  396. )
  397. return DataFrame(d, index=index)
  398. bad_loc = isna(dates)
  399. index = dates.index
  400. if bad_loc.any():
  401. dates = Series(dates)
  402. if is_datetime64_dtype(dates):
  403. dates[bad_loc] = to_datetime(stata_epoch)
  404. else:
  405. dates[bad_loc] = stata_epoch
  406. if fmt in ["%tc", "tc"]:
  407. d = parse_dates_safe(dates, delta=True)
  408. conv_dates = d.delta / 1000
  409. elif fmt in ["%tC", "tC"]:
  410. warnings.warn(
  411. "Stata Internal Format tC not supported.",
  412. stacklevel=find_stack_level(),
  413. )
  414. conv_dates = dates
  415. elif fmt in ["%td", "td"]:
  416. d = parse_dates_safe(dates, delta=True)
  417. conv_dates = d.delta // US_PER_DAY
  418. elif fmt in ["%tw", "tw"]:
  419. d = parse_dates_safe(dates, year=True, days=True)
  420. conv_dates = 52 * (d.year - stata_epoch.year) + d.days // 7
  421. elif fmt in ["%tm", "tm"]:
  422. d = parse_dates_safe(dates, year=True)
  423. conv_dates = 12 * (d.year - stata_epoch.year) + d.month - 1
  424. elif fmt in ["%tq", "tq"]:
  425. d = parse_dates_safe(dates, year=True)
  426. conv_dates = 4 * (d.year - stata_epoch.year) + (d.month - 1) // 3
  427. elif fmt in ["%th", "th"]:
  428. d = parse_dates_safe(dates, year=True)
  429. conv_dates = 2 * (d.year - stata_epoch.year) + (d.month > 6).astype(int)
  430. elif fmt in ["%ty", "ty"]:
  431. d = parse_dates_safe(dates, year=True)
  432. conv_dates = d.year
  433. else:
  434. raise ValueError(f"Format {fmt} is not a known Stata date format")
  435. conv_dates = Series(conv_dates, dtype=np.float64)
  436. missing_value = struct.unpack("<d", b"\x00\x00\x00\x00\x00\x00\xe0\x7f")[0]
  437. conv_dates[bad_loc] = missing_value
  438. return Series(conv_dates, index=index)
  439. excessive_string_length_error: Final = """
  440. Fixed width strings in Stata .dta files are limited to 244 (or fewer)
  441. characters. Column '{0}' does not satisfy this restriction. Use the
  442. 'version=117' parameter to write the newer (Stata 13 and later) format.
  443. """
  444. precision_loss_doc: Final = """
  445. Column converted from {0} to {1}, and some data are outside of the lossless
  446. conversion range. This may result in a loss of precision in the saved data.
  447. """
  448. value_label_mismatch_doc: Final = """
  449. Stata value labels (pandas categories) must be strings. Column {0} contains
  450. non-string labels which will be converted to strings. Please check that the
  451. Stata data file created has not lost information due to duplicate labels.
  452. """
  453. invalid_name_doc: Final = """
  454. Not all pandas column names were valid Stata variable names.
  455. The following replacements have been made:
  456. {0}
  457. If this is not what you expect, please make sure you have Stata-compliant
  458. column names in your DataFrame (strings only, max 32 characters, only
  459. alphanumerics and underscores, no Stata reserved words)
  460. """
  461. categorical_conversion_warning: Final = """
  462. One or more series with value labels are not fully labeled. Reading this
  463. dataset with an iterator results in categorical variable with different
  464. categories. This occurs since it is not possible to know all possible values
  465. until the entire dataset has been read. To avoid this warning, you can either
  466. read dataset without an iterator, or manually convert categorical data by
  467. ``convert_categoricals`` to False and then accessing the variable labels
  468. through the value_labels method of the reader.
  469. """
  470. def _cast_to_stata_types(data: DataFrame) -> DataFrame:
  471. """
  472. Checks the dtypes of the columns of a pandas DataFrame for
  473. compatibility with the data types and ranges supported by Stata, and
  474. converts if necessary.
  475. Parameters
  476. ----------
  477. data : DataFrame
  478. The DataFrame to check and convert
  479. Notes
  480. -----
  481. Numeric columns in Stata must be one of int8, int16, int32, float32 or
  482. float64, with some additional value restrictions. int8 and int16 columns
  483. are checked for violations of the value restrictions and upcast if needed.
  484. int64 data is not usable in Stata, and so it is downcast to int32 whenever
  485. the value are in the int32 range, and sidecast to float64 when larger than
  486. this range. If the int64 values are outside of the range of those
  487. perfectly representable as float64 values, a warning is raised.
  488. bool columns are cast to int8. uint columns are converted to int of the
  489. same size if there is no loss in precision, otherwise are upcast to a
  490. larger type. uint64 is currently not supported since it is concerted to
  491. object in a DataFrame.
  492. """
  493. ws = ""
  494. # original, if small, if large
  495. conversion_data: tuple[
  496. tuple[type, type, type],
  497. tuple[type, type, type],
  498. tuple[type, type, type],
  499. tuple[type, type, type],
  500. tuple[type, type, type],
  501. ] = (
  502. (np.bool_, np.int8, np.int8),
  503. (np.uint8, np.int8, np.int16),
  504. (np.uint16, np.int16, np.int32),
  505. (np.uint32, np.int32, np.int64),
  506. (np.uint64, np.int64, np.float64),
  507. )
  508. float32_max = struct.unpack("<f", b"\xff\xff\xff\x7e")[0]
  509. float64_max = struct.unpack("<d", b"\xff\xff\xff\xff\xff\xff\xdf\x7f")[0]
  510. for col in data:
  511. # Cast from unsupported types to supported types
  512. is_nullable_int = isinstance(data[col].dtype, (IntegerDtype, BooleanDtype))
  513. orig = data[col]
  514. # We need to find orig_missing before altering data below
  515. orig_missing = orig.isna()
  516. if is_nullable_int:
  517. missing_loc = data[col].isna()
  518. if missing_loc.any():
  519. # Replace with always safe value
  520. fv = 0 if isinstance(data[col].dtype, IntegerDtype) else False
  521. data.loc[missing_loc, col] = fv
  522. # Replace with NumPy-compatible column
  523. data[col] = data[col].astype(data[col].dtype.numpy_dtype)
  524. dtype = data[col].dtype
  525. for c_data in conversion_data:
  526. if dtype == c_data[0]:
  527. if data[col].max() <= np.iinfo(c_data[1]).max:
  528. dtype = c_data[1]
  529. else:
  530. dtype = c_data[2]
  531. if c_data[2] == np.int64: # Warn if necessary
  532. if data[col].max() >= 2**53:
  533. ws = precision_loss_doc.format("uint64", "float64")
  534. data[col] = data[col].astype(dtype)
  535. # Check values and upcast if necessary
  536. if dtype == np.int8:
  537. if data[col].max() > 100 or data[col].min() < -127:
  538. data[col] = data[col].astype(np.int16)
  539. elif dtype == np.int16:
  540. if data[col].max() > 32740 or data[col].min() < -32767:
  541. data[col] = data[col].astype(np.int32)
  542. elif dtype == np.int64:
  543. if data[col].max() <= 2147483620 and data[col].min() >= -2147483647:
  544. data[col] = data[col].astype(np.int32)
  545. else:
  546. data[col] = data[col].astype(np.float64)
  547. if data[col].max() >= 2**53 or data[col].min() <= -(2**53):
  548. ws = precision_loss_doc.format("int64", "float64")
  549. elif dtype in (np.float32, np.float64):
  550. if np.isinf(data[col]).any():
  551. raise ValueError(
  552. f"Column {col} contains infinity or -infinity"
  553. "which is outside the range supported by Stata."
  554. )
  555. value = data[col].max()
  556. if dtype == np.float32 and value > float32_max:
  557. data[col] = data[col].astype(np.float64)
  558. elif dtype == np.float64:
  559. if value > float64_max:
  560. raise ValueError(
  561. f"Column {col} has a maximum value ({value}) outside the range "
  562. f"supported by Stata ({float64_max})"
  563. )
  564. if is_nullable_int:
  565. if orig_missing.any():
  566. # Replace missing by Stata sentinel value
  567. sentinel = StataMissingValue.BASE_MISSING_VALUES[data[col].dtype.name]
  568. data.loc[orig_missing, col] = sentinel
  569. if ws:
  570. warnings.warn(
  571. ws,
  572. PossiblePrecisionLoss,
  573. stacklevel=find_stack_level(),
  574. )
  575. return data
  576. class StataValueLabel:
  577. """
  578. Parse a categorical column and prepare formatted output
  579. Parameters
  580. ----------
  581. catarray : Series
  582. Categorical Series to encode
  583. encoding : {"latin-1", "utf-8"}
  584. Encoding to use for value labels.
  585. """
  586. def __init__(
  587. self, catarray: Series, encoding: Literal["latin-1", "utf-8"] = "latin-1"
  588. ) -> None:
  589. if encoding not in ("latin-1", "utf-8"):
  590. raise ValueError("Only latin-1 and utf-8 are supported.")
  591. self.labname = catarray.name
  592. self._encoding = encoding
  593. categories = catarray.cat.categories
  594. self.value_labels: list[tuple[float, str]] = list(
  595. zip(np.arange(len(categories)), categories)
  596. )
  597. self.value_labels.sort(key=lambda x: x[0])
  598. self._prepare_value_labels()
  599. def _prepare_value_labels(self):
  600. """Encode value labels."""
  601. self.text_len = 0
  602. self.txt: list[bytes] = []
  603. self.n = 0
  604. # Offsets (length of categories), converted to int32
  605. self.off = np.array([], dtype=np.int32)
  606. # Values, converted to int32
  607. self.val = np.array([], dtype=np.int32)
  608. self.len = 0
  609. # Compute lengths and setup lists of offsets and labels
  610. offsets: list[int] = []
  611. values: list[float] = []
  612. for vl in self.value_labels:
  613. category: str | bytes = vl[1]
  614. if not isinstance(category, str):
  615. category = str(category)
  616. warnings.warn(
  617. value_label_mismatch_doc.format(self.labname),
  618. ValueLabelTypeMismatch,
  619. stacklevel=find_stack_level(),
  620. )
  621. category = category.encode(self._encoding)
  622. offsets.append(self.text_len)
  623. self.text_len += len(category) + 1 # +1 for the padding
  624. values.append(vl[0])
  625. self.txt.append(category)
  626. self.n += 1
  627. if self.text_len > 32000:
  628. raise ValueError(
  629. "Stata value labels for a single variable must "
  630. "have a combined length less than 32,000 characters."
  631. )
  632. # Ensure int32
  633. self.off = np.array(offsets, dtype=np.int32)
  634. self.val = np.array(values, dtype=np.int32)
  635. # Total length
  636. self.len = 4 + 4 + 4 * self.n + 4 * self.n + self.text_len
  637. def generate_value_label(self, byteorder: str) -> bytes:
  638. """
  639. Generate the binary representation of the value labels.
  640. Parameters
  641. ----------
  642. byteorder : str
  643. Byte order of the output
  644. Returns
  645. -------
  646. value_label : bytes
  647. Bytes containing the formatted value label
  648. """
  649. encoding = self._encoding
  650. bio = BytesIO()
  651. null_byte = b"\x00"
  652. # len
  653. bio.write(struct.pack(byteorder + "i", self.len))
  654. # labname
  655. labname = str(self.labname)[:32].encode(encoding)
  656. lab_len = 32 if encoding not in ("utf-8", "utf8") else 128
  657. labname = _pad_bytes(labname, lab_len + 1)
  658. bio.write(labname)
  659. # padding - 3 bytes
  660. for i in range(3):
  661. bio.write(struct.pack("c", null_byte))
  662. # value_label_table
  663. # n - int32
  664. bio.write(struct.pack(byteorder + "i", self.n))
  665. # textlen - int32
  666. bio.write(struct.pack(byteorder + "i", self.text_len))
  667. # off - int32 array (n elements)
  668. for offset in self.off:
  669. bio.write(struct.pack(byteorder + "i", offset))
  670. # val - int32 array (n elements)
  671. for value in self.val:
  672. bio.write(struct.pack(byteorder + "i", value))
  673. # txt - Text labels, null terminated
  674. for text in self.txt:
  675. bio.write(text + null_byte)
  676. return bio.getvalue()
  677. class StataNonCatValueLabel(StataValueLabel):
  678. """
  679. Prepare formatted version of value labels
  680. Parameters
  681. ----------
  682. labname : str
  683. Value label name
  684. value_labels: Dictionary
  685. Mapping of values to labels
  686. encoding : {"latin-1", "utf-8"}
  687. Encoding to use for value labels.
  688. """
  689. def __init__(
  690. self,
  691. labname: str,
  692. value_labels: dict[float, str],
  693. encoding: Literal["latin-1", "utf-8"] = "latin-1",
  694. ) -> None:
  695. if encoding not in ("latin-1", "utf-8"):
  696. raise ValueError("Only latin-1 and utf-8 are supported.")
  697. self.labname = labname
  698. self._encoding = encoding
  699. self.value_labels: list[tuple[float, str]] = sorted(
  700. value_labels.items(), key=lambda x: x[0]
  701. )
  702. self._prepare_value_labels()
  703. class StataMissingValue:
  704. """
  705. An observation's missing value.
  706. Parameters
  707. ----------
  708. value : {int, float}
  709. The Stata missing value code
  710. Notes
  711. -----
  712. More information: <https://www.stata.com/help.cgi?missing>
  713. Integer missing values make the code '.', '.a', ..., '.z' to the ranges
  714. 101 ... 127 (for int8), 32741 ... 32767 (for int16) and 2147483621 ...
  715. 2147483647 (for int32). Missing values for floating point data types are
  716. more complex but the pattern is simple to discern from the following table.
  717. np.float32 missing values (float in Stata)
  718. 0000007f .
  719. 0008007f .a
  720. 0010007f .b
  721. ...
  722. 00c0007f .x
  723. 00c8007f .y
  724. 00d0007f .z
  725. np.float64 missing values (double in Stata)
  726. 000000000000e07f .
  727. 000000000001e07f .a
  728. 000000000002e07f .b
  729. ...
  730. 000000000018e07f .x
  731. 000000000019e07f .y
  732. 00000000001ae07f .z
  733. """
  734. # Construct a dictionary of missing values
  735. MISSING_VALUES: dict[float, str] = {}
  736. bases: Final = (101, 32741, 2147483621)
  737. for b in bases:
  738. # Conversion to long to avoid hash issues on 32 bit platforms #8968
  739. MISSING_VALUES[b] = "."
  740. for i in range(1, 27):
  741. MISSING_VALUES[i + b] = "." + chr(96 + i)
  742. float32_base: bytes = b"\x00\x00\x00\x7f"
  743. increment: int = struct.unpack("<i", b"\x00\x08\x00\x00")[0]
  744. for i in range(27):
  745. key = struct.unpack("<f", float32_base)[0]
  746. MISSING_VALUES[key] = "."
  747. if i > 0:
  748. MISSING_VALUES[key] += chr(96 + i)
  749. int_value = struct.unpack("<i", struct.pack("<f", key))[0] + increment
  750. float32_base = struct.pack("<i", int_value)
  751. float64_base: bytes = b"\x00\x00\x00\x00\x00\x00\xe0\x7f"
  752. increment = struct.unpack("q", b"\x00\x00\x00\x00\x00\x01\x00\x00")[0]
  753. for i in range(27):
  754. key = struct.unpack("<d", float64_base)[0]
  755. MISSING_VALUES[key] = "."
  756. if i > 0:
  757. MISSING_VALUES[key] += chr(96 + i)
  758. int_value = struct.unpack("q", struct.pack("<d", key))[0] + increment
  759. float64_base = struct.pack("q", int_value)
  760. BASE_MISSING_VALUES: Final = {
  761. "int8": 101,
  762. "int16": 32741,
  763. "int32": 2147483621,
  764. "float32": struct.unpack("<f", float32_base)[0],
  765. "float64": struct.unpack("<d", float64_base)[0],
  766. }
  767. def __init__(self, value: float) -> None:
  768. self._value = value
  769. # Conversion to int to avoid hash issues on 32 bit platforms #8968
  770. value = int(value) if value < 2147483648 else float(value)
  771. self._str = self.MISSING_VALUES[value]
  772. @property
  773. def string(self) -> str:
  774. """
  775. The Stata representation of the missing value: '.', '.a'..'.z'
  776. Returns
  777. -------
  778. str
  779. The representation of the missing value.
  780. """
  781. return self._str
  782. @property
  783. def value(self) -> float:
  784. """
  785. The binary representation of the missing value.
  786. Returns
  787. -------
  788. {int, float}
  789. The binary representation of the missing value.
  790. """
  791. return self._value
  792. def __str__(self) -> str:
  793. return self.string
  794. def __repr__(self) -> str:
  795. return f"{type(self)}({self})"
  796. def __eq__(self, other: Any) -> bool:
  797. return (
  798. isinstance(other, type(self))
  799. and self.string == other.string
  800. and self.value == other.value
  801. )
  802. @classmethod
  803. def get_base_missing_value(cls, dtype: np.dtype) -> float:
  804. if dtype.type is np.int8:
  805. value = cls.BASE_MISSING_VALUES["int8"]
  806. elif dtype.type is np.int16:
  807. value = cls.BASE_MISSING_VALUES["int16"]
  808. elif dtype.type is np.int32:
  809. value = cls.BASE_MISSING_VALUES["int32"]
  810. elif dtype.type is np.float32:
  811. value = cls.BASE_MISSING_VALUES["float32"]
  812. elif dtype.type is np.float64:
  813. value = cls.BASE_MISSING_VALUES["float64"]
  814. else:
  815. raise ValueError("Unsupported dtype")
  816. return value
  817. class StataParser:
  818. def __init__(self) -> None:
  819. # type code.
  820. # --------------------
  821. # str1 1 = 0x01
  822. # str2 2 = 0x02
  823. # ...
  824. # str244 244 = 0xf4
  825. # byte 251 = 0xfb (sic)
  826. # int 252 = 0xfc
  827. # long 253 = 0xfd
  828. # float 254 = 0xfe
  829. # double 255 = 0xff
  830. # --------------------
  831. # NOTE: the byte type seems to be reserved for categorical variables
  832. # with a label, but the underlying variable is -127 to 100
  833. # we're going to drop the label and cast to int
  834. self.DTYPE_MAP = dict(
  835. list(zip(range(1, 245), [np.dtype("a" + str(i)) for i in range(1, 245)]))
  836. + [
  837. (251, np.dtype(np.int8)),
  838. (252, np.dtype(np.int16)),
  839. (253, np.dtype(np.int32)),
  840. (254, np.dtype(np.float32)),
  841. (255, np.dtype(np.float64)),
  842. ]
  843. )
  844. self.DTYPE_MAP_XML: dict[int, np.dtype] = {
  845. 32768: np.dtype(np.uint8), # Keys to GSO
  846. 65526: np.dtype(np.float64),
  847. 65527: np.dtype(np.float32),
  848. 65528: np.dtype(np.int32),
  849. 65529: np.dtype(np.int16),
  850. 65530: np.dtype(np.int8),
  851. }
  852. self.TYPE_MAP = list(tuple(range(251)) + tuple("bhlfd"))
  853. self.TYPE_MAP_XML = {
  854. # Not really a Q, unclear how to handle byteswap
  855. 32768: "Q",
  856. 65526: "d",
  857. 65527: "f",
  858. 65528: "l",
  859. 65529: "h",
  860. 65530: "b",
  861. }
  862. # NOTE: technically, some of these are wrong. there are more numbers
  863. # that can be represented. it's the 27 ABOVE and BELOW the max listed
  864. # numeric data type in [U] 12.2.2 of the 11.2 manual
  865. float32_min = b"\xff\xff\xff\xfe"
  866. float32_max = b"\xff\xff\xff\x7e"
  867. float64_min = b"\xff\xff\xff\xff\xff\xff\xef\xff"
  868. float64_max = b"\xff\xff\xff\xff\xff\xff\xdf\x7f"
  869. self.VALID_RANGE = {
  870. "b": (-127, 100),
  871. "h": (-32767, 32740),
  872. "l": (-2147483647, 2147483620),
  873. "f": (
  874. np.float32(struct.unpack("<f", float32_min)[0]),
  875. np.float32(struct.unpack("<f", float32_max)[0]),
  876. ),
  877. "d": (
  878. np.float64(struct.unpack("<d", float64_min)[0]),
  879. np.float64(struct.unpack("<d", float64_max)[0]),
  880. ),
  881. }
  882. self.OLD_TYPE_MAPPING = {
  883. 98: 251, # byte
  884. 105: 252, # int
  885. 108: 253, # long
  886. 102: 254, # float
  887. 100: 255, # double
  888. }
  889. # These missing values are the generic '.' in Stata, and are used
  890. # to replace nans
  891. self.MISSING_VALUES = {
  892. "b": 101,
  893. "h": 32741,
  894. "l": 2147483621,
  895. "f": np.float32(struct.unpack("<f", b"\x00\x00\x00\x7f")[0]),
  896. "d": np.float64(
  897. struct.unpack("<d", b"\x00\x00\x00\x00\x00\x00\xe0\x7f")[0]
  898. ),
  899. }
  900. self.NUMPY_TYPE_MAP = {
  901. "b": "i1",
  902. "h": "i2",
  903. "l": "i4",
  904. "f": "f4",
  905. "d": "f8",
  906. "Q": "u8",
  907. }
  908. # Reserved words cannot be used as variable names
  909. self.RESERVED_WORDS = (
  910. "aggregate",
  911. "array",
  912. "boolean",
  913. "break",
  914. "byte",
  915. "case",
  916. "catch",
  917. "class",
  918. "colvector",
  919. "complex",
  920. "const",
  921. "continue",
  922. "default",
  923. "delegate",
  924. "delete",
  925. "do",
  926. "double",
  927. "else",
  928. "eltypedef",
  929. "end",
  930. "enum",
  931. "explicit",
  932. "export",
  933. "external",
  934. "float",
  935. "for",
  936. "friend",
  937. "function",
  938. "global",
  939. "goto",
  940. "if",
  941. "inline",
  942. "int",
  943. "local",
  944. "long",
  945. "NULL",
  946. "pragma",
  947. "protected",
  948. "quad",
  949. "rowvector",
  950. "short",
  951. "typedef",
  952. "typename",
  953. "virtual",
  954. "_all",
  955. "_N",
  956. "_skip",
  957. "_b",
  958. "_pi",
  959. "str#",
  960. "in",
  961. "_pred",
  962. "strL",
  963. "_coef",
  964. "_rc",
  965. "using",
  966. "_cons",
  967. "_se",
  968. "with",
  969. "_n",
  970. )
  971. class StataReader(StataParser, abc.Iterator):
  972. __doc__ = _stata_reader_doc
  973. _path_or_buf: IO[bytes] # Will be assigned by `_open_file`.
  974. def __init__(
  975. self,
  976. path_or_buf: FilePath | ReadBuffer[bytes],
  977. convert_dates: bool = True,
  978. convert_categoricals: bool = True,
  979. index_col: str | None = None,
  980. convert_missing: bool = False,
  981. preserve_dtypes: bool = True,
  982. columns: Sequence[str] | None = None,
  983. order_categoricals: bool = True,
  984. chunksize: int | None = None,
  985. compression: CompressionOptions = "infer",
  986. storage_options: StorageOptions = None,
  987. ) -> None:
  988. super().__init__()
  989. self._col_sizes: list[int] = []
  990. # Arguments to the reader (can be temporarily overridden in
  991. # calls to read).
  992. self._convert_dates = convert_dates
  993. self._convert_categoricals = convert_categoricals
  994. self._index_col = index_col
  995. self._convert_missing = convert_missing
  996. self._preserve_dtypes = preserve_dtypes
  997. self._columns = columns
  998. self._order_categoricals = order_categoricals
  999. self._original_path_or_buf = path_or_buf
  1000. self._compression = compression
  1001. self._storage_options = storage_options
  1002. self._encoding = ""
  1003. self._chunksize = chunksize
  1004. self._using_iterator = False
  1005. self._entered = False
  1006. if self._chunksize is None:
  1007. self._chunksize = 1
  1008. elif not isinstance(chunksize, int) or chunksize <= 0:
  1009. raise ValueError("chunksize must be a positive integer when set.")
  1010. # State variables for the file
  1011. self._close_file: Callable[[], None] | None = None
  1012. self._has_string_data = False
  1013. self._missing_values = False
  1014. self._can_read_value_labels = False
  1015. self._column_selector_set = False
  1016. self._value_labels_read = False
  1017. self._data_read = False
  1018. self._dtype: np.dtype | None = None
  1019. self._lines_read = 0
  1020. self._native_byteorder = _set_endianness(sys.byteorder)
  1021. def _ensure_open(self) -> None:
  1022. """
  1023. Ensure the file has been opened and its header data read.
  1024. """
  1025. if not hasattr(self, "_path_or_buf"):
  1026. self._open_file()
  1027. def _open_file(self) -> None:
  1028. """
  1029. Open the file (with compression options, etc.), and read header information.
  1030. """
  1031. if not self._entered:
  1032. warnings.warn(
  1033. "StataReader is being used without using a context manager. "
  1034. "Using StataReader as a context manager is the only supported method.",
  1035. ResourceWarning,
  1036. stacklevel=find_stack_level(),
  1037. )
  1038. handles = get_handle(
  1039. self._original_path_or_buf,
  1040. "rb",
  1041. storage_options=self._storage_options,
  1042. is_text=False,
  1043. compression=self._compression,
  1044. )
  1045. if hasattr(handles.handle, "seekable") and handles.handle.seekable():
  1046. # If the handle is directly seekable, use it without an extra copy.
  1047. self._path_or_buf = handles.handle
  1048. self._close_file = handles.close
  1049. else:
  1050. # Copy to memory, and ensure no encoding.
  1051. with handles:
  1052. self._path_or_buf = BytesIO(handles.handle.read())
  1053. self._close_file = self._path_or_buf.close
  1054. self._read_header()
  1055. self._setup_dtype()
  1056. def __enter__(self) -> StataReader:
  1057. """enter context manager"""
  1058. self._entered = True
  1059. return self
  1060. def __exit__(
  1061. self,
  1062. exc_type: type[BaseException] | None,
  1063. exc_value: BaseException | None,
  1064. traceback: TracebackType | None,
  1065. ) -> None:
  1066. if self._close_file:
  1067. self._close_file()
  1068. def close(self) -> None:
  1069. """Close the handle if its open.
  1070. .. deprecated: 2.0.0
  1071. The close method is not part of the public API.
  1072. The only supported way to use StataReader is to use it as a context manager.
  1073. """
  1074. warnings.warn(
  1075. "The StataReader.close() method is not part of the public API and "
  1076. "will be removed in a future version without notice. "
  1077. "Using StataReader as a context manager is the only supported method.",
  1078. FutureWarning,
  1079. stacklevel=find_stack_level(),
  1080. )
  1081. if self._close_file:
  1082. self._close_file()
  1083. def _set_encoding(self) -> None:
  1084. """
  1085. Set string encoding which depends on file version
  1086. """
  1087. if self._format_version < 118:
  1088. self._encoding = "latin-1"
  1089. else:
  1090. self._encoding = "utf-8"
  1091. def _read_int8(self) -> int:
  1092. return struct.unpack("b", self._path_or_buf.read(1))[0]
  1093. def _read_uint8(self) -> int:
  1094. return struct.unpack("B", self._path_or_buf.read(1))[0]
  1095. def _read_uint16(self) -> int:
  1096. return struct.unpack(f"{self._byteorder}H", self._path_or_buf.read(2))[0]
  1097. def _read_uint32(self) -> int:
  1098. return struct.unpack(f"{self._byteorder}I", self._path_or_buf.read(4))[0]
  1099. def _read_uint64(self) -> int:
  1100. return struct.unpack(f"{self._byteorder}Q", self._path_or_buf.read(8))[0]
  1101. def _read_int16(self) -> int:
  1102. return struct.unpack(f"{self._byteorder}h", self._path_or_buf.read(2))[0]
  1103. def _read_int32(self) -> int:
  1104. return struct.unpack(f"{self._byteorder}i", self._path_or_buf.read(4))[0]
  1105. def _read_int64(self) -> int:
  1106. return struct.unpack(f"{self._byteorder}q", self._path_or_buf.read(8))[0]
  1107. def _read_char8(self) -> bytes:
  1108. return struct.unpack("c", self._path_or_buf.read(1))[0]
  1109. def _read_int16_count(self, count: int) -> tuple[int, ...]:
  1110. return struct.unpack(
  1111. f"{self._byteorder}{'h' * count}",
  1112. self._path_or_buf.read(2 * count),
  1113. )
  1114. def _read_header(self) -> None:
  1115. first_char = self._read_char8()
  1116. if first_char == b"<":
  1117. self._read_new_header()
  1118. else:
  1119. self._read_old_header(first_char)
  1120. self._has_string_data = len([x for x in self._typlist if type(x) is int]) > 0
  1121. # calculate size of a data record
  1122. self._col_sizes = [self._calcsize(typ) for typ in self._typlist]
  1123. def _read_new_header(self) -> None:
  1124. # The first part of the header is common to 117 - 119.
  1125. self._path_or_buf.read(27) # stata_dta><header><release>
  1126. self._format_version = int(self._path_or_buf.read(3))
  1127. if self._format_version not in [117, 118, 119]:
  1128. raise ValueError(_version_error.format(version=self._format_version))
  1129. self._set_encoding()
  1130. self._path_or_buf.read(21) # </release><byteorder>
  1131. self._byteorder = ">" if self._path_or_buf.read(3) == b"MSF" else "<"
  1132. self._path_or_buf.read(15) # </byteorder><K>
  1133. self._nvar = (
  1134. self._read_uint16() if self._format_version <= 118 else self._read_uint32()
  1135. )
  1136. self._path_or_buf.read(7) # </K><N>
  1137. self._nobs = self._get_nobs()
  1138. self._path_or_buf.read(11) # </N><label>
  1139. self._data_label = self._get_data_label()
  1140. self._path_or_buf.read(19) # </label><timestamp>
  1141. self._time_stamp = self._get_time_stamp()
  1142. self._path_or_buf.read(26) # </timestamp></header><map>
  1143. self._path_or_buf.read(8) # 0x0000000000000000
  1144. self._path_or_buf.read(8) # position of <map>
  1145. self._seek_vartypes = self._read_int64() + 16
  1146. self._seek_varnames = self._read_int64() + 10
  1147. self._seek_sortlist = self._read_int64() + 10
  1148. self._seek_formats = self._read_int64() + 9
  1149. self._seek_value_label_names = self._read_int64() + 19
  1150. # Requires version-specific treatment
  1151. self._seek_variable_labels = self._get_seek_variable_labels()
  1152. self._path_or_buf.read(8) # <characteristics>
  1153. self._data_location = self._read_int64() + 6
  1154. self._seek_strls = self._read_int64() + 7
  1155. self._seek_value_labels = self._read_int64() + 14
  1156. self._typlist, self._dtyplist = self._get_dtypes(self._seek_vartypes)
  1157. self._path_or_buf.seek(self._seek_varnames)
  1158. self._varlist = self._get_varlist()
  1159. self._path_or_buf.seek(self._seek_sortlist)
  1160. self._srtlist = self._read_int16_count(self._nvar + 1)[:-1]
  1161. self._path_or_buf.seek(self._seek_formats)
  1162. self._fmtlist = self._get_fmtlist()
  1163. self._path_or_buf.seek(self._seek_value_label_names)
  1164. self._lbllist = self._get_lbllist()
  1165. self._path_or_buf.seek(self._seek_variable_labels)
  1166. self._variable_labels = self._get_variable_labels()
  1167. # Get data type information, works for versions 117-119.
  1168. def _get_dtypes(
  1169. self, seek_vartypes: int
  1170. ) -> tuple[list[int | str], list[str | np.dtype]]:
  1171. self._path_or_buf.seek(seek_vartypes)
  1172. raw_typlist = [self._read_uint16() for _ in range(self._nvar)]
  1173. def f(typ: int) -> int | str:
  1174. if typ <= 2045:
  1175. return typ
  1176. try:
  1177. return self.TYPE_MAP_XML[typ]
  1178. except KeyError as err:
  1179. raise ValueError(f"cannot convert stata types [{typ}]") from err
  1180. typlist = [f(x) for x in raw_typlist]
  1181. def g(typ: int) -> str | np.dtype:
  1182. if typ <= 2045:
  1183. return str(typ)
  1184. try:
  1185. return self.DTYPE_MAP_XML[typ]
  1186. except KeyError as err:
  1187. raise ValueError(f"cannot convert stata dtype [{typ}]") from err
  1188. dtyplist = [g(x) for x in raw_typlist]
  1189. return typlist, dtyplist
  1190. def _get_varlist(self) -> list[str]:
  1191. # 33 in order formats, 129 in formats 118 and 119
  1192. b = 33 if self._format_version < 118 else 129
  1193. return [self._decode(self._path_or_buf.read(b)) for _ in range(self._nvar)]
  1194. # Returns the format list
  1195. def _get_fmtlist(self) -> list[str]:
  1196. if self._format_version >= 118:
  1197. b = 57
  1198. elif self._format_version > 113:
  1199. b = 49
  1200. elif self._format_version > 104:
  1201. b = 12
  1202. else:
  1203. b = 7
  1204. return [self._decode(self._path_or_buf.read(b)) for _ in range(self._nvar)]
  1205. # Returns the label list
  1206. def _get_lbllist(self) -> list[str]:
  1207. if self._format_version >= 118:
  1208. b = 129
  1209. elif self._format_version > 108:
  1210. b = 33
  1211. else:
  1212. b = 9
  1213. return [self._decode(self._path_or_buf.read(b)) for _ in range(self._nvar)]
  1214. def _get_variable_labels(self) -> list[str]:
  1215. if self._format_version >= 118:
  1216. vlblist = [
  1217. self._decode(self._path_or_buf.read(321)) for _ in range(self._nvar)
  1218. ]
  1219. elif self._format_version > 105:
  1220. vlblist = [
  1221. self._decode(self._path_or_buf.read(81)) for _ in range(self._nvar)
  1222. ]
  1223. else:
  1224. vlblist = [
  1225. self._decode(self._path_or_buf.read(32)) for _ in range(self._nvar)
  1226. ]
  1227. return vlblist
  1228. def _get_nobs(self) -> int:
  1229. if self._format_version >= 118:
  1230. return self._read_uint64()
  1231. else:
  1232. return self._read_uint32()
  1233. def _get_data_label(self) -> str:
  1234. if self._format_version >= 118:
  1235. strlen = self._read_uint16()
  1236. return self._decode(self._path_or_buf.read(strlen))
  1237. elif self._format_version == 117:
  1238. strlen = self._read_int8()
  1239. return self._decode(self._path_or_buf.read(strlen))
  1240. elif self._format_version > 105:
  1241. return self._decode(self._path_or_buf.read(81))
  1242. else:
  1243. return self._decode(self._path_or_buf.read(32))
  1244. def _get_time_stamp(self) -> str:
  1245. if self._format_version >= 118:
  1246. strlen = self._read_int8()
  1247. return self._path_or_buf.read(strlen).decode("utf-8")
  1248. elif self._format_version == 117:
  1249. strlen = self._read_int8()
  1250. return self._decode(self._path_or_buf.read(strlen))
  1251. elif self._format_version > 104:
  1252. return self._decode(self._path_or_buf.read(18))
  1253. else:
  1254. raise ValueError()
  1255. def _get_seek_variable_labels(self) -> int:
  1256. if self._format_version == 117:
  1257. self._path_or_buf.read(8) # <variable_labels>, throw away
  1258. # Stata 117 data files do not follow the described format. This is
  1259. # a work around that uses the previous label, 33 bytes for each
  1260. # variable, 20 for the closing tag and 17 for the opening tag
  1261. return self._seek_value_label_names + (33 * self._nvar) + 20 + 17
  1262. elif self._format_version >= 118:
  1263. return self._read_int64() + 17
  1264. else:
  1265. raise ValueError()
  1266. def _read_old_header(self, first_char: bytes) -> None:
  1267. self._format_version = int(first_char[0])
  1268. if self._format_version not in [104, 105, 108, 111, 113, 114, 115]:
  1269. raise ValueError(_version_error.format(version=self._format_version))
  1270. self._set_encoding()
  1271. self._byteorder = ">" if self._read_int8() == 0x1 else "<"
  1272. self._filetype = self._read_int8()
  1273. self._path_or_buf.read(1) # unused
  1274. self._nvar = self._read_uint16()
  1275. self._nobs = self._get_nobs()
  1276. self._data_label = self._get_data_label()
  1277. self._time_stamp = self._get_time_stamp()
  1278. # descriptors
  1279. if self._format_version > 108:
  1280. typlist = [int(c) for c in self._path_or_buf.read(self._nvar)]
  1281. else:
  1282. buf = self._path_or_buf.read(self._nvar)
  1283. typlistb = np.frombuffer(buf, dtype=np.uint8)
  1284. typlist = []
  1285. for tp in typlistb:
  1286. if tp in self.OLD_TYPE_MAPPING:
  1287. typlist.append(self.OLD_TYPE_MAPPING[tp])
  1288. else:
  1289. typlist.append(tp - 127) # bytes
  1290. try:
  1291. self._typlist = [self.TYPE_MAP[typ] for typ in typlist]
  1292. except ValueError as err:
  1293. invalid_types = ",".join([str(x) for x in typlist])
  1294. raise ValueError(f"cannot convert stata types [{invalid_types}]") from err
  1295. try:
  1296. self._dtyplist = [self.DTYPE_MAP[typ] for typ in typlist]
  1297. except ValueError as err:
  1298. invalid_dtypes = ",".join([str(x) for x in typlist])
  1299. raise ValueError(f"cannot convert stata dtypes [{invalid_dtypes}]") from err
  1300. if self._format_version > 108:
  1301. self._varlist = [
  1302. self._decode(self._path_or_buf.read(33)) for _ in range(self._nvar)
  1303. ]
  1304. else:
  1305. self._varlist = [
  1306. self._decode(self._path_or_buf.read(9)) for _ in range(self._nvar)
  1307. ]
  1308. self._srtlist = self._read_int16_count(self._nvar + 1)[:-1]
  1309. self._fmtlist = self._get_fmtlist()
  1310. self._lbllist = self._get_lbllist()
  1311. self._variable_labels = self._get_variable_labels()
  1312. # ignore expansion fields (Format 105 and later)
  1313. # When reading, read five bytes; the last four bytes now tell you
  1314. # the size of the next read, which you discard. You then continue
  1315. # like this until you read 5 bytes of zeros.
  1316. if self._format_version > 104:
  1317. while True:
  1318. data_type = self._read_int8()
  1319. if self._format_version > 108:
  1320. data_len = self._read_int32()
  1321. else:
  1322. data_len = self._read_int16()
  1323. if data_type == 0:
  1324. break
  1325. self._path_or_buf.read(data_len)
  1326. # necessary data to continue parsing
  1327. self._data_location = self._path_or_buf.tell()
  1328. def _setup_dtype(self) -> np.dtype:
  1329. """Map between numpy and state dtypes"""
  1330. if self._dtype is not None:
  1331. return self._dtype
  1332. dtypes = [] # Convert struct data types to numpy data type
  1333. for i, typ in enumerate(self._typlist):
  1334. if typ in self.NUMPY_TYPE_MAP:
  1335. typ = cast(str, typ) # only strs in NUMPY_TYPE_MAP
  1336. dtypes.append((f"s{i}", f"{self._byteorder}{self.NUMPY_TYPE_MAP[typ]}"))
  1337. else:
  1338. dtypes.append((f"s{i}", f"S{typ}"))
  1339. self._dtype = np.dtype(dtypes)
  1340. return self._dtype
  1341. def _calcsize(self, fmt: int | str) -> int:
  1342. if isinstance(fmt, int):
  1343. return fmt
  1344. return struct.calcsize(self._byteorder + fmt)
  1345. def _decode(self, s: bytes) -> str:
  1346. # have bytes not strings, so must decode
  1347. s = s.partition(b"\0")[0]
  1348. try:
  1349. return s.decode(self._encoding)
  1350. except UnicodeDecodeError:
  1351. # GH 25960, fallback to handle incorrect format produced when 117
  1352. # files are converted to 118 files in Stata
  1353. encoding = self._encoding
  1354. msg = f"""
  1355. One or more strings in the dta file could not be decoded using {encoding}, and
  1356. so the fallback encoding of latin-1 is being used. This can happen when a file
  1357. has been incorrectly encoded by Stata or some other software. You should verify
  1358. the string values returned are correct."""
  1359. warnings.warn(
  1360. msg,
  1361. UnicodeWarning,
  1362. stacklevel=find_stack_level(),
  1363. )
  1364. return s.decode("latin-1")
  1365. def _read_value_labels(self) -> None:
  1366. self._ensure_open()
  1367. if self._value_labels_read:
  1368. # Don't read twice
  1369. return
  1370. if self._format_version <= 108:
  1371. # Value labels are not supported in version 108 and earlier.
  1372. self._value_labels_read = True
  1373. self._value_label_dict: dict[str, dict[float, str]] = {}
  1374. return
  1375. if self._format_version >= 117:
  1376. self._path_or_buf.seek(self._seek_value_labels)
  1377. else:
  1378. assert self._dtype is not None
  1379. offset = self._nobs * self._dtype.itemsize
  1380. self._path_or_buf.seek(self._data_location + offset)
  1381. self._value_labels_read = True
  1382. self._value_label_dict = {}
  1383. while True:
  1384. if self._format_version >= 117:
  1385. if self._path_or_buf.read(5) == b"</val": # <lbl>
  1386. break # end of value label table
  1387. slength = self._path_or_buf.read(4)
  1388. if not slength:
  1389. break # end of value label table (format < 117)
  1390. if self._format_version <= 117:
  1391. labname = self._decode(self._path_or_buf.read(33))
  1392. else:
  1393. labname = self._decode(self._path_or_buf.read(129))
  1394. self._path_or_buf.read(3) # padding
  1395. n = self._read_uint32()
  1396. txtlen = self._read_uint32()
  1397. off = np.frombuffer(
  1398. self._path_or_buf.read(4 * n), dtype=f"{self._byteorder}i4", count=n
  1399. )
  1400. val = np.frombuffer(
  1401. self._path_or_buf.read(4 * n), dtype=f"{self._byteorder}i4", count=n
  1402. )
  1403. ii = np.argsort(off)
  1404. off = off[ii]
  1405. val = val[ii]
  1406. txt = self._path_or_buf.read(txtlen)
  1407. self._value_label_dict[labname] = {}
  1408. for i in range(n):
  1409. end = off[i + 1] if i < n - 1 else txtlen
  1410. self._value_label_dict[labname][val[i]] = self._decode(
  1411. txt[off[i] : end]
  1412. )
  1413. if self._format_version >= 117:
  1414. self._path_or_buf.read(6) # </lbl>
  1415. self._value_labels_read = True
  1416. def _read_strls(self) -> None:
  1417. self._path_or_buf.seek(self._seek_strls)
  1418. # Wrap v_o in a string to allow uint64 values as keys on 32bit OS
  1419. self.GSO = {"0": ""}
  1420. while True:
  1421. if self._path_or_buf.read(3) != b"GSO":
  1422. break
  1423. if self._format_version == 117:
  1424. v_o = self._read_uint64()
  1425. else:
  1426. buf = self._path_or_buf.read(12)
  1427. # Only tested on little endian file on little endian machine.
  1428. v_size = 2 if self._format_version == 118 else 3
  1429. if self._byteorder == "<":
  1430. buf = buf[0:v_size] + buf[4 : (12 - v_size)]
  1431. else:
  1432. # This path may not be correct, impossible to test
  1433. buf = buf[0:v_size] + buf[(4 + v_size) :]
  1434. v_o = struct.unpack("Q", buf)[0]
  1435. typ = self._read_uint8()
  1436. length = self._read_uint32()
  1437. va = self._path_or_buf.read(length)
  1438. if typ == 130:
  1439. decoded_va = va[0:-1].decode(self._encoding)
  1440. else:
  1441. # Stata says typ 129 can be binary, so use str
  1442. decoded_va = str(va)
  1443. # Wrap v_o in a string to allow uint64 values as keys on 32bit OS
  1444. self.GSO[str(v_o)] = decoded_va
  1445. def __next__(self) -> DataFrame:
  1446. self._using_iterator = True
  1447. return self.read(nrows=self._chunksize)
  1448. def get_chunk(self, size: int | None = None) -> DataFrame:
  1449. """
  1450. Reads lines from Stata file and returns as dataframe
  1451. Parameters
  1452. ----------
  1453. size : int, defaults to None
  1454. Number of lines to read. If None, reads whole file.
  1455. Returns
  1456. -------
  1457. DataFrame
  1458. """
  1459. if size is None:
  1460. size = self._chunksize
  1461. return self.read(nrows=size)
  1462. @Appender(_read_method_doc)
  1463. def read(
  1464. self,
  1465. nrows: int | None = None,
  1466. convert_dates: bool | None = None,
  1467. convert_categoricals: bool | None = None,
  1468. index_col: str | None = None,
  1469. convert_missing: bool | None = None,
  1470. preserve_dtypes: bool | None = None,
  1471. columns: Sequence[str] | None = None,
  1472. order_categoricals: bool | None = None,
  1473. ) -> DataFrame:
  1474. self._ensure_open()
  1475. # Handle empty file or chunk. If reading incrementally raise
  1476. # StopIteration. If reading the whole thing return an empty
  1477. # data frame.
  1478. if (self._nobs == 0) and (nrows is None):
  1479. self._can_read_value_labels = True
  1480. self._data_read = True
  1481. return DataFrame(columns=self._varlist)
  1482. # Handle options
  1483. if convert_dates is None:
  1484. convert_dates = self._convert_dates
  1485. if convert_categoricals is None:
  1486. convert_categoricals = self._convert_categoricals
  1487. if convert_missing is None:
  1488. convert_missing = self._convert_missing
  1489. if preserve_dtypes is None:
  1490. preserve_dtypes = self._preserve_dtypes
  1491. if columns is None:
  1492. columns = self._columns
  1493. if order_categoricals is None:
  1494. order_categoricals = self._order_categoricals
  1495. if index_col is None:
  1496. index_col = self._index_col
  1497. if nrows is None:
  1498. nrows = self._nobs
  1499. if (self._format_version >= 117) and (not self._value_labels_read):
  1500. self._can_read_value_labels = True
  1501. self._read_strls()
  1502. # Read data
  1503. assert self._dtype is not None
  1504. dtype = self._dtype
  1505. max_read_len = (self._nobs - self._lines_read) * dtype.itemsize
  1506. read_len = nrows * dtype.itemsize
  1507. read_len = min(read_len, max_read_len)
  1508. if read_len <= 0:
  1509. # Iterator has finished, should never be here unless
  1510. # we are reading the file incrementally
  1511. if convert_categoricals:
  1512. self._read_value_labels()
  1513. raise StopIteration
  1514. offset = self._lines_read * dtype.itemsize
  1515. self._path_or_buf.seek(self._data_location + offset)
  1516. read_lines = min(nrows, self._nobs - self._lines_read)
  1517. raw_data = np.frombuffer(
  1518. self._path_or_buf.read(read_len), dtype=dtype, count=read_lines
  1519. )
  1520. self._lines_read += read_lines
  1521. if self._lines_read == self._nobs:
  1522. self._can_read_value_labels = True
  1523. self._data_read = True
  1524. # if necessary, swap the byte order to native here
  1525. if self._byteorder != self._native_byteorder:
  1526. raw_data = raw_data.byteswap().newbyteorder()
  1527. if convert_categoricals:
  1528. self._read_value_labels()
  1529. if len(raw_data) == 0:
  1530. data = DataFrame(columns=self._varlist)
  1531. else:
  1532. data = DataFrame.from_records(raw_data)
  1533. data.columns = Index(self._varlist)
  1534. # If index is not specified, use actual row number rather than
  1535. # restarting at 0 for each chunk.
  1536. if index_col is None:
  1537. rng = range(self._lines_read - read_lines, self._lines_read)
  1538. data.index = Index(rng) # set attr instead of set_index to avoid copy
  1539. if columns is not None:
  1540. data = self._do_select_columns(data, columns)
  1541. # Decode strings
  1542. for col, typ in zip(data, self._typlist):
  1543. if type(typ) is int:
  1544. data[col] = data[col].apply(self._decode, convert_dtype=True)
  1545. data = self._insert_strls(data)
  1546. cols_ = np.where([dtyp is not None for dtyp in self._dtyplist])[0]
  1547. # Convert columns (if needed) to match input type
  1548. ix = data.index
  1549. requires_type_conversion = False
  1550. data_formatted = []
  1551. for i in cols_:
  1552. if self._dtyplist[i] is not None:
  1553. col = data.columns[i]
  1554. dtype = data[col].dtype
  1555. if dtype != np.dtype(object) and dtype != self._dtyplist[i]:
  1556. requires_type_conversion = True
  1557. data_formatted.append(
  1558. (col, Series(data[col], ix, self._dtyplist[i]))
  1559. )
  1560. else:
  1561. data_formatted.append((col, data[col]))
  1562. if requires_type_conversion:
  1563. data = DataFrame.from_dict(dict(data_formatted))
  1564. del data_formatted
  1565. data = self._do_convert_missing(data, convert_missing)
  1566. if convert_dates:
  1567. def any_startswith(x: str) -> bool:
  1568. return any(x.startswith(fmt) for fmt in _date_formats)
  1569. cols = np.where([any_startswith(x) for x in self._fmtlist])[0]
  1570. for i in cols:
  1571. col = data.columns[i]
  1572. data[col] = _stata_elapsed_date_to_datetime_vec(
  1573. data[col], self._fmtlist[i]
  1574. )
  1575. if convert_categoricals and self._format_version > 108:
  1576. data = self._do_convert_categoricals(
  1577. data, self._value_label_dict, self._lbllist, order_categoricals
  1578. )
  1579. if not preserve_dtypes:
  1580. retyped_data = []
  1581. convert = False
  1582. for col in data:
  1583. dtype = data[col].dtype
  1584. if dtype in (np.dtype(np.float16), np.dtype(np.float32)):
  1585. dtype = np.dtype(np.float64)
  1586. convert = True
  1587. elif dtype in (
  1588. np.dtype(np.int8),
  1589. np.dtype(np.int16),
  1590. np.dtype(np.int32),
  1591. ):
  1592. dtype = np.dtype(np.int64)
  1593. convert = True
  1594. retyped_data.append((col, data[col].astype(dtype)))
  1595. if convert:
  1596. data = DataFrame.from_dict(dict(retyped_data))
  1597. if index_col is not None:
  1598. data = data.set_index(data.pop(index_col))
  1599. return data
  1600. def _do_convert_missing(self, data: DataFrame, convert_missing: bool) -> DataFrame:
  1601. # Check for missing values, and replace if found
  1602. replacements = {}
  1603. for i, colname in enumerate(data):
  1604. fmt = self._typlist[i]
  1605. if fmt not in self.VALID_RANGE:
  1606. continue
  1607. fmt = cast(str, fmt) # only strs in VALID_RANGE
  1608. nmin, nmax = self.VALID_RANGE[fmt]
  1609. series = data[colname]
  1610. # appreciably faster to do this with ndarray instead of Series
  1611. svals = series._values
  1612. missing = (svals < nmin) | (svals > nmax)
  1613. if not missing.any():
  1614. continue
  1615. if convert_missing: # Replacement follows Stata notation
  1616. missing_loc = np.nonzero(np.asarray(missing))[0]
  1617. umissing, umissing_loc = np.unique(series[missing], return_inverse=True)
  1618. replacement = Series(series, dtype=object)
  1619. for j, um in enumerate(umissing):
  1620. missing_value = StataMissingValue(um)
  1621. loc = missing_loc[umissing_loc == j]
  1622. replacement.iloc[loc] = missing_value
  1623. else: # All replacements are identical
  1624. dtype = series.dtype
  1625. if dtype not in (np.float32, np.float64):
  1626. dtype = np.float64
  1627. replacement = Series(series, dtype=dtype)
  1628. if not replacement._values.flags["WRITEABLE"]:
  1629. # only relevant for ArrayManager; construction
  1630. # path for BlockManager ensures writeability
  1631. replacement = replacement.copy()
  1632. # Note: operating on ._values is much faster than directly
  1633. # TODO: can we fix that?
  1634. replacement._values[missing] = np.nan
  1635. replacements[colname] = replacement
  1636. if replacements:
  1637. for col, value in replacements.items():
  1638. data[col] = value
  1639. return data
  1640. def _insert_strls(self, data: DataFrame) -> DataFrame:
  1641. if not hasattr(self, "GSO") or len(self.GSO) == 0:
  1642. return data
  1643. for i, typ in enumerate(self._typlist):
  1644. if typ != "Q":
  1645. continue
  1646. # Wrap v_o in a string to allow uint64 values as keys on 32bit OS
  1647. data.iloc[:, i] = [self.GSO[str(k)] for k in data.iloc[:, i]]
  1648. return data
  1649. def _do_select_columns(self, data: DataFrame, columns: Sequence[str]) -> DataFrame:
  1650. if not self._column_selector_set:
  1651. column_set = set(columns)
  1652. if len(column_set) != len(columns):
  1653. raise ValueError("columns contains duplicate entries")
  1654. unmatched = column_set.difference(data.columns)
  1655. if unmatched:
  1656. joined = ", ".join(list(unmatched))
  1657. raise ValueError(
  1658. "The following columns were not "
  1659. f"found in the Stata data set: {joined}"
  1660. )
  1661. # Copy information for retained columns for later processing
  1662. dtyplist = []
  1663. typlist = []
  1664. fmtlist = []
  1665. lbllist = []
  1666. for col in columns:
  1667. i = data.columns.get_loc(col)
  1668. dtyplist.append(self._dtyplist[i])
  1669. typlist.append(self._typlist[i])
  1670. fmtlist.append(self._fmtlist[i])
  1671. lbllist.append(self._lbllist[i])
  1672. self._dtyplist = dtyplist
  1673. self._typlist = typlist
  1674. self._fmtlist = fmtlist
  1675. self._lbllist = lbllist
  1676. self._column_selector_set = True
  1677. return data[columns]
  1678. def _do_convert_categoricals(
  1679. self,
  1680. data: DataFrame,
  1681. value_label_dict: dict[str, dict[float, str]],
  1682. lbllist: Sequence[str],
  1683. order_categoricals: bool,
  1684. ) -> DataFrame:
  1685. """
  1686. Converts categorical columns to Categorical type.
  1687. """
  1688. value_labels = list(value_label_dict.keys())
  1689. cat_converted_data = []
  1690. for col, label in zip(data, lbllist):
  1691. if label in value_labels:
  1692. # Explicit call with ordered=True
  1693. vl = value_label_dict[label]
  1694. keys = np.array(list(vl.keys()))
  1695. column = data[col]
  1696. key_matches = column.isin(keys)
  1697. if self._using_iterator and key_matches.all():
  1698. initial_categories: np.ndarray | None = keys
  1699. # If all categories are in the keys and we are iterating,
  1700. # use the same keys for all chunks. If some are missing
  1701. # value labels, then we will fall back to the categories
  1702. # varying across chunks.
  1703. else:
  1704. if self._using_iterator:
  1705. # warn is using an iterator
  1706. warnings.warn(
  1707. categorical_conversion_warning,
  1708. CategoricalConversionWarning,
  1709. stacklevel=find_stack_level(),
  1710. )
  1711. initial_categories = None
  1712. cat_data = Categorical(
  1713. column, categories=initial_categories, ordered=order_categoricals
  1714. )
  1715. if initial_categories is None:
  1716. # If None here, then we need to match the cats in the Categorical
  1717. categories = []
  1718. for category in cat_data.categories:
  1719. if category in vl:
  1720. categories.append(vl[category])
  1721. else:
  1722. categories.append(category)
  1723. else:
  1724. # If all cats are matched, we can use the values
  1725. categories = list(vl.values())
  1726. try:
  1727. # Try to catch duplicate categories
  1728. # TODO: if we get a non-copying rename_categories, use that
  1729. cat_data = cat_data.rename_categories(categories)
  1730. except ValueError as err:
  1731. vc = Series(categories, copy=False).value_counts()
  1732. repeated_cats = list(vc.index[vc > 1])
  1733. repeats = "-" * 80 + "\n" + "\n".join(repeated_cats)
  1734. # GH 25772
  1735. msg = f"""
  1736. Value labels for column {col} are not unique. These cannot be converted to
  1737. pandas categoricals.
  1738. Either read the file with `convert_categoricals` set to False or use the
  1739. low level interface in `StataReader` to separately read the values and the
  1740. value_labels.
  1741. The repeated labels are:
  1742. {repeats}
  1743. """
  1744. raise ValueError(msg) from err
  1745. # TODO: is the next line needed above in the data(...) method?
  1746. cat_series = Series(cat_data, index=data.index, copy=False)
  1747. cat_converted_data.append((col, cat_series))
  1748. else:
  1749. cat_converted_data.append((col, data[col]))
  1750. data = DataFrame(dict(cat_converted_data), copy=False)
  1751. return data
  1752. @property
  1753. def data_label(self) -> str:
  1754. """
  1755. Return data label of Stata file.
  1756. """
  1757. self._ensure_open()
  1758. return self._data_label
  1759. @property
  1760. def time_stamp(self) -> str:
  1761. """
  1762. Return time stamp of Stata file.
  1763. """
  1764. self._ensure_open()
  1765. return self._time_stamp
  1766. def variable_labels(self) -> dict[str, str]:
  1767. """
  1768. Return a dict associating each variable name with corresponding label.
  1769. Returns
  1770. -------
  1771. dict
  1772. """
  1773. self._ensure_open()
  1774. return dict(zip(self._varlist, self._variable_labels))
  1775. def value_labels(self) -> dict[str, dict[float, str]]:
  1776. """
  1777. Return a nested dict associating each variable name to its value and label.
  1778. Returns
  1779. -------
  1780. dict
  1781. """
  1782. if not self._value_labels_read:
  1783. self._read_value_labels()
  1784. return self._value_label_dict
  1785. @Appender(_read_stata_doc)
  1786. def read_stata(
  1787. filepath_or_buffer: FilePath | ReadBuffer[bytes],
  1788. *,
  1789. convert_dates: bool = True,
  1790. convert_categoricals: bool = True,
  1791. index_col: str | None = None,
  1792. convert_missing: bool = False,
  1793. preserve_dtypes: bool = True,
  1794. columns: Sequence[str] | None = None,
  1795. order_categoricals: bool = True,
  1796. chunksize: int | None = None,
  1797. iterator: bool = False,
  1798. compression: CompressionOptions = "infer",
  1799. storage_options: StorageOptions = None,
  1800. ) -> DataFrame | StataReader:
  1801. reader = StataReader(
  1802. filepath_or_buffer,
  1803. convert_dates=convert_dates,
  1804. convert_categoricals=convert_categoricals,
  1805. index_col=index_col,
  1806. convert_missing=convert_missing,
  1807. preserve_dtypes=preserve_dtypes,
  1808. columns=columns,
  1809. order_categoricals=order_categoricals,
  1810. chunksize=chunksize,
  1811. storage_options=storage_options,
  1812. compression=compression,
  1813. )
  1814. if iterator or chunksize:
  1815. return reader
  1816. with reader:
  1817. return reader.read()
  1818. def _set_endianness(endianness: str) -> str:
  1819. if endianness.lower() in ["<", "little"]:
  1820. return "<"
  1821. elif endianness.lower() in [">", "big"]:
  1822. return ">"
  1823. else: # pragma : no cover
  1824. raise ValueError(f"Endianness {endianness} not understood")
  1825. def _pad_bytes(name: AnyStr, length: int) -> AnyStr:
  1826. """
  1827. Take a char string and pads it with null bytes until it's length chars.
  1828. """
  1829. if isinstance(name, bytes):
  1830. return name + b"\x00" * (length - len(name))
  1831. return name + "\x00" * (length - len(name))
  1832. def _convert_datetime_to_stata_type(fmt: str) -> np.dtype:
  1833. """
  1834. Convert from one of the stata date formats to a type in TYPE_MAP.
  1835. """
  1836. if fmt in [
  1837. "tc",
  1838. "%tc",
  1839. "td",
  1840. "%td",
  1841. "tw",
  1842. "%tw",
  1843. "tm",
  1844. "%tm",
  1845. "tq",
  1846. "%tq",
  1847. "th",
  1848. "%th",
  1849. "ty",
  1850. "%ty",
  1851. ]:
  1852. return np.dtype(np.float64) # Stata expects doubles for SIFs
  1853. else:
  1854. raise NotImplementedError(f"Format {fmt} not implemented")
  1855. def _maybe_convert_to_int_keys(convert_dates: dict, varlist: list[Hashable]) -> dict:
  1856. new_dict = {}
  1857. for key in convert_dates:
  1858. if not convert_dates[key].startswith("%"): # make sure proper fmts
  1859. convert_dates[key] = "%" + convert_dates[key]
  1860. if key in varlist:
  1861. new_dict.update({varlist.index(key): convert_dates[key]})
  1862. else:
  1863. if not isinstance(key, int):
  1864. raise ValueError("convert_dates key must be a column or an integer")
  1865. new_dict.update({key: convert_dates[key]})
  1866. return new_dict
  1867. def _dtype_to_stata_type(dtype: np.dtype, column: Series) -> int:
  1868. """
  1869. Convert dtype types to stata types. Returns the byte of the given ordinal.
  1870. See TYPE_MAP and comments for an explanation. This is also explained in
  1871. the dta spec.
  1872. 1 - 244 are strings of this length
  1873. Pandas Stata
  1874. 251 - for int8 byte
  1875. 252 - for int16 int
  1876. 253 - for int32 long
  1877. 254 - for float32 float
  1878. 255 - for double double
  1879. If there are dates to convert, then dtype will already have the correct
  1880. type inserted.
  1881. """
  1882. # TODO: expand to handle datetime to integer conversion
  1883. if dtype.type is np.object_: # try to coerce it to the biggest string
  1884. # not memory efficient, what else could we
  1885. # do?
  1886. itemsize = max_len_string_array(ensure_object(column._values))
  1887. return max(itemsize, 1)
  1888. elif dtype.type is np.float64:
  1889. return 255
  1890. elif dtype.type is np.float32:
  1891. return 254
  1892. elif dtype.type is np.int32:
  1893. return 253
  1894. elif dtype.type is np.int16:
  1895. return 252
  1896. elif dtype.type is np.int8:
  1897. return 251
  1898. else: # pragma : no cover
  1899. raise NotImplementedError(f"Data type {dtype} not supported.")
  1900. def _dtype_to_default_stata_fmt(
  1901. dtype, column: Series, dta_version: int = 114, force_strl: bool = False
  1902. ) -> str:
  1903. """
  1904. Map numpy dtype to stata's default format for this type. Not terribly
  1905. important since users can change this in Stata. Semantics are
  1906. object -> "%DDs" where DD is the length of the string. If not a string,
  1907. raise ValueError
  1908. float64 -> "%10.0g"
  1909. float32 -> "%9.0g"
  1910. int64 -> "%9.0g"
  1911. int32 -> "%12.0g"
  1912. int16 -> "%8.0g"
  1913. int8 -> "%8.0g"
  1914. strl -> "%9s"
  1915. """
  1916. # TODO: Refactor to combine type with format
  1917. # TODO: expand this to handle a default datetime format?
  1918. if dta_version < 117:
  1919. max_str_len = 244
  1920. else:
  1921. max_str_len = 2045
  1922. if force_strl:
  1923. return "%9s"
  1924. if dtype.type is np.object_:
  1925. itemsize = max_len_string_array(ensure_object(column._values))
  1926. if itemsize > max_str_len:
  1927. if dta_version >= 117:
  1928. return "%9s"
  1929. else:
  1930. raise ValueError(excessive_string_length_error.format(column.name))
  1931. return "%" + str(max(itemsize, 1)) + "s"
  1932. elif dtype == np.float64:
  1933. return "%10.0g"
  1934. elif dtype == np.float32:
  1935. return "%9.0g"
  1936. elif dtype == np.int32:
  1937. return "%12.0g"
  1938. elif dtype in (np.int8, np.int16):
  1939. return "%8.0g"
  1940. else: # pragma : no cover
  1941. raise NotImplementedError(f"Data type {dtype} not supported.")
  1942. @doc(
  1943. storage_options=_shared_docs["storage_options"],
  1944. compression_options=_shared_docs["compression_options"] % "fname",
  1945. )
  1946. class StataWriter(StataParser):
  1947. """
  1948. A class for writing Stata binary dta files
  1949. Parameters
  1950. ----------
  1951. fname : path (string), buffer or path object
  1952. string, path object (pathlib.Path or py._path.local.LocalPath) or
  1953. object implementing a binary write() functions. If using a buffer
  1954. then the buffer will not be automatically closed after the file
  1955. is written.
  1956. data : DataFrame
  1957. Input to save
  1958. convert_dates : dict
  1959. Dictionary mapping columns containing datetime types to stata internal
  1960. format to use when writing the dates. Options are 'tc', 'td', 'tm',
  1961. 'tw', 'th', 'tq', 'ty'. Column can be either an integer or a name.
  1962. Datetime columns that do not have a conversion type specified will be
  1963. converted to 'tc'. Raises NotImplementedError if a datetime column has
  1964. timezone information
  1965. write_index : bool
  1966. Write the index to Stata dataset.
  1967. byteorder : str
  1968. Can be ">", "<", "little", or "big". default is `sys.byteorder`
  1969. time_stamp : datetime
  1970. A datetime to use as file creation date. Default is the current time
  1971. data_label : str
  1972. A label for the data set. Must be 80 characters or smaller.
  1973. variable_labels : dict
  1974. Dictionary containing columns as keys and variable labels as values.
  1975. Each label must be 80 characters or smaller.
  1976. {compression_options}
  1977. .. versionadded:: 1.1.0
  1978. .. versionchanged:: 1.4.0 Zstandard support.
  1979. {storage_options}
  1980. .. versionadded:: 1.2.0
  1981. value_labels : dict of dicts
  1982. Dictionary containing columns as keys and dictionaries of column value
  1983. to labels as values. The combined length of all labels for a single
  1984. variable must be 32,000 characters or smaller.
  1985. .. versionadded:: 1.4.0
  1986. Returns
  1987. -------
  1988. writer : StataWriter instance
  1989. The StataWriter instance has a write_file method, which will
  1990. write the file to the given `fname`.
  1991. Raises
  1992. ------
  1993. NotImplementedError
  1994. * If datetimes contain timezone information
  1995. ValueError
  1996. * Columns listed in convert_dates are neither datetime64[ns]
  1997. or datetime.datetime
  1998. * Column dtype is not representable in Stata
  1999. * Column listed in convert_dates is not in DataFrame
  2000. * Categorical label contains more than 32,000 characters
  2001. Examples
  2002. --------
  2003. >>> data = pd.DataFrame([[1.0, 1]], columns=['a', 'b'])
  2004. >>> writer = StataWriter('./data_file.dta', data)
  2005. >>> writer.write_file()
  2006. Directly write a zip file
  2007. >>> compression = {{"method": "zip", "archive_name": "data_file.dta"}}
  2008. >>> writer = StataWriter('./data_file.zip', data, compression=compression)
  2009. >>> writer.write_file()
  2010. Save a DataFrame with dates
  2011. >>> from datetime import datetime
  2012. >>> data = pd.DataFrame([[datetime(2000,1,1)]], columns=['date'])
  2013. >>> writer = StataWriter('./date_data_file.dta', data, {{'date' : 'tw'}})
  2014. >>> writer.write_file()
  2015. """
  2016. _max_string_length = 244
  2017. _encoding: Literal["latin-1", "utf-8"] = "latin-1"
  2018. def __init__(
  2019. self,
  2020. fname: FilePath | WriteBuffer[bytes],
  2021. data: DataFrame,
  2022. convert_dates: dict[Hashable, str] | None = None,
  2023. write_index: bool = True,
  2024. byteorder: str | None = None,
  2025. time_stamp: datetime.datetime | None = None,
  2026. data_label: str | None = None,
  2027. variable_labels: dict[Hashable, str] | None = None,
  2028. compression: CompressionOptions = "infer",
  2029. storage_options: StorageOptions = None,
  2030. *,
  2031. value_labels: dict[Hashable, dict[float, str]] | None = None,
  2032. ) -> None:
  2033. super().__init__()
  2034. self.data = data
  2035. self._convert_dates = {} if convert_dates is None else convert_dates
  2036. self._write_index = write_index
  2037. self._time_stamp = time_stamp
  2038. self._data_label = data_label
  2039. self._variable_labels = variable_labels
  2040. self._non_cat_value_labels = value_labels
  2041. self._value_labels: list[StataValueLabel] = []
  2042. self._has_value_labels = np.array([], dtype=bool)
  2043. self._compression = compression
  2044. self._output_file: IO[bytes] | None = None
  2045. self._converted_names: dict[Hashable, str] = {}
  2046. # attach nobs, nvars, data, varlist, typlist
  2047. self._prepare_pandas(data)
  2048. self.storage_options = storage_options
  2049. if byteorder is None:
  2050. byteorder = sys.byteorder
  2051. self._byteorder = _set_endianness(byteorder)
  2052. self._fname = fname
  2053. self.type_converters = {253: np.int32, 252: np.int16, 251: np.int8}
  2054. def _write(self, to_write: str) -> None:
  2055. """
  2056. Helper to call encode before writing to file for Python 3 compat.
  2057. """
  2058. self.handles.handle.write(to_write.encode(self._encoding))
  2059. def _write_bytes(self, value: bytes) -> None:
  2060. """
  2061. Helper to assert file is open before writing.
  2062. """
  2063. self.handles.handle.write(value)
  2064. def _prepare_non_cat_value_labels(
  2065. self, data: DataFrame
  2066. ) -> list[StataNonCatValueLabel]:
  2067. """
  2068. Check for value labels provided for non-categorical columns. Value
  2069. labels
  2070. """
  2071. non_cat_value_labels: list[StataNonCatValueLabel] = []
  2072. if self._non_cat_value_labels is None:
  2073. return non_cat_value_labels
  2074. for labname, labels in self._non_cat_value_labels.items():
  2075. if labname in self._converted_names:
  2076. colname = self._converted_names[labname]
  2077. elif labname in data.columns:
  2078. colname = str(labname)
  2079. else:
  2080. raise KeyError(
  2081. f"Can't create value labels for {labname}, it wasn't "
  2082. "found in the dataset."
  2083. )
  2084. if not is_numeric_dtype(data[colname].dtype):
  2085. # Labels should not be passed explicitly for categorical
  2086. # columns that will be converted to int
  2087. raise ValueError(
  2088. f"Can't create value labels for {labname}, value labels "
  2089. "can only be applied to numeric columns."
  2090. )
  2091. svl = StataNonCatValueLabel(colname, labels, self._encoding)
  2092. non_cat_value_labels.append(svl)
  2093. return non_cat_value_labels
  2094. def _prepare_categoricals(self, data: DataFrame) -> DataFrame:
  2095. """
  2096. Check for categorical columns, retain categorical information for
  2097. Stata file and convert categorical data to int
  2098. """
  2099. is_cat = [is_categorical_dtype(data[col].dtype) for col in data]
  2100. if not any(is_cat):
  2101. return data
  2102. self._has_value_labels |= np.array(is_cat)
  2103. get_base_missing_value = StataMissingValue.get_base_missing_value
  2104. data_formatted = []
  2105. for col, col_is_cat in zip(data, is_cat):
  2106. if col_is_cat:
  2107. svl = StataValueLabel(data[col], encoding=self._encoding)
  2108. self._value_labels.append(svl)
  2109. dtype = data[col].cat.codes.dtype
  2110. if dtype == np.int64:
  2111. raise ValueError(
  2112. "It is not possible to export "
  2113. "int64-based categorical data to Stata."
  2114. )
  2115. values = data[col].cat.codes._values.copy()
  2116. # Upcast if needed so that correct missing values can be set
  2117. if values.max() >= get_base_missing_value(dtype):
  2118. if dtype == np.int8:
  2119. dtype = np.dtype(np.int16)
  2120. elif dtype == np.int16:
  2121. dtype = np.dtype(np.int32)
  2122. else:
  2123. dtype = np.dtype(np.float64)
  2124. values = np.array(values, dtype=dtype)
  2125. # Replace missing values with Stata missing value for type
  2126. values[values == -1] = get_base_missing_value(dtype)
  2127. data_formatted.append((col, values))
  2128. else:
  2129. data_formatted.append((col, data[col]))
  2130. return DataFrame.from_dict(dict(data_formatted))
  2131. def _replace_nans(self, data: DataFrame) -> DataFrame:
  2132. # return data
  2133. """
  2134. Checks floating point data columns for nans, and replaces these with
  2135. the generic Stata for missing value (.)
  2136. """
  2137. for c in data:
  2138. dtype = data[c].dtype
  2139. if dtype in (np.float32, np.float64):
  2140. if dtype == np.float32:
  2141. replacement = self.MISSING_VALUES["f"]
  2142. else:
  2143. replacement = self.MISSING_VALUES["d"]
  2144. data[c] = data[c].fillna(replacement)
  2145. return data
  2146. def _update_strl_names(self) -> None:
  2147. """No-op, forward compatibility"""
  2148. def _validate_variable_name(self, name: str) -> str:
  2149. """
  2150. Validate variable names for Stata export.
  2151. Parameters
  2152. ----------
  2153. name : str
  2154. Variable name
  2155. Returns
  2156. -------
  2157. str
  2158. The validated name with invalid characters replaced with
  2159. underscores.
  2160. Notes
  2161. -----
  2162. Stata 114 and 117 support ascii characters in a-z, A-Z, 0-9
  2163. and _.
  2164. """
  2165. for c in name:
  2166. if (
  2167. (c < "A" or c > "Z")
  2168. and (c < "a" or c > "z")
  2169. and (c < "0" or c > "9")
  2170. and c != "_"
  2171. ):
  2172. name = name.replace(c, "_")
  2173. return name
  2174. def _check_column_names(self, data: DataFrame) -> DataFrame:
  2175. """
  2176. Checks column names to ensure that they are valid Stata column names.
  2177. This includes checks for:
  2178. * Non-string names
  2179. * Stata keywords
  2180. * Variables that start with numbers
  2181. * Variables with names that are too long
  2182. When an illegal variable name is detected, it is converted, and if
  2183. dates are exported, the variable name is propagated to the date
  2184. conversion dictionary
  2185. """
  2186. converted_names: dict[Hashable, str] = {}
  2187. columns = list(data.columns)
  2188. original_columns = columns[:]
  2189. duplicate_var_id = 0
  2190. for j, name in enumerate(columns):
  2191. orig_name = name
  2192. if not isinstance(name, str):
  2193. name = str(name)
  2194. name = self._validate_variable_name(name)
  2195. # Variable name must not be a reserved word
  2196. if name in self.RESERVED_WORDS:
  2197. name = "_" + name
  2198. # Variable name may not start with a number
  2199. if "0" <= name[0] <= "9":
  2200. name = "_" + name
  2201. name = name[: min(len(name), 32)]
  2202. if not name == orig_name:
  2203. # check for duplicates
  2204. while columns.count(name) > 0:
  2205. # prepend ascending number to avoid duplicates
  2206. name = "_" + str(duplicate_var_id) + name
  2207. name = name[: min(len(name), 32)]
  2208. duplicate_var_id += 1
  2209. converted_names[orig_name] = name
  2210. columns[j] = name
  2211. data.columns = Index(columns)
  2212. # Check date conversion, and fix key if needed
  2213. if self._convert_dates:
  2214. for c, o in zip(columns, original_columns):
  2215. if c != o:
  2216. self._convert_dates[c] = self._convert_dates[o]
  2217. del self._convert_dates[o]
  2218. if converted_names:
  2219. conversion_warning = []
  2220. for orig_name, name in converted_names.items():
  2221. msg = f"{orig_name} -> {name}"
  2222. conversion_warning.append(msg)
  2223. ws = invalid_name_doc.format("\n ".join(conversion_warning))
  2224. warnings.warn(
  2225. ws,
  2226. InvalidColumnName,
  2227. stacklevel=find_stack_level(),
  2228. )
  2229. self._converted_names = converted_names
  2230. self._update_strl_names()
  2231. return data
  2232. def _set_formats_and_types(self, dtypes: Series) -> None:
  2233. self.fmtlist: list[str] = []
  2234. self.typlist: list[int] = []
  2235. for col, dtype in dtypes.items():
  2236. self.fmtlist.append(_dtype_to_default_stata_fmt(dtype, self.data[col]))
  2237. self.typlist.append(_dtype_to_stata_type(dtype, self.data[col]))
  2238. def _prepare_pandas(self, data: DataFrame) -> None:
  2239. # NOTE: we might need a different API / class for pandas objects so
  2240. # we can set different semantics - handle this with a PR to pandas.io
  2241. data = data.copy()
  2242. if self._write_index:
  2243. temp = data.reset_index()
  2244. if isinstance(temp, DataFrame):
  2245. data = temp
  2246. # Ensure column names are strings
  2247. data = self._check_column_names(data)
  2248. # Check columns for compatibility with stata, upcast if necessary
  2249. # Raise if outside the supported range
  2250. data = _cast_to_stata_types(data)
  2251. # Replace NaNs with Stata missing values
  2252. data = self._replace_nans(data)
  2253. # Set all columns to initially unlabelled
  2254. self._has_value_labels = np.repeat(False, data.shape[1])
  2255. # Create value labels for non-categorical data
  2256. non_cat_value_labels = self._prepare_non_cat_value_labels(data)
  2257. non_cat_columns = [svl.labname for svl in non_cat_value_labels]
  2258. has_non_cat_val_labels = data.columns.isin(non_cat_columns)
  2259. self._has_value_labels |= has_non_cat_val_labels
  2260. self._value_labels.extend(non_cat_value_labels)
  2261. # Convert categoricals to int data, and strip labels
  2262. data = self._prepare_categoricals(data)
  2263. self.nobs, self.nvar = data.shape
  2264. self.data = data
  2265. self.varlist = data.columns.tolist()
  2266. dtypes = data.dtypes
  2267. # Ensure all date columns are converted
  2268. for col in data:
  2269. if col in self._convert_dates:
  2270. continue
  2271. if is_datetime64_dtype(data[col]):
  2272. self._convert_dates[col] = "tc"
  2273. self._convert_dates = _maybe_convert_to_int_keys(
  2274. self._convert_dates, self.varlist
  2275. )
  2276. for key in self._convert_dates:
  2277. new_type = _convert_datetime_to_stata_type(self._convert_dates[key])
  2278. dtypes[key] = np.dtype(new_type)
  2279. # Verify object arrays are strings and encode to bytes
  2280. self._encode_strings()
  2281. self._set_formats_and_types(dtypes)
  2282. # set the given format for the datetime cols
  2283. if self._convert_dates is not None:
  2284. for key in self._convert_dates:
  2285. if isinstance(key, int):
  2286. self.fmtlist[key] = self._convert_dates[key]
  2287. def _encode_strings(self) -> None:
  2288. """
  2289. Encode strings in dta-specific encoding
  2290. Do not encode columns marked for date conversion or for strL
  2291. conversion. The strL converter independently handles conversion and
  2292. also accepts empty string arrays.
  2293. """
  2294. convert_dates = self._convert_dates
  2295. # _convert_strl is not available in dta 114
  2296. convert_strl = getattr(self, "_convert_strl", [])
  2297. for i, col in enumerate(self.data):
  2298. # Skip columns marked for date conversion or strl conversion
  2299. if i in convert_dates or col in convert_strl:
  2300. continue
  2301. column = self.data[col]
  2302. dtype = column.dtype
  2303. if dtype.type is np.object_:
  2304. inferred_dtype = infer_dtype(column, skipna=True)
  2305. if not ((inferred_dtype == "string") or len(column) == 0):
  2306. col = column.name
  2307. raise ValueError(
  2308. f"""\
  2309. Column `{col}` cannot be exported.\n\nOnly string-like object arrays
  2310. containing all strings or a mix of strings and None can be exported.
  2311. Object arrays containing only null values are prohibited. Other object
  2312. types cannot be exported and must first be converted to one of the
  2313. supported types."""
  2314. )
  2315. encoded = self.data[col].str.encode(self._encoding)
  2316. # If larger than _max_string_length do nothing
  2317. if (
  2318. max_len_string_array(ensure_object(encoded._values))
  2319. <= self._max_string_length
  2320. ):
  2321. self.data[col] = encoded
  2322. def write_file(self) -> None:
  2323. """
  2324. Export DataFrame object to Stata dta format.
  2325. """
  2326. with get_handle(
  2327. self._fname,
  2328. "wb",
  2329. compression=self._compression,
  2330. is_text=False,
  2331. storage_options=self.storage_options,
  2332. ) as self.handles:
  2333. if self.handles.compression["method"] is not None:
  2334. # ZipFile creates a file (with the same name) for each write call.
  2335. # Write it first into a buffer and then write the buffer to the ZipFile.
  2336. self._output_file, self.handles.handle = self.handles.handle, BytesIO()
  2337. self.handles.created_handles.append(self.handles.handle)
  2338. try:
  2339. self._write_header(
  2340. data_label=self._data_label, time_stamp=self._time_stamp
  2341. )
  2342. self._write_map()
  2343. self._write_variable_types()
  2344. self._write_varnames()
  2345. self._write_sortlist()
  2346. self._write_formats()
  2347. self._write_value_label_names()
  2348. self._write_variable_labels()
  2349. self._write_expansion_fields()
  2350. self._write_characteristics()
  2351. records = self._prepare_data()
  2352. self._write_data(records)
  2353. self._write_strls()
  2354. self._write_value_labels()
  2355. self._write_file_close_tag()
  2356. self._write_map()
  2357. self._close()
  2358. except Exception as exc:
  2359. self.handles.close()
  2360. if isinstance(self._fname, (str, os.PathLike)) and os.path.isfile(
  2361. self._fname
  2362. ):
  2363. try:
  2364. os.unlink(self._fname)
  2365. except OSError:
  2366. warnings.warn(
  2367. f"This save was not successful but {self._fname} could not "
  2368. "be deleted. This file is not valid.",
  2369. ResourceWarning,
  2370. stacklevel=find_stack_level(),
  2371. )
  2372. raise exc
  2373. def _close(self) -> None:
  2374. """
  2375. Close the file if it was created by the writer.
  2376. If a buffer or file-like object was passed in, for example a GzipFile,
  2377. then leave this file open for the caller to close.
  2378. """
  2379. # write compression
  2380. if self._output_file is not None:
  2381. assert isinstance(self.handles.handle, BytesIO)
  2382. bio, self.handles.handle = self.handles.handle, self._output_file
  2383. self.handles.handle.write(bio.getvalue())
  2384. def _write_map(self) -> None:
  2385. """No-op, future compatibility"""
  2386. def _write_file_close_tag(self) -> None:
  2387. """No-op, future compatibility"""
  2388. def _write_characteristics(self) -> None:
  2389. """No-op, future compatibility"""
  2390. def _write_strls(self) -> None:
  2391. """No-op, future compatibility"""
  2392. def _write_expansion_fields(self) -> None:
  2393. """Write 5 zeros for expansion fields"""
  2394. self._write(_pad_bytes("", 5))
  2395. def _write_value_labels(self) -> None:
  2396. for vl in self._value_labels:
  2397. self._write_bytes(vl.generate_value_label(self._byteorder))
  2398. def _write_header(
  2399. self,
  2400. data_label: str | None = None,
  2401. time_stamp: datetime.datetime | None = None,
  2402. ) -> None:
  2403. byteorder = self._byteorder
  2404. # ds_format - just use 114
  2405. self._write_bytes(struct.pack("b", 114))
  2406. # byteorder
  2407. self._write(byteorder == ">" and "\x01" or "\x02")
  2408. # filetype
  2409. self._write("\x01")
  2410. # unused
  2411. self._write("\x00")
  2412. # number of vars, 2 bytes
  2413. self._write_bytes(struct.pack(byteorder + "h", self.nvar)[:2])
  2414. # number of obs, 4 bytes
  2415. self._write_bytes(struct.pack(byteorder + "i", self.nobs)[:4])
  2416. # data label 81 bytes, char, null terminated
  2417. if data_label is None:
  2418. self._write_bytes(self._null_terminate_bytes(_pad_bytes("", 80)))
  2419. else:
  2420. self._write_bytes(
  2421. self._null_terminate_bytes(_pad_bytes(data_label[:80], 80))
  2422. )
  2423. # time stamp, 18 bytes, char, null terminated
  2424. # format dd Mon yyyy hh:mm
  2425. if time_stamp is None:
  2426. time_stamp = datetime.datetime.now()
  2427. elif not isinstance(time_stamp, datetime.datetime):
  2428. raise ValueError("time_stamp should be datetime type")
  2429. # GH #13856
  2430. # Avoid locale-specific month conversion
  2431. months = [
  2432. "Jan",
  2433. "Feb",
  2434. "Mar",
  2435. "Apr",
  2436. "May",
  2437. "Jun",
  2438. "Jul",
  2439. "Aug",
  2440. "Sep",
  2441. "Oct",
  2442. "Nov",
  2443. "Dec",
  2444. ]
  2445. month_lookup = {i + 1: month for i, month in enumerate(months)}
  2446. ts = (
  2447. time_stamp.strftime("%d ")
  2448. + month_lookup[time_stamp.month]
  2449. + time_stamp.strftime(" %Y %H:%M")
  2450. )
  2451. self._write_bytes(self._null_terminate_bytes(ts))
  2452. def _write_variable_types(self) -> None:
  2453. for typ in self.typlist:
  2454. self._write_bytes(struct.pack("B", typ))
  2455. def _write_varnames(self) -> None:
  2456. # varlist names are checked by _check_column_names
  2457. # varlist, requires null terminated
  2458. for name in self.varlist:
  2459. name = self._null_terminate_str(name)
  2460. name = _pad_bytes(name[:32], 33)
  2461. self._write(name)
  2462. def _write_sortlist(self) -> None:
  2463. # srtlist, 2*(nvar+1), int array, encoded by byteorder
  2464. srtlist = _pad_bytes("", 2 * (self.nvar + 1))
  2465. self._write(srtlist)
  2466. def _write_formats(self) -> None:
  2467. # fmtlist, 49*nvar, char array
  2468. for fmt in self.fmtlist:
  2469. self._write(_pad_bytes(fmt, 49))
  2470. def _write_value_label_names(self) -> None:
  2471. # lbllist, 33*nvar, char array
  2472. for i in range(self.nvar):
  2473. # Use variable name when categorical
  2474. if self._has_value_labels[i]:
  2475. name = self.varlist[i]
  2476. name = self._null_terminate_str(name)
  2477. name = _pad_bytes(name[:32], 33)
  2478. self._write(name)
  2479. else: # Default is empty label
  2480. self._write(_pad_bytes("", 33))
  2481. def _write_variable_labels(self) -> None:
  2482. # Missing labels are 80 blank characters plus null termination
  2483. blank = _pad_bytes("", 81)
  2484. if self._variable_labels is None:
  2485. for i in range(self.nvar):
  2486. self._write(blank)
  2487. return
  2488. for col in self.data:
  2489. if col in self._variable_labels:
  2490. label = self._variable_labels[col]
  2491. if len(label) > 80:
  2492. raise ValueError("Variable labels must be 80 characters or fewer")
  2493. is_latin1 = all(ord(c) < 256 for c in label)
  2494. if not is_latin1:
  2495. raise ValueError(
  2496. "Variable labels must contain only characters that "
  2497. "can be encoded in Latin-1"
  2498. )
  2499. self._write(_pad_bytes(label, 81))
  2500. else:
  2501. self._write(blank)
  2502. def _convert_strls(self, data: DataFrame) -> DataFrame:
  2503. """No-op, future compatibility"""
  2504. return data
  2505. def _prepare_data(self) -> np.recarray:
  2506. data = self.data
  2507. typlist = self.typlist
  2508. convert_dates = self._convert_dates
  2509. # 1. Convert dates
  2510. if self._convert_dates is not None:
  2511. for i, col in enumerate(data):
  2512. if i in convert_dates:
  2513. data[col] = _datetime_to_stata_elapsed_vec(
  2514. data[col], self.fmtlist[i]
  2515. )
  2516. # 2. Convert strls
  2517. data = self._convert_strls(data)
  2518. # 3. Convert bad string data to '' and pad to correct length
  2519. dtypes = {}
  2520. native_byteorder = self._byteorder == _set_endianness(sys.byteorder)
  2521. for i, col in enumerate(data):
  2522. typ = typlist[i]
  2523. if typ <= self._max_string_length:
  2524. data[col] = data[col].fillna("").apply(_pad_bytes, args=(typ,))
  2525. stype = f"S{typ}"
  2526. dtypes[col] = stype
  2527. data[col] = data[col].astype(stype)
  2528. else:
  2529. dtype = data[col].dtype
  2530. if not native_byteorder:
  2531. dtype = dtype.newbyteorder(self._byteorder)
  2532. dtypes[col] = dtype
  2533. return data.to_records(index=False, column_dtypes=dtypes)
  2534. def _write_data(self, records: np.recarray) -> None:
  2535. self._write_bytes(records.tobytes())
  2536. @staticmethod
  2537. def _null_terminate_str(s: str) -> str:
  2538. s += "\x00"
  2539. return s
  2540. def _null_terminate_bytes(self, s: str) -> bytes:
  2541. return self._null_terminate_str(s).encode(self._encoding)
  2542. def _dtype_to_stata_type_117(dtype: np.dtype, column: Series, force_strl: bool) -> int:
  2543. """
  2544. Converts dtype types to stata types. Returns the byte of the given ordinal.
  2545. See TYPE_MAP and comments for an explanation. This is also explained in
  2546. the dta spec.
  2547. 1 - 2045 are strings of this length
  2548. Pandas Stata
  2549. 32768 - for object strL
  2550. 65526 - for int8 byte
  2551. 65527 - for int16 int
  2552. 65528 - for int32 long
  2553. 65529 - for float32 float
  2554. 65530 - for double double
  2555. If there are dates to convert, then dtype will already have the correct
  2556. type inserted.
  2557. """
  2558. # TODO: expand to handle datetime to integer conversion
  2559. if force_strl:
  2560. return 32768
  2561. if dtype.type is np.object_: # try to coerce it to the biggest string
  2562. # not memory efficient, what else could we
  2563. # do?
  2564. itemsize = max_len_string_array(ensure_object(column._values))
  2565. itemsize = max(itemsize, 1)
  2566. if itemsize <= 2045:
  2567. return itemsize
  2568. return 32768
  2569. elif dtype.type is np.float64:
  2570. return 65526
  2571. elif dtype.type is np.float32:
  2572. return 65527
  2573. elif dtype.type is np.int32:
  2574. return 65528
  2575. elif dtype.type is np.int16:
  2576. return 65529
  2577. elif dtype.type is np.int8:
  2578. return 65530
  2579. else: # pragma : no cover
  2580. raise NotImplementedError(f"Data type {dtype} not supported.")
  2581. def _pad_bytes_new(name: str | bytes, length: int) -> bytes:
  2582. """
  2583. Takes a bytes instance and pads it with null bytes until it's length chars.
  2584. """
  2585. if isinstance(name, str):
  2586. name = bytes(name, "utf-8")
  2587. return name + b"\x00" * (length - len(name))
  2588. class StataStrLWriter:
  2589. """
  2590. Converter for Stata StrLs
  2591. Stata StrLs map 8 byte values to strings which are stored using a
  2592. dictionary-like format where strings are keyed to two values.
  2593. Parameters
  2594. ----------
  2595. df : DataFrame
  2596. DataFrame to convert
  2597. columns : Sequence[str]
  2598. List of columns names to convert to StrL
  2599. version : int, optional
  2600. dta version. Currently supports 117, 118 and 119
  2601. byteorder : str, optional
  2602. Can be ">", "<", "little", or "big". default is `sys.byteorder`
  2603. Notes
  2604. -----
  2605. Supports creation of the StrL block of a dta file for dta versions
  2606. 117, 118 and 119. These differ in how the GSO is stored. 118 and
  2607. 119 store the GSO lookup value as a uint32 and a uint64, while 117
  2608. uses two uint32s. 118 and 119 also encode all strings as unicode
  2609. which is required by the format. 117 uses 'latin-1' a fixed width
  2610. encoding that extends the 7-bit ascii table with an additional 128
  2611. characters.
  2612. """
  2613. def __init__(
  2614. self,
  2615. df: DataFrame,
  2616. columns: Sequence[str],
  2617. version: int = 117,
  2618. byteorder: str | None = None,
  2619. ) -> None:
  2620. if version not in (117, 118, 119):
  2621. raise ValueError("Only dta versions 117, 118 and 119 supported")
  2622. self._dta_ver = version
  2623. self.df = df
  2624. self.columns = columns
  2625. self._gso_table = {"": (0, 0)}
  2626. if byteorder is None:
  2627. byteorder = sys.byteorder
  2628. self._byteorder = _set_endianness(byteorder)
  2629. gso_v_type = "I" # uint32
  2630. gso_o_type = "Q" # uint64
  2631. self._encoding = "utf-8"
  2632. if version == 117:
  2633. o_size = 4
  2634. gso_o_type = "I" # 117 used uint32
  2635. self._encoding = "latin-1"
  2636. elif version == 118:
  2637. o_size = 6
  2638. else: # version == 119
  2639. o_size = 5
  2640. self._o_offet = 2 ** (8 * (8 - o_size))
  2641. self._gso_o_type = gso_o_type
  2642. self._gso_v_type = gso_v_type
  2643. def _convert_key(self, key: tuple[int, int]) -> int:
  2644. v, o = key
  2645. return v + self._o_offet * o
  2646. def generate_table(self) -> tuple[dict[str, tuple[int, int]], DataFrame]:
  2647. """
  2648. Generates the GSO lookup table for the DataFrame
  2649. Returns
  2650. -------
  2651. gso_table : dict
  2652. Ordered dictionary using the string found as keys
  2653. and their lookup position (v,o) as values
  2654. gso_df : DataFrame
  2655. DataFrame where strl columns have been converted to
  2656. (v,o) values
  2657. Notes
  2658. -----
  2659. Modifies the DataFrame in-place.
  2660. The DataFrame returned encodes the (v,o) values as uint64s. The
  2661. encoding depends on the dta version, and can be expressed as
  2662. enc = v + o * 2 ** (o_size * 8)
  2663. so that v is stored in the lower bits and o is in the upper
  2664. bits. o_size is
  2665. * 117: 4
  2666. * 118: 6
  2667. * 119: 5
  2668. """
  2669. gso_table = self._gso_table
  2670. gso_df = self.df
  2671. columns = list(gso_df.columns)
  2672. selected = gso_df[self.columns]
  2673. col_index = [(col, columns.index(col)) for col in self.columns]
  2674. keys = np.empty(selected.shape, dtype=np.uint64)
  2675. for o, (idx, row) in enumerate(selected.iterrows()):
  2676. for j, (col, v) in enumerate(col_index):
  2677. val = row[col]
  2678. # Allow columns with mixed str and None (GH 23633)
  2679. val = "" if val is None else val
  2680. key = gso_table.get(val, None)
  2681. if key is None:
  2682. # Stata prefers human numbers
  2683. key = (v + 1, o + 1)
  2684. gso_table[val] = key
  2685. keys[o, j] = self._convert_key(key)
  2686. for i, col in enumerate(self.columns):
  2687. gso_df[col] = keys[:, i]
  2688. return gso_table, gso_df
  2689. def generate_blob(self, gso_table: dict[str, tuple[int, int]]) -> bytes:
  2690. """
  2691. Generates the binary blob of GSOs that is written to the dta file.
  2692. Parameters
  2693. ----------
  2694. gso_table : dict
  2695. Ordered dictionary (str, vo)
  2696. Returns
  2697. -------
  2698. gso : bytes
  2699. Binary content of dta file to be placed between strl tags
  2700. Notes
  2701. -----
  2702. Output format depends on dta version. 117 uses two uint32s to
  2703. express v and o while 118+ uses a uint32 for v and a uint64 for o.
  2704. """
  2705. # Format information
  2706. # Length includes null term
  2707. # 117
  2708. # GSOvvvvooootllllxxxxxxxxxxxxxxx...x
  2709. # 3 u4 u4 u1 u4 string + null term
  2710. #
  2711. # 118, 119
  2712. # GSOvvvvooooooootllllxxxxxxxxxxxxxxx...x
  2713. # 3 u4 u8 u1 u4 string + null term
  2714. bio = BytesIO()
  2715. gso = bytes("GSO", "ascii")
  2716. gso_type = struct.pack(self._byteorder + "B", 130)
  2717. null = struct.pack(self._byteorder + "B", 0)
  2718. v_type = self._byteorder + self._gso_v_type
  2719. o_type = self._byteorder + self._gso_o_type
  2720. len_type = self._byteorder + "I"
  2721. for strl, vo in gso_table.items():
  2722. if vo == (0, 0):
  2723. continue
  2724. v, o = vo
  2725. # GSO
  2726. bio.write(gso)
  2727. # vvvv
  2728. bio.write(struct.pack(v_type, v))
  2729. # oooo / oooooooo
  2730. bio.write(struct.pack(o_type, o))
  2731. # t
  2732. bio.write(gso_type)
  2733. # llll
  2734. utf8_string = bytes(strl, "utf-8")
  2735. bio.write(struct.pack(len_type, len(utf8_string) + 1))
  2736. # xxx...xxx
  2737. bio.write(utf8_string)
  2738. bio.write(null)
  2739. return bio.getvalue()
  2740. class StataWriter117(StataWriter):
  2741. """
  2742. A class for writing Stata binary dta files in Stata 13 format (117)
  2743. Parameters
  2744. ----------
  2745. fname : path (string), buffer or path object
  2746. string, path object (pathlib.Path or py._path.local.LocalPath) or
  2747. object implementing a binary write() functions. If using a buffer
  2748. then the buffer will not be automatically closed after the file
  2749. is written.
  2750. data : DataFrame
  2751. Input to save
  2752. convert_dates : dict
  2753. Dictionary mapping columns containing datetime types to stata internal
  2754. format to use when writing the dates. Options are 'tc', 'td', 'tm',
  2755. 'tw', 'th', 'tq', 'ty'. Column can be either an integer or a name.
  2756. Datetime columns that do not have a conversion type specified will be
  2757. converted to 'tc'. Raises NotImplementedError if a datetime column has
  2758. timezone information
  2759. write_index : bool
  2760. Write the index to Stata dataset.
  2761. byteorder : str
  2762. Can be ">", "<", "little", or "big". default is `sys.byteorder`
  2763. time_stamp : datetime
  2764. A datetime to use as file creation date. Default is the current time
  2765. data_label : str
  2766. A label for the data set. Must be 80 characters or smaller.
  2767. variable_labels : dict
  2768. Dictionary containing columns as keys and variable labels as values.
  2769. Each label must be 80 characters or smaller.
  2770. convert_strl : list
  2771. List of columns names to convert to Stata StrL format. Columns with
  2772. more than 2045 characters are automatically written as StrL.
  2773. Smaller columns can be converted by including the column name. Using
  2774. StrLs can reduce output file size when strings are longer than 8
  2775. characters, and either frequently repeated or sparse.
  2776. {compression_options}
  2777. .. versionadded:: 1.1.0
  2778. .. versionchanged:: 1.4.0 Zstandard support.
  2779. value_labels : dict of dicts
  2780. Dictionary containing columns as keys and dictionaries of column value
  2781. to labels as values. The combined length of all labels for a single
  2782. variable must be 32,000 characters or smaller.
  2783. .. versionadded:: 1.4.0
  2784. Returns
  2785. -------
  2786. writer : StataWriter117 instance
  2787. The StataWriter117 instance has a write_file method, which will
  2788. write the file to the given `fname`.
  2789. Raises
  2790. ------
  2791. NotImplementedError
  2792. * If datetimes contain timezone information
  2793. ValueError
  2794. * Columns listed in convert_dates are neither datetime64[ns]
  2795. or datetime.datetime
  2796. * Column dtype is not representable in Stata
  2797. * Column listed in convert_dates is not in DataFrame
  2798. * Categorical label contains more than 32,000 characters
  2799. Examples
  2800. --------
  2801. >>> data = pd.DataFrame([[1.0, 1, 'a']], columns=['a', 'b', 'c'])
  2802. >>> writer = pd.io.stata.StataWriter117('./data_file.dta', data)
  2803. >>> writer.write_file()
  2804. Directly write a zip file
  2805. >>> compression = {"method": "zip", "archive_name": "data_file.dta"}
  2806. >>> writer = pd.io.stata.StataWriter117(
  2807. ... './data_file.zip', data, compression=compression
  2808. ... )
  2809. >>> writer.write_file()
  2810. Or with long strings stored in strl format
  2811. >>> data = pd.DataFrame([['A relatively long string'], [''], ['']],
  2812. ... columns=['strls'])
  2813. >>> writer = pd.io.stata.StataWriter117(
  2814. ... './data_file_with_long_strings.dta', data, convert_strl=['strls'])
  2815. >>> writer.write_file()
  2816. """
  2817. _max_string_length = 2045
  2818. _dta_version = 117
  2819. def __init__(
  2820. self,
  2821. fname: FilePath | WriteBuffer[bytes],
  2822. data: DataFrame,
  2823. convert_dates: dict[Hashable, str] | None = None,
  2824. write_index: bool = True,
  2825. byteorder: str | None = None,
  2826. time_stamp: datetime.datetime | None = None,
  2827. data_label: str | None = None,
  2828. variable_labels: dict[Hashable, str] | None = None,
  2829. convert_strl: Sequence[Hashable] | None = None,
  2830. compression: CompressionOptions = "infer",
  2831. storage_options: StorageOptions = None,
  2832. *,
  2833. value_labels: dict[Hashable, dict[float, str]] | None = None,
  2834. ) -> None:
  2835. # Copy to new list since convert_strl might be modified later
  2836. self._convert_strl: list[Hashable] = []
  2837. if convert_strl is not None:
  2838. self._convert_strl.extend(convert_strl)
  2839. super().__init__(
  2840. fname,
  2841. data,
  2842. convert_dates,
  2843. write_index,
  2844. byteorder=byteorder,
  2845. time_stamp=time_stamp,
  2846. data_label=data_label,
  2847. variable_labels=variable_labels,
  2848. value_labels=value_labels,
  2849. compression=compression,
  2850. storage_options=storage_options,
  2851. )
  2852. self._map: dict[str, int] = {}
  2853. self._strl_blob = b""
  2854. @staticmethod
  2855. def _tag(val: str | bytes, tag: str) -> bytes:
  2856. """Surround val with <tag></tag>"""
  2857. if isinstance(val, str):
  2858. val = bytes(val, "utf-8")
  2859. return bytes("<" + tag + ">", "utf-8") + val + bytes("</" + tag + ">", "utf-8")
  2860. def _update_map(self, tag: str) -> None:
  2861. """Update map location for tag with file position"""
  2862. assert self.handles.handle is not None
  2863. self._map[tag] = self.handles.handle.tell()
  2864. def _write_header(
  2865. self,
  2866. data_label: str | None = None,
  2867. time_stamp: datetime.datetime | None = None,
  2868. ) -> None:
  2869. """Write the file header"""
  2870. byteorder = self._byteorder
  2871. self._write_bytes(bytes("<stata_dta>", "utf-8"))
  2872. bio = BytesIO()
  2873. # ds_format - 117
  2874. bio.write(self._tag(bytes(str(self._dta_version), "utf-8"), "release"))
  2875. # byteorder
  2876. bio.write(self._tag(byteorder == ">" and "MSF" or "LSF", "byteorder"))
  2877. # number of vars, 2 bytes in 117 and 118, 4 byte in 119
  2878. nvar_type = "H" if self._dta_version <= 118 else "I"
  2879. bio.write(self._tag(struct.pack(byteorder + nvar_type, self.nvar), "K"))
  2880. # 117 uses 4 bytes, 118 uses 8
  2881. nobs_size = "I" if self._dta_version == 117 else "Q"
  2882. bio.write(self._tag(struct.pack(byteorder + nobs_size, self.nobs), "N"))
  2883. # data label 81 bytes, char, null terminated
  2884. label = data_label[:80] if data_label is not None else ""
  2885. encoded_label = label.encode(self._encoding)
  2886. label_size = "B" if self._dta_version == 117 else "H"
  2887. label_len = struct.pack(byteorder + label_size, len(encoded_label))
  2888. encoded_label = label_len + encoded_label
  2889. bio.write(self._tag(encoded_label, "label"))
  2890. # time stamp, 18 bytes, char, null terminated
  2891. # format dd Mon yyyy hh:mm
  2892. if time_stamp is None:
  2893. time_stamp = datetime.datetime.now()
  2894. elif not isinstance(time_stamp, datetime.datetime):
  2895. raise ValueError("time_stamp should be datetime type")
  2896. # Avoid locale-specific month conversion
  2897. months = [
  2898. "Jan",
  2899. "Feb",
  2900. "Mar",
  2901. "Apr",
  2902. "May",
  2903. "Jun",
  2904. "Jul",
  2905. "Aug",
  2906. "Sep",
  2907. "Oct",
  2908. "Nov",
  2909. "Dec",
  2910. ]
  2911. month_lookup = {i + 1: month for i, month in enumerate(months)}
  2912. ts = (
  2913. time_stamp.strftime("%d ")
  2914. + month_lookup[time_stamp.month]
  2915. + time_stamp.strftime(" %Y %H:%M")
  2916. )
  2917. # '\x11' added due to inspection of Stata file
  2918. stata_ts = b"\x11" + bytes(ts, "utf-8")
  2919. bio.write(self._tag(stata_ts, "timestamp"))
  2920. self._write_bytes(self._tag(bio.getvalue(), "header"))
  2921. def _write_map(self) -> None:
  2922. """
  2923. Called twice during file write. The first populates the values in
  2924. the map with 0s. The second call writes the final map locations when
  2925. all blocks have been written.
  2926. """
  2927. if not self._map:
  2928. self._map = {
  2929. "stata_data": 0,
  2930. "map": self.handles.handle.tell(),
  2931. "variable_types": 0,
  2932. "varnames": 0,
  2933. "sortlist": 0,
  2934. "formats": 0,
  2935. "value_label_names": 0,
  2936. "variable_labels": 0,
  2937. "characteristics": 0,
  2938. "data": 0,
  2939. "strls": 0,
  2940. "value_labels": 0,
  2941. "stata_data_close": 0,
  2942. "end-of-file": 0,
  2943. }
  2944. # Move to start of map
  2945. self.handles.handle.seek(self._map["map"])
  2946. bio = BytesIO()
  2947. for val in self._map.values():
  2948. bio.write(struct.pack(self._byteorder + "Q", val))
  2949. self._write_bytes(self._tag(bio.getvalue(), "map"))
  2950. def _write_variable_types(self) -> None:
  2951. self._update_map("variable_types")
  2952. bio = BytesIO()
  2953. for typ in self.typlist:
  2954. bio.write(struct.pack(self._byteorder + "H", typ))
  2955. self._write_bytes(self._tag(bio.getvalue(), "variable_types"))
  2956. def _write_varnames(self) -> None:
  2957. self._update_map("varnames")
  2958. bio = BytesIO()
  2959. # 118 scales by 4 to accommodate utf-8 data worst case encoding
  2960. vn_len = 32 if self._dta_version == 117 else 128
  2961. for name in self.varlist:
  2962. name = self._null_terminate_str(name)
  2963. name = _pad_bytes_new(name[:32].encode(self._encoding), vn_len + 1)
  2964. bio.write(name)
  2965. self._write_bytes(self._tag(bio.getvalue(), "varnames"))
  2966. def _write_sortlist(self) -> None:
  2967. self._update_map("sortlist")
  2968. sort_size = 2 if self._dta_version < 119 else 4
  2969. self._write_bytes(self._tag(b"\x00" * sort_size * (self.nvar + 1), "sortlist"))
  2970. def _write_formats(self) -> None:
  2971. self._update_map("formats")
  2972. bio = BytesIO()
  2973. fmt_len = 49 if self._dta_version == 117 else 57
  2974. for fmt in self.fmtlist:
  2975. bio.write(_pad_bytes_new(fmt.encode(self._encoding), fmt_len))
  2976. self._write_bytes(self._tag(bio.getvalue(), "formats"))
  2977. def _write_value_label_names(self) -> None:
  2978. self._update_map("value_label_names")
  2979. bio = BytesIO()
  2980. # 118 scales by 4 to accommodate utf-8 data worst case encoding
  2981. vl_len = 32 if self._dta_version == 117 else 128
  2982. for i in range(self.nvar):
  2983. # Use variable name when categorical
  2984. name = "" # default name
  2985. if self._has_value_labels[i]:
  2986. name = self.varlist[i]
  2987. name = self._null_terminate_str(name)
  2988. encoded_name = _pad_bytes_new(name[:32].encode(self._encoding), vl_len + 1)
  2989. bio.write(encoded_name)
  2990. self._write_bytes(self._tag(bio.getvalue(), "value_label_names"))
  2991. def _write_variable_labels(self) -> None:
  2992. # Missing labels are 80 blank characters plus null termination
  2993. self._update_map("variable_labels")
  2994. bio = BytesIO()
  2995. # 118 scales by 4 to accommodate utf-8 data worst case encoding
  2996. vl_len = 80 if self._dta_version == 117 else 320
  2997. blank = _pad_bytes_new("", vl_len + 1)
  2998. if self._variable_labels is None:
  2999. for _ in range(self.nvar):
  3000. bio.write(blank)
  3001. self._write_bytes(self._tag(bio.getvalue(), "variable_labels"))
  3002. return
  3003. for col in self.data:
  3004. if col in self._variable_labels:
  3005. label = self._variable_labels[col]
  3006. if len(label) > 80:
  3007. raise ValueError("Variable labels must be 80 characters or fewer")
  3008. try:
  3009. encoded = label.encode(self._encoding)
  3010. except UnicodeEncodeError as err:
  3011. raise ValueError(
  3012. "Variable labels must contain only characters that "
  3013. f"can be encoded in {self._encoding}"
  3014. ) from err
  3015. bio.write(_pad_bytes_new(encoded, vl_len + 1))
  3016. else:
  3017. bio.write(blank)
  3018. self._write_bytes(self._tag(bio.getvalue(), "variable_labels"))
  3019. def _write_characteristics(self) -> None:
  3020. self._update_map("characteristics")
  3021. self._write_bytes(self._tag(b"", "characteristics"))
  3022. def _write_data(self, records) -> None:
  3023. self._update_map("data")
  3024. self._write_bytes(b"<data>")
  3025. self._write_bytes(records.tobytes())
  3026. self._write_bytes(b"</data>")
  3027. def _write_strls(self) -> None:
  3028. self._update_map("strls")
  3029. self._write_bytes(self._tag(self._strl_blob, "strls"))
  3030. def _write_expansion_fields(self) -> None:
  3031. """No-op in dta 117+"""
  3032. def _write_value_labels(self) -> None:
  3033. self._update_map("value_labels")
  3034. bio = BytesIO()
  3035. for vl in self._value_labels:
  3036. lab = vl.generate_value_label(self._byteorder)
  3037. lab = self._tag(lab, "lbl")
  3038. bio.write(lab)
  3039. self._write_bytes(self._tag(bio.getvalue(), "value_labels"))
  3040. def _write_file_close_tag(self) -> None:
  3041. self._update_map("stata_data_close")
  3042. self._write_bytes(bytes("</stata_dta>", "utf-8"))
  3043. self._update_map("end-of-file")
  3044. def _update_strl_names(self) -> None:
  3045. """
  3046. Update column names for conversion to strl if they might have been
  3047. changed to comply with Stata naming rules
  3048. """
  3049. # Update convert_strl if names changed
  3050. for orig, new in self._converted_names.items():
  3051. if orig in self._convert_strl:
  3052. idx = self._convert_strl.index(orig)
  3053. self._convert_strl[idx] = new
  3054. def _convert_strls(self, data: DataFrame) -> DataFrame:
  3055. """
  3056. Convert columns to StrLs if either very large or in the
  3057. convert_strl variable
  3058. """
  3059. convert_cols = [
  3060. col
  3061. for i, col in enumerate(data)
  3062. if self.typlist[i] == 32768 or col in self._convert_strl
  3063. ]
  3064. if convert_cols:
  3065. ssw = StataStrLWriter(data, convert_cols, version=self._dta_version)
  3066. tab, new_data = ssw.generate_table()
  3067. data = new_data
  3068. self._strl_blob = ssw.generate_blob(tab)
  3069. return data
  3070. def _set_formats_and_types(self, dtypes: Series) -> None:
  3071. self.typlist = []
  3072. self.fmtlist = []
  3073. for col, dtype in dtypes.items():
  3074. force_strl = col in self._convert_strl
  3075. fmt = _dtype_to_default_stata_fmt(
  3076. dtype,
  3077. self.data[col],
  3078. dta_version=self._dta_version,
  3079. force_strl=force_strl,
  3080. )
  3081. self.fmtlist.append(fmt)
  3082. self.typlist.append(
  3083. _dtype_to_stata_type_117(dtype, self.data[col], force_strl)
  3084. )
  3085. class StataWriterUTF8(StataWriter117):
  3086. """
  3087. Stata binary dta file writing in Stata 15 (118) and 16 (119) formats
  3088. DTA 118 and 119 format files support unicode string data (both fixed
  3089. and strL) format. Unicode is also supported in value labels, variable
  3090. labels and the dataset label. Format 119 is automatically used if the
  3091. file contains more than 32,767 variables.
  3092. Parameters
  3093. ----------
  3094. fname : path (string), buffer or path object
  3095. string, path object (pathlib.Path or py._path.local.LocalPath) or
  3096. object implementing a binary write() functions. If using a buffer
  3097. then the buffer will not be automatically closed after the file
  3098. is written.
  3099. data : DataFrame
  3100. Input to save
  3101. convert_dates : dict, default None
  3102. Dictionary mapping columns containing datetime types to stata internal
  3103. format to use when writing the dates. Options are 'tc', 'td', 'tm',
  3104. 'tw', 'th', 'tq', 'ty'. Column can be either an integer or a name.
  3105. Datetime columns that do not have a conversion type specified will be
  3106. converted to 'tc'. Raises NotImplementedError if a datetime column has
  3107. timezone information
  3108. write_index : bool, default True
  3109. Write the index to Stata dataset.
  3110. byteorder : str, default None
  3111. Can be ">", "<", "little", or "big". default is `sys.byteorder`
  3112. time_stamp : datetime, default None
  3113. A datetime to use as file creation date. Default is the current time
  3114. data_label : str, default None
  3115. A label for the data set. Must be 80 characters or smaller.
  3116. variable_labels : dict, default None
  3117. Dictionary containing columns as keys and variable labels as values.
  3118. Each label must be 80 characters or smaller.
  3119. convert_strl : list, default None
  3120. List of columns names to convert to Stata StrL format. Columns with
  3121. more than 2045 characters are automatically written as StrL.
  3122. Smaller columns can be converted by including the column name. Using
  3123. StrLs can reduce output file size when strings are longer than 8
  3124. characters, and either frequently repeated or sparse.
  3125. version : int, default None
  3126. The dta version to use. By default, uses the size of data to determine
  3127. the version. 118 is used if data.shape[1] <= 32767, and 119 is used
  3128. for storing larger DataFrames.
  3129. {compression_options}
  3130. .. versionadded:: 1.1.0
  3131. .. versionchanged:: 1.4.0 Zstandard support.
  3132. value_labels : dict of dicts
  3133. Dictionary containing columns as keys and dictionaries of column value
  3134. to labels as values. The combined length of all labels for a single
  3135. variable must be 32,000 characters or smaller.
  3136. .. versionadded:: 1.4.0
  3137. Returns
  3138. -------
  3139. StataWriterUTF8
  3140. The instance has a write_file method, which will write the file to the
  3141. given `fname`.
  3142. Raises
  3143. ------
  3144. NotImplementedError
  3145. * If datetimes contain timezone information
  3146. ValueError
  3147. * Columns listed in convert_dates are neither datetime64[ns]
  3148. or datetime.datetime
  3149. * Column dtype is not representable in Stata
  3150. * Column listed in convert_dates is not in DataFrame
  3151. * Categorical label contains more than 32,000 characters
  3152. Examples
  3153. --------
  3154. Using Unicode data and column names
  3155. >>> from pandas.io.stata import StataWriterUTF8
  3156. >>> data = pd.DataFrame([[1.0, 1, 'ᴬ']], columns=['a', 'β', 'ĉ'])
  3157. >>> writer = StataWriterUTF8('./data_file.dta', data)
  3158. >>> writer.write_file()
  3159. Directly write a zip file
  3160. >>> compression = {"method": "zip", "archive_name": "data_file.dta"}
  3161. >>> writer = StataWriterUTF8('./data_file.zip', data, compression=compression)
  3162. >>> writer.write_file()
  3163. Or with long strings stored in strl format
  3164. >>> data = pd.DataFrame([['ᴀ relatively long ŝtring'], [''], ['']],
  3165. ... columns=['strls'])
  3166. >>> writer = StataWriterUTF8('./data_file_with_long_strings.dta', data,
  3167. ... convert_strl=['strls'])
  3168. >>> writer.write_file()
  3169. """
  3170. _encoding: Literal["utf-8"] = "utf-8"
  3171. def __init__(
  3172. self,
  3173. fname: FilePath | WriteBuffer[bytes],
  3174. data: DataFrame,
  3175. convert_dates: dict[Hashable, str] | None = None,
  3176. write_index: bool = True,
  3177. byteorder: str | None = None,
  3178. time_stamp: datetime.datetime | None = None,
  3179. data_label: str | None = None,
  3180. variable_labels: dict[Hashable, str] | None = None,
  3181. convert_strl: Sequence[Hashable] | None = None,
  3182. version: int | None = None,
  3183. compression: CompressionOptions = "infer",
  3184. storage_options: StorageOptions = None,
  3185. *,
  3186. value_labels: dict[Hashable, dict[float, str]] | None = None,
  3187. ) -> None:
  3188. if version is None:
  3189. version = 118 if data.shape[1] <= 32767 else 119
  3190. elif version not in (118, 119):
  3191. raise ValueError("version must be either 118 or 119.")
  3192. elif version == 118 and data.shape[1] > 32767:
  3193. raise ValueError(
  3194. "You must use version 119 for data sets containing more than"
  3195. "32,767 variables"
  3196. )
  3197. super().__init__(
  3198. fname,
  3199. data,
  3200. convert_dates=convert_dates,
  3201. write_index=write_index,
  3202. byteorder=byteorder,
  3203. time_stamp=time_stamp,
  3204. data_label=data_label,
  3205. variable_labels=variable_labels,
  3206. value_labels=value_labels,
  3207. convert_strl=convert_strl,
  3208. compression=compression,
  3209. storage_options=storage_options,
  3210. )
  3211. # Override version set in StataWriter117 init
  3212. self._dta_version = version
  3213. def _validate_variable_name(self, name: str) -> str:
  3214. """
  3215. Validate variable names for Stata export.
  3216. Parameters
  3217. ----------
  3218. name : str
  3219. Variable name
  3220. Returns
  3221. -------
  3222. str
  3223. The validated name with invalid characters replaced with
  3224. underscores.
  3225. Notes
  3226. -----
  3227. Stata 118+ support most unicode characters. The only limitation is in
  3228. the ascii range where the characters supported are a-z, A-Z, 0-9 and _.
  3229. """
  3230. # High code points appear to be acceptable
  3231. for c in name:
  3232. if (
  3233. (
  3234. ord(c) < 128
  3235. and (c < "A" or c > "Z")
  3236. and (c < "a" or c > "z")
  3237. and (c < "0" or c > "9")
  3238. and c != "_"
  3239. )
  3240. or 128 <= ord(c) < 192
  3241. or c in {"×", "÷"}
  3242. ):
  3243. name = name.replace(c, "_")
  3244. return name