Skip to content

Commit 62fd182

Browse files
committed
SA14: Add SQLAlchemy 1.4 compatibility for CrateCompiler
- Code reuse was aimed at, but for the SA <1.4 vs. SA >=1.4 split, two functions, `visit_update_14` and `_get_crud_params_14`, have been vendored separately to accompany `crate.client.sqlalchemy.compiler.CrateCompiler`. All adjustments have now been marked inline with `CrateDB amendment`. - The main query rewriting function for UPDATE statements, `rewrite_update`, needed adjustments to account for a different wrapping/nesting of in/out parameters. - The `cresultproxy` module was temporarily taken out of the equation because it raised some runtime error.
1 parent 1d9e151 commit 62fd182

File tree

5 files changed

+367
-3
lines changed

5 files changed

+367
-3
lines changed

CHANGES.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ Unreleased
2222
- Added support for enabling SSL using SQLAlchemy DB URI with parameter
2323
``?ssl=true``.
2424

25+
- Add support for SQLAlchemy 1.4
26+
2527
2020/09/28 0.26.0
2628
=================
2729

src/crate/client/sqlalchemy/compiler.py

Lines changed: 314 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,10 @@
2323
from collections import defaultdict
2424

2525
import sqlalchemy as sa
26-
from sqlalchemy.sql import crud
26+
from sqlalchemy.sql import crud, selectable
2727
from sqlalchemy.sql import compiler
2828
from .types import MutableDict
29+
from .sa_version import SA_VERSION, SA_1_4
2930

3031

3132
def rewrite_update(clauseelement, multiparams, params):
@@ -73,7 +74,16 @@ def rewrite_update(clauseelement, multiparams, params):
7374
def crate_before_execute(conn, clauseelement, multiparams, params):
7475
is_crate = type(conn.dialect).__name__ == 'CrateDialect'
7576
if is_crate and isinstance(clauseelement, sa.sql.expression.Update):
76-
return rewrite_update(clauseelement, multiparams, params)
77+
if SA_VERSION >= SA_1_4:
78+
multiparams = ([params],)
79+
params = {}
80+
81+
clauseelement, multiparams, params = rewrite_update(clauseelement, multiparams, params)
82+
83+
if SA_VERSION >= SA_1_4:
84+
params = multiparams[0]
85+
multiparams = []
86+
7787
return clauseelement, multiparams, params
7888

7989

@@ -189,6 +199,9 @@ def visit_update(self, update_stmt, **kw):
189199
Parts are taken from the SQLCompiler base class.
190200
"""
191201

202+
if SA_VERSION >= SA_1_4:
203+
return self.visit_update_14(update_stmt, **kw)
204+
192205
if not update_stmt.parameters and \
193206
not hasattr(update_stmt, '_crate_specific'):
194207
return super(CrateCompiler, self).visit_update(update_stmt, **kw)
@@ -212,11 +225,14 @@ def visit_update(self, update_stmt, **kw):
212225
update_stmt, table_text
213226
)
214227

228+
# CrateDB amendment.
215229
crud_params = self._get_crud_params(update_stmt, **kw)
216230

217231
text += table_text
218232

219233
text += ' SET '
234+
235+
# CrateDB amendment begin.
220236
include_table = extra_froms and \
221237
self.render_table_with_column_in_update_from
222238

@@ -234,6 +250,7 @@ def visit_update(self, update_stmt, **kw):
234250
set_clauses.append(k + ' = ' + self.process(bindparam))
235251

236252
text += ', '.join(set_clauses)
253+
# CrateDB amendment end.
237254

238255
if self.returning or update_stmt._returning:
239256
if not self.returning:
@@ -269,7 +286,6 @@ def visit_update(self, update_stmt, **kw):
269286

270287
def _get_crud_params(compiler, stmt, **kw):
271288
""" extract values from crud parameters
272-
273289
taken from SQLAlchemy's crud module (since 1.0.x) and
274290
adapted for Crate dialect"""
275291

@@ -325,3 +341,298 @@ def _get_crud_params(compiler, stmt, **kw):
325341
values, kw)
326342

327343
return values
344+
345+
def visit_update_14(self, update_stmt, **kw):
346+
347+
compile_state = update_stmt._compile_state_factory(
348+
update_stmt, self, **kw
349+
)
350+
update_stmt = compile_state.statement
351+
352+
toplevel = not self.stack
353+
if toplevel:
354+
self.isupdate = True
355+
if not self.compile_state:
356+
self.compile_state = compile_state
357+
358+
extra_froms = compile_state._extra_froms
359+
is_multitable = bool(extra_froms)
360+
361+
if is_multitable:
362+
# main table might be a JOIN
363+
main_froms = set(selectable._from_objects(update_stmt.table))
364+
render_extra_froms = [
365+
f for f in extra_froms if f not in main_froms
366+
]
367+
correlate_froms = main_froms.union(extra_froms)
368+
else:
369+
render_extra_froms = []
370+
correlate_froms = {update_stmt.table}
371+
372+
self.stack.append(
373+
{
374+
"correlate_froms": correlate_froms,
375+
"asfrom_froms": correlate_froms,
376+
"selectable": update_stmt,
377+
}
378+
)
379+
380+
text = "UPDATE "
381+
382+
if update_stmt._prefixes:
383+
text += self._generate_prefixes(
384+
update_stmt, update_stmt._prefixes, **kw
385+
)
386+
387+
table_text = self.update_tables_clause(
388+
update_stmt, update_stmt.table, render_extra_froms, **kw
389+
)
390+
391+
# CrateDB amendment.
392+
crud_params = _get_crud_params_14(
393+
self, update_stmt, compile_state, **kw
394+
)
395+
396+
if update_stmt._hints:
397+
dialect_hints, table_text = self._setup_crud_hints(
398+
update_stmt, table_text
399+
)
400+
else:
401+
dialect_hints = None
402+
403+
text += table_text
404+
405+
text += " SET "
406+
407+
# CrateDB amendment begin.
408+
include_table = extra_froms and \
409+
self.render_table_with_column_in_update_from
410+
411+
set_clauses = []
412+
413+
for c, expr, value in crud_params:
414+
key = c._compiler_dispatch(self, include_table=include_table)
415+
clause = key + ' = ' + value
416+
set_clauses.append(clause)
417+
418+
for k, v in compile_state._dict_parameters.items():
419+
if isinstance(k, str) and '[' in k:
420+
bindparam = sa.sql.bindparam(k, v)
421+
clause = k + ' = ' + self.process(bindparam)
422+
set_clauses.append(clause)
423+
424+
text += ', '.join(set_clauses)
425+
# CrateDB amendment end.
426+
427+
if self.returning or update_stmt._returning:
428+
if self.returning_precedes_values:
429+
text += " " + self.returning_clause(
430+
update_stmt, self.returning or update_stmt._returning
431+
)
432+
433+
if extra_froms:
434+
extra_from_text = self.update_from_clause(
435+
update_stmt,
436+
update_stmt.table,
437+
render_extra_froms,
438+
dialect_hints,
439+
**kw
440+
)
441+
if extra_from_text:
442+
text += " " + extra_from_text
443+
444+
if update_stmt._where_criteria:
445+
t = self._generate_delimited_and_list(
446+
update_stmt._where_criteria, **kw
447+
)
448+
if t:
449+
text += " WHERE " + t
450+
451+
limit_clause = self.update_limit_clause(update_stmt)
452+
if limit_clause:
453+
text += " " + limit_clause
454+
455+
if (
456+
self.returning or update_stmt._returning
457+
) and not self.returning_precedes_values:
458+
text += " " + self.returning_clause(
459+
update_stmt, self.returning or update_stmt._returning
460+
)
461+
462+
if self.ctes and toplevel:
463+
text = self._render_cte_clause() + text
464+
465+
self.stack.pop(-1)
466+
467+
return text
468+
469+
470+
def _get_crud_params_14(compiler, stmt, compile_state, **kw):
471+
"""create a set of tuples representing column/string pairs for use
472+
in an INSERT or UPDATE statement.
473+
474+
Also generates the Compiled object's postfetch, prefetch, and
475+
returning column collections, used for default handling and ultimately
476+
populating the CursorResult's prefetch_cols() and postfetch_cols()
477+
collections.
478+
479+
"""
480+
from sqlalchemy.sql.crud import _key_getters_for_crud_column
481+
from sqlalchemy.sql.crud import _create_bind_param
482+
from sqlalchemy.sql.crud import REQUIRED
483+
from sqlalchemy.sql.crud import _get_stmt_parameter_tuples_params
484+
from sqlalchemy.sql.crud import _get_multitable_params
485+
from sqlalchemy.sql.crud import _scan_insert_from_select_cols
486+
from sqlalchemy.sql.crud import _scan_cols
487+
from sqlalchemy import exc # noqa: F401
488+
from sqlalchemy.sql.crud import _extend_values_for_multiparams
489+
490+
compiler.postfetch = []
491+
compiler.insert_prefetch = []
492+
compiler.update_prefetch = []
493+
compiler.returning = []
494+
495+
# getters - these are normally just column.key,
496+
# but in the case of mysql multi-table update, the rules for
497+
# .key must conditionally take tablename into account
498+
(
499+
_column_as_key,
500+
_getattr_col_key,
501+
_col_bind_name,
502+
) = getters = _key_getters_for_crud_column(compiler, stmt, compile_state)
503+
504+
compiler._key_getters_for_crud_column = getters
505+
506+
# no parameters in the statement, no parameters in the
507+
# compiled params - return binds for all columns
508+
if compiler.column_keys is None and compile_state._no_parameters:
509+
return [
510+
(
511+
c,
512+
compiler.preparer.format_column(c),
513+
_create_bind_param(compiler, c, None, required=True),
514+
)
515+
for c in stmt.table.columns
516+
]
517+
518+
if compile_state._has_multi_parameters:
519+
spd = compile_state._multi_parameters[0]
520+
stmt_parameter_tuples = list(spd.items())
521+
elif compile_state._ordered_values:
522+
spd = compile_state._dict_parameters
523+
stmt_parameter_tuples = compile_state._ordered_values
524+
elif compile_state._dict_parameters:
525+
spd = compile_state._dict_parameters
526+
stmt_parameter_tuples = list(spd.items())
527+
else:
528+
stmt_parameter_tuples = spd = None
529+
530+
# if we have statement parameters - set defaults in the
531+
# compiled params
532+
if compiler.column_keys is None:
533+
parameters = {}
534+
elif stmt_parameter_tuples:
535+
parameters = dict(
536+
(_column_as_key(key), REQUIRED)
537+
for key in compiler.column_keys
538+
if key not in spd
539+
)
540+
else:
541+
parameters = dict(
542+
(_column_as_key(key), REQUIRED) for key in compiler.column_keys
543+
)
544+
545+
# create a list of column assignment clauses as tuples
546+
values = []
547+
548+
if stmt_parameter_tuples is not None:
549+
_get_stmt_parameter_tuples_params(
550+
compiler,
551+
compile_state,
552+
parameters,
553+
stmt_parameter_tuples,
554+
_column_as_key,
555+
values,
556+
kw,
557+
)
558+
559+
check_columns = {}
560+
561+
# special logic that only occurs for multi-table UPDATE
562+
# statements
563+
if compile_state.isupdate and compile_state.is_multitable:
564+
_get_multitable_params(
565+
compiler,
566+
stmt,
567+
compile_state,
568+
stmt_parameter_tuples,
569+
check_columns,
570+
_col_bind_name,
571+
_getattr_col_key,
572+
values,
573+
kw,
574+
)
575+
576+
if compile_state.isinsert and stmt._select_names:
577+
_scan_insert_from_select_cols(
578+
compiler,
579+
stmt,
580+
compile_state,
581+
parameters,
582+
_getattr_col_key,
583+
_column_as_key,
584+
_col_bind_name,
585+
check_columns,
586+
values,
587+
kw,
588+
)
589+
else:
590+
_scan_cols(
591+
compiler,
592+
stmt,
593+
compile_state,
594+
parameters,
595+
_getattr_col_key,
596+
_column_as_key,
597+
_col_bind_name,
598+
check_columns,
599+
values,
600+
kw,
601+
)
602+
603+
# CrateDB amendment.
604+
# The rewriting logic in `rewrite_update` and `visit_update` needs
605+
# adjustments here in order to prevent `sqlalchemy.exc.CompileError:
606+
# Unconsumed column names: characters_name, data['nested']`
607+
"""
608+
if parameters and stmt_parameter_tuples:
609+
check = (
610+
set(parameters)
611+
.intersection(_column_as_key(k) for k, v in stmt_parameter_tuples)
612+
.difference(check_columns)
613+
)
614+
if check:
615+
raise exc.CompileError(
616+
"Unconsumed column names: %s"
617+
% (", ".join("%s" % (c,) for c in check))
618+
)
619+
"""
620+
621+
if compile_state._has_multi_parameters:
622+
values = _extend_values_for_multiparams(
623+
compiler, stmt, compile_state, values, kw
624+
)
625+
elif not values and compiler.for_executemany:
626+
# convert an "INSERT DEFAULT VALUES"
627+
# into INSERT (firstcol) VALUES (DEFAULT) which can be turned
628+
# into an in-place multi values. This supports
629+
# insert_executemany_returning mode :)
630+
values = [
631+
(
632+
stmt.table.columns[0],
633+
compiler.preparer.format_column(stmt.table.columns[0]),
634+
"DEFAULT",
635+
)
636+
]
637+
638+
return values

src/crate/client/sqlalchemy/dialect.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@
2222
import logging
2323
from datetime import datetime, date
2424

25+
# FIXME: Workaround to be able to use SQLAlchemy 1.4.
26+
# Caveat: This purges the ``cresultproxy`` extension
27+
# at runtime, so it will impose a speed bump.
28+
import crate.client.sqlalchemy.monkey # noqa:F401
29+
2530
from sqlalchemy import types as sqltypes
2631
from sqlalchemy.engine import default, reflection
2732
from sqlalchemy.sql import functions

0 commit comments

Comments
 (0)