Skip to content

Commit 2a191d3

Browse files
committed
CLN: StataReader: refactor repeated struct.unpack/read calls to helpers
1 parent 8cb6382 commit 2a191d3

File tree

1 file changed

+69
-71
lines changed

1 file changed

+69
-71
lines changed

pandas/io/stata.py

Lines changed: 69 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1202,9 +1202,42 @@ def _set_encoding(self) -> None:
12021202
else:
12031203
self._encoding = "utf-8"
12041204

1205+
def _read_int8(self) -> int:
1206+
return struct.unpack("b", self.path_or_buf.read(1))[0]
1207+
1208+
def _read_uint8(self) -> int:
1209+
return struct.unpack("B", self.path_or_buf.read(1))[0]
1210+
1211+
def _read_uint16(self) -> int:
1212+
return struct.unpack(f"{self.byteorder}H", self.path_or_buf.read(2))[0]
1213+
1214+
def _read_uint32(self) -> int:
1215+
return struct.unpack(f"{self.byteorder}I", self.path_or_buf.read(4))[0]
1216+
1217+
def _read_uint64(self) -> int:
1218+
return struct.unpack(f"{self.byteorder}Q", self.path_or_buf.read(8))[0]
1219+
1220+
def _read_int16(self) -> int:
1221+
return struct.unpack(f"{self.byteorder}h", self.path_or_buf.read(2))[0]
1222+
1223+
def _read_int32(self) -> int:
1224+
return struct.unpack(f"{self.byteorder}i", self.path_or_buf.read(4))[0]
1225+
1226+
def _read_int64(self) -> int:
1227+
return struct.unpack(f"{self.byteorder}q", self.path_or_buf.read(8))[0]
1228+
1229+
def _read_char8(self) -> bytes:
1230+
return struct.unpack("c", self.path_or_buf.read(1))[0]
1231+
1232+
def _read_int16_count(self, count: int) -> tuple[int, ...]:
1233+
return struct.unpack(
1234+
f"{self.byteorder}{'h' * count}",
1235+
self.path_or_buf.read(2 * count),
1236+
)
1237+
12051238
def _read_header(self) -> None:
1206-
first_char = self.path_or_buf.read(1)
1207-
if struct.unpack("c", first_char)[0] == b"<":
1239+
first_char = self._read_char8()
1240+
if first_char == b"<":
12081241
self._read_new_header()
12091242
else:
12101243
self._read_old_header(first_char)
@@ -1224,11 +1257,9 @@ def _read_new_header(self) -> None:
12241257
self.path_or_buf.read(21) # </release><byteorder>
12251258
self.byteorder = ">" if self.path_or_buf.read(3) == b"MSF" else "<"
12261259
self.path_or_buf.read(15) # </byteorder><K>
1227-
nvar_type = "H" if self.format_version <= 118 else "I"
1228-
nvar_size = 2 if self.format_version <= 118 else 4
1229-
self.nvar = struct.unpack(
1230-
self.byteorder + nvar_type, self.path_or_buf.read(nvar_size)
1231-
)[0]
1260+
self.nvar = (
1261+
self._read_uint16() if self.format_version <= 118 else self._read_uint32()
1262+
)
12321263
self.path_or_buf.read(7) # </K><N>
12331264

12341265
self.nobs = self._get_nobs()
@@ -1240,46 +1271,27 @@ def _read_new_header(self) -> None:
12401271
self.path_or_buf.read(8) # 0x0000000000000000
12411272
self.path_or_buf.read(8) # position of <map>
12421273

1243-
self._seek_vartypes = (
1244-
struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 16
1245-
)
1246-
self._seek_varnames = (
1247-
struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 10
1248-
)
1249-
self._seek_sortlist = (
1250-
struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 10
1251-
)
1252-
self._seek_formats = (
1253-
struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 9
1254-
)
1255-
self._seek_value_label_names = (
1256-
struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 19
1257-
)
1274+
self._seek_vartypes = self._read_int64() + 16
1275+
self._seek_varnames = self._read_int64() + 10
1276+
self._seek_sortlist = self._read_int64() + 10
1277+
self._seek_formats = self._read_int64() + 9
1278+
self._seek_value_label_names = self._read_int64() + 19
12581279

12591280
# Requires version-specific treatment
12601281
self._seek_variable_labels = self._get_seek_variable_labels()
12611282

12621283
self.path_or_buf.read(8) # <characteristics>
1263-
self.data_location = (
1264-
struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 6
1265-
)
1266-
self.seek_strls = (
1267-
struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 7
1268-
)
1269-
self.seek_value_labels = (
1270-
struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 14
1271-
)
1284+
self.data_location = self._read_int64() + 6
1285+
self.seek_strls = self._read_int64() + 7
1286+
self.seek_value_labels = self._read_int64() + 14
12721287

12731288
self.typlist, self.dtyplist = self._get_dtypes(self._seek_vartypes)
12741289

12751290
self.path_or_buf.seek(self._seek_varnames)
12761291
self.varlist = self._get_varlist()
12771292

12781293
self.path_or_buf.seek(self._seek_sortlist)
1279-
self.srtlist = struct.unpack(
1280-
self.byteorder + ("h" * (self.nvar + 1)),
1281-
self.path_or_buf.read(2 * (self.nvar + 1)),
1282-
)[:-1]
1294+
self.srtlist = self._read_int16_count(self.nvar + 1)[:-1]
12831295

12841296
self.path_or_buf.seek(self._seek_formats)
12851297
self.fmtlist = self._get_fmtlist()
@@ -1296,10 +1308,7 @@ def _get_dtypes(
12961308
) -> tuple[list[int | str], list[str | np.dtype]]:
12971309

12981310
self.path_or_buf.seek(seek_vartypes)
1299-
raw_typlist = [
1300-
struct.unpack(self.byteorder + "H", self.path_or_buf.read(2))[0]
1301-
for _ in range(self.nvar)
1302-
]
1311+
raw_typlist = [self._read_uint16() for _ in range(self.nvar)]
13031312

13041313
def f(typ: int) -> int | str:
13051314
if typ <= 2045:
@@ -1368,16 +1377,16 @@ def _get_variable_labels(self) -> list[str]:
13681377

13691378
def _get_nobs(self) -> int:
13701379
if self.format_version >= 118:
1371-
return struct.unpack(self.byteorder + "Q", self.path_or_buf.read(8))[0]
1380+
return self._read_uint64()
13721381
else:
1373-
return struct.unpack(self.byteorder + "I", self.path_or_buf.read(4))[0]
1382+
return self._read_uint32()
13741383

13751384
def _get_data_label(self) -> str:
13761385
if self.format_version >= 118:
1377-
strlen = struct.unpack(self.byteorder + "H", self.path_or_buf.read(2))[0]
1386+
strlen = self._read_uint16()
13781387
return self._decode(self.path_or_buf.read(strlen))
13791388
elif self.format_version == 117:
1380-
strlen = struct.unpack("b", self.path_or_buf.read(1))[0]
1389+
strlen = self._read_int8()
13811390
return self._decode(self.path_or_buf.read(strlen))
13821391
elif self.format_version > 105:
13831392
return self._decode(self.path_or_buf.read(81))
@@ -1386,10 +1395,10 @@ def _get_data_label(self) -> str:
13861395

13871396
def _get_time_stamp(self) -> str:
13881397
if self.format_version >= 118:
1389-
strlen = struct.unpack("b", self.path_or_buf.read(1))[0]
1398+
strlen = self._read_int8()
13901399
return self.path_or_buf.read(strlen).decode("utf-8")
13911400
elif self.format_version == 117:
1392-
strlen = struct.unpack("b", self.path_or_buf.read(1))[0]
1401+
strlen = self._read_int8()
13931402
return self._decode(self.path_or_buf.read(strlen))
13941403
elif self.format_version > 104:
13951404
return self._decode(self.path_or_buf.read(18))
@@ -1404,22 +1413,20 @@ def _get_seek_variable_labels(self) -> int:
14041413
# variable, 20 for the closing tag and 17 for the opening tag
14051414
return self._seek_value_label_names + (33 * self.nvar) + 20 + 17
14061415
elif self.format_version >= 118:
1407-
return struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 17
1416+
return self._read_int64() + 17
14081417
else:
14091418
raise ValueError()
14101419

14111420
def _read_old_header(self, first_char: bytes) -> None:
1412-
self.format_version = struct.unpack("b", first_char)[0]
1421+
self.format_version = int(first_char[0])
14131422
if self.format_version not in [104, 105, 108, 111, 113, 114, 115]:
14141423
raise ValueError(_version_error.format(version=self.format_version))
14151424
self._set_encoding()
1416-
self.byteorder = (
1417-
">" if struct.unpack("b", self.path_or_buf.read(1))[0] == 0x1 else "<"
1418-
)
1419-
self.filetype = struct.unpack("b", self.path_or_buf.read(1))[0]
1425+
self.byteorder = (">" if self._read_int8() == 0x1 else "<")
1426+
self.filetype = self._read_int8()
14201427
self.path_or_buf.read(1) # unused
14211428

1422-
self.nvar = struct.unpack(self.byteorder + "H", self.path_or_buf.read(2))[0]
1429+
self.nvar = self._read_uint16()
14231430
self.nobs = self._get_nobs()
14241431

14251432
self._data_label = self._get_data_label()
@@ -1428,7 +1435,7 @@ def _read_old_header(self, first_char: bytes) -> None:
14281435

14291436
# descriptors
14301437
if self.format_version > 108:
1431-
typlist = [ord(self.path_or_buf.read(1)) for _ in range(self.nvar)]
1438+
typlist = [int(c) for c in self.path_or_buf.read(self.nvar)]
14321439
else:
14331440
buf = self.path_or_buf.read(self.nvar)
14341441
typlistb = np.frombuffer(buf, dtype=np.uint8)
@@ -1458,10 +1465,7 @@ def _read_old_header(self, first_char: bytes) -> None:
14581465
self.varlist = [
14591466
self._decode(self.path_or_buf.read(9)) for _ in range(self.nvar)
14601467
]
1461-
self.srtlist = struct.unpack(
1462-
self.byteorder + ("h" * (self.nvar + 1)),
1463-
self.path_or_buf.read(2 * (self.nvar + 1)),
1464-
)[:-1]
1468+
self.srtlist = self._read_int16_count(self.nvar + 1)[:-1]
14651469

14661470
self.fmtlist = self._get_fmtlist()
14671471

@@ -1476,17 +1480,11 @@ def _read_old_header(self, first_char: bytes) -> None:
14761480

14771481
if self.format_version > 104:
14781482
while True:
1479-
data_type = struct.unpack(
1480-
self.byteorder + "b", self.path_or_buf.read(1)
1481-
)[0]
1483+
data_type = self._read_int8()
14821484
if self.format_version > 108:
1483-
data_len = struct.unpack(
1484-
self.byteorder + "i", self.path_or_buf.read(4)
1485-
)[0]
1485+
data_len = self._read_int32()
14861486
else:
1487-
data_len = struct.unpack(
1488-
self.byteorder + "h", self.path_or_buf.read(2)
1489-
)[0]
1487+
data_len = self._read_int16()
14901488
if data_type == 0:
14911489
break
14921490
self.path_or_buf.read(data_len)
@@ -1570,8 +1568,8 @@ def _read_value_labels(self) -> None:
15701568
labname = self._decode(self.path_or_buf.read(129))
15711569
self.path_or_buf.read(3) # padding
15721570

1573-
n = struct.unpack(self.byteorder + "I", self.path_or_buf.read(4))[0]
1574-
txtlen = struct.unpack(self.byteorder + "I", self.path_or_buf.read(4))[0]
1571+
n = self._read_uint32()
1572+
txtlen = self._read_uint32()
15751573
off = np.frombuffer(
15761574
self.path_or_buf.read(4 * n), dtype=self.byteorder + "i4", count=n
15771575
)
@@ -1599,7 +1597,7 @@ def _read_strls(self) -> None:
15991597
break
16001598

16011599
if self.format_version == 117:
1602-
v_o = struct.unpack(self.byteorder + "Q", self.path_or_buf.read(8))[0]
1600+
v_o = self._read_uint64()
16031601
else:
16041602
buf = self.path_or_buf.read(12)
16051603
# Only tested on little endian file on little endian machine.
@@ -1610,8 +1608,8 @@ def _read_strls(self) -> None:
16101608
# This path may not be correct, impossible to test
16111609
buf = buf[0:v_size] + buf[(4 + v_size) :]
16121610
v_o = struct.unpack("Q", buf)[0]
1613-
typ = struct.unpack("B", self.path_or_buf.read(1))[0]
1614-
length = struct.unpack(self.byteorder + "I", self.path_or_buf.read(4))[0]
1611+
typ = self._read_uint8()
1612+
length = self._read_uint32()
16151613
va = self.path_or_buf.read(length)
16161614
if typ == 130:
16171615
decoded_va = va[0:-1].decode(self._encoding)

0 commit comments

Comments
 (0)