Skip to content

Commit 708c8aa

Browse files
authored
Allow tuple-valued params in read_sql[_query] (#997)
1 parent 336718a commit 708c8aa

File tree

2 files changed

+65
-4
lines changed

2 files changed

+65
-4
lines changed

pandas-stubs/io/sql.pyi

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,14 @@ def read_sql_query(
6666
con: _SQLConnection,
6767
index_col: str | list[str] | None = ...,
6868
coerce_float: bool = ...,
69-
params: list[Scalar] | tuple[Scalar, ...] | Mapping[str, Scalar] | None = ...,
69+
params: (
70+
list[Scalar]
71+
| tuple[Scalar, ...]
72+
| tuple[tuple[Scalar, ...], ...]
73+
| Mapping[str, Scalar]
74+
| Mapping[str, tuple[Scalar, ...]]
75+
| None
76+
) = ...,
7077
parse_dates: list[str] | dict[str, str] | dict[str, dict[str, Any]] | None = ...,
7178
*,
7279
chunksize: int,
@@ -79,7 +86,14 @@ def read_sql_query(
7986
con: _SQLConnection,
8087
index_col: str | list[str] | None = ...,
8188
coerce_float: bool = ...,
82-
params: list[Scalar] | tuple[Scalar, ...] | Mapping[str, Scalar] | None = ...,
89+
params: (
90+
list[Scalar]
91+
| tuple[Scalar, ...]
92+
| tuple[tuple[Scalar, ...], ...]
93+
| Mapping[str, Scalar]
94+
| Mapping[str, tuple[Scalar, ...]]
95+
| None
96+
) = ...,
8397
parse_dates: list[str] | dict[str, str] | dict[str, dict[str, Any]] | None = ...,
8498
chunksize: None = ...,
8599
dtype: DtypeArg | None = ...,
@@ -91,7 +105,14 @@ def read_sql(
91105
con: _SQLConnection,
92106
index_col: str | list[str] | None = ...,
93107
coerce_float: bool = ...,
94-
params: list[Scalar] | tuple[Scalar, ...] | Mapping[str, Scalar] | None = ...,
108+
params: (
109+
list[Scalar]
110+
| tuple[Scalar, ...]
111+
| tuple[tuple[Scalar, ...], ...]
112+
| Mapping[str, Scalar]
113+
| Mapping[str, tuple[Scalar, ...]]
114+
| None
115+
) = ...,
95116
parse_dates: list[str] | dict[str, str] | dict[str, dict[str, Any]] | None = ...,
96117
columns: list[str] = ...,
97118
*,
@@ -105,7 +126,14 @@ def read_sql(
105126
con: _SQLConnection,
106127
index_col: str | list[str] | None = ...,
107128
coerce_float: bool = ...,
108-
params: list[Scalar] | tuple[Scalar, ...] | Mapping[str, Scalar] | None = ...,
129+
params: (
130+
list[Scalar]
131+
| tuple[Scalar, ...]
132+
| tuple[tuple[Scalar, ...], ...]
133+
| Mapping[str, Scalar]
134+
| Mapping[str, tuple[Scalar, ...]]
135+
| None
136+
) = ...,
109137
parse_dates: list[str] | dict[str, str] | dict[str, dict[str, Any]] | None = ...,
110138
columns: list[str] = ...,
111139
chunksize: None = ...,

tests/test_io.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1238,6 +1238,39 @@ def test_read_sql_query_via_sqlalchemy_engine_with_params():
12381238
engine.dispose()
12391239

12401240

1241+
@pytest.mark.skip(
1242+
reason="Only works in Postgres (and MySQL, but with different query syntax)"
1243+
)
1244+
def test_read_sql_query_via_sqlalchemy_engine_with_tuple_valued_params():
1245+
with ensure_clean() as path:
1246+
db_uri = "postgresql+psycopg2://postgres@localhost:5432/postgres"
1247+
engine = sqlalchemy.create_engine(db_uri)
1248+
1249+
check(
1250+
assert_type(
1251+
read_sql_query(
1252+
"select * from test where a in %(a)s",
1253+
con=engine,
1254+
params={"a": (1, 2)},
1255+
),
1256+
DataFrame,
1257+
),
1258+
DataFrame,
1259+
)
1260+
check(
1261+
assert_type(
1262+
read_sql_query(
1263+
"select * from test where a in %s",
1264+
con=engine,
1265+
params=((1, 2),),
1266+
),
1267+
DataFrame,
1268+
),
1269+
DataFrame,
1270+
)
1271+
engine.dispose()
1272+
1273+
12411274
def test_read_html():
12421275
check(assert_type(DF.to_html(), str), str)
12431276
with ensure_clean() as path:

0 commit comments

Comments
 (0)