Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 910c98d

Browse files
authored
Merge pull request #369 from datafold/jan17
Bugfix: Add brackets around WHERE clause
2 parents 2a4ea5d + 1ed7ce0 commit 910c98d

File tree

4 files changed

+8
-5
lines changed

4 files changed

+8
-5
lines changed

data_diff/__main__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def write_usage(self, prog: str, args: str = "", prefix: Optional[str] = None) -
185185
metavar="COUNT",
186186
)
187187
@click.option(
188-
"-w", "--where", default=None, help="An additional 'where' expression to restrict the search space.", metavar="EXPR"
188+
"-w", "--where", default=None, help="An additional 'where' expression to restrict the search space. Beware of SQL Injection!", metavar="EXPR"
189189
)
190190
@click.option("-a", "--algorithm", default=Algorithm.AUTO.value, type=click.Choice([i.value for i in Algorithm]))
191191
@click.option(

data_diff/joindiff_tables.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ def _test_null_keys(self, table1, table2):
253253
q = t.select(*this[key_columns]).where(or_(this[k] == None for k in key_columns))
254254
nulls = ts.database.query(q, list)
255255
if nulls:
256-
raise ValueError("NULL values in one or more primary keys")
256+
raise ValueError(f"NULL values in one or more primary keys of {ts.table_path}")
257257

258258
def _collect_stats(self, i, table_seg: TableSegment, info_tree: InfoTree):
259259
logger.debug(f"Collecting stats for table #{i}")

data_diff/table_segment.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,11 @@ def __post_init__(self):
6767
f"Error: min_update expected to be smaller than max_update! ({self.min_update} >= {self.max_update})"
6868
)
6969

70+
def _where(self):
71+
return f"({self.where})" if self.where else None
72+
7073
def _with_raw_schema(self, raw_schema: dict) -> "TableSegment":
71-
schema = self.database._process_table_schema(self.table_path, raw_schema, self.relevant_columns, self.where)
74+
schema = self.database._process_table_schema(self.table_path, raw_schema, self.relevant_columns, self._where())
7275
return self.new(_schema=create_schema(self.database, self.table_path, schema, self.case_sensitive))
7376

7477
def with_schema(self) -> "TableSegment":
@@ -100,7 +103,7 @@ def source_table(self):
100103

101104
def make_select(self):
102105
return self.source_table.where(
103-
*self._make_key_range(), *self._make_update_range(), Code(self.where) if self.where else SKIP
106+
*self._make_key_range(), *self._make_update_range(), Code(self._where()) if self.where else SKIP
104107
)
105108

106109
def get_values(self) -> list:

tests/test_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def test_api(self):
4949

5050
# test where
5151
diff_id = diff[0][1][0]
52-
where = f"id != {diff_id}"
52+
where = f"id != {diff_id} OR id = 90000000"
5353

5454
t1 = connect_to_table(TEST_MYSQL_CONN_STRING, self.table_src_name, where=where)
5555
t2 = connect_to_table(TEST_MYSQL_CONN_STRING, self.table_dst_name, where=where)

0 commit comments

Comments
 (0)