23
23
from collections import defaultdict
24
24
25
25
import sqlalchemy as sa
26
- from sqlalchemy .sql import crud
26
+ from sqlalchemy .sql import crud , selectable
27
27
from sqlalchemy .sql import compiler
28
28
from .types import MutableDict
29
+ from .sa_version import SA_VERSION , SA_1_4
29
30
30
31
31
32
def rewrite_update (clauseelement , multiparams , params ):
@@ -73,7 +74,16 @@ def rewrite_update(clauseelement, multiparams, params):
73
74
def crate_before_execute (conn , clauseelement , multiparams , params ):
74
75
is_crate = type (conn .dialect ).__name__ == 'CrateDialect'
75
76
if is_crate and isinstance (clauseelement , sa .sql .expression .Update ):
76
- return rewrite_update (clauseelement , multiparams , params )
77
+ if SA_VERSION >= SA_1_4 :
78
+ multiparams = ([params ],)
79
+ params = {}
80
+
81
+ clauseelement , multiparams , params = rewrite_update (clauseelement , multiparams , params )
82
+
83
+ if SA_VERSION >= SA_1_4 :
84
+ params = multiparams [0 ]
85
+ multiparams = []
86
+
77
87
return clauseelement , multiparams , params
78
88
79
89
@@ -189,6 +199,9 @@ def visit_update(self, update_stmt, **kw):
189
199
Parts are taken from the SQLCompiler base class.
190
200
"""
191
201
202
+ if SA_VERSION >= SA_1_4 :
203
+ return self .visit_update_14 (update_stmt , ** kw )
204
+
192
205
if not update_stmt .parameters and \
193
206
not hasattr (update_stmt , '_crate_specific' ):
194
207
return super (CrateCompiler , self ).visit_update (update_stmt , ** kw )
@@ -212,11 +225,14 @@ def visit_update(self, update_stmt, **kw):
212
225
update_stmt , table_text
213
226
)
214
227
228
+ # CrateDB amendment.
215
229
crud_params = self ._get_crud_params (update_stmt , ** kw )
216
230
217
231
text += table_text
218
232
219
233
text += ' SET '
234
+
235
+ # CrateDB amendment begin.
220
236
include_table = extra_froms and \
221
237
self .render_table_with_column_in_update_from
222
238
@@ -234,6 +250,7 @@ def visit_update(self, update_stmt, **kw):
234
250
set_clauses .append (k + ' = ' + self .process (bindparam ))
235
251
236
252
text += ', ' .join (set_clauses )
253
+ # CrateDB amendment end.
237
254
238
255
if self .returning or update_stmt ._returning :
239
256
if not self .returning :
@@ -269,7 +286,6 @@ def visit_update(self, update_stmt, **kw):
269
286
270
287
def _get_crud_params (compiler , stmt , ** kw ):
271
288
""" extract values from crud parameters
272
-
273
289
taken from SQLAlchemy's crud module (since 1.0.x) and
274
290
adapted for Crate dialect"""
275
291
@@ -325,3 +341,298 @@ def _get_crud_params(compiler, stmt, **kw):
325
341
values , kw )
326
342
327
343
return values
344
+
345
+ def visit_update_14 (self , update_stmt , ** kw ):
346
+
347
+ compile_state = update_stmt ._compile_state_factory (
348
+ update_stmt , self , ** kw
349
+ )
350
+ update_stmt = compile_state .statement
351
+
352
+ toplevel = not self .stack
353
+ if toplevel :
354
+ self .isupdate = True
355
+ if not self .compile_state :
356
+ self .compile_state = compile_state
357
+
358
+ extra_froms = compile_state ._extra_froms
359
+ is_multitable = bool (extra_froms )
360
+
361
+ if is_multitable :
362
+ # main table might be a JOIN
363
+ main_froms = set (selectable ._from_objects (update_stmt .table ))
364
+ render_extra_froms = [
365
+ f for f in extra_froms if f not in main_froms
366
+ ]
367
+ correlate_froms = main_froms .union (extra_froms )
368
+ else :
369
+ render_extra_froms = []
370
+ correlate_froms = {update_stmt .table }
371
+
372
+ self .stack .append (
373
+ {
374
+ "correlate_froms" : correlate_froms ,
375
+ "asfrom_froms" : correlate_froms ,
376
+ "selectable" : update_stmt ,
377
+ }
378
+ )
379
+
380
+ text = "UPDATE "
381
+
382
+ if update_stmt ._prefixes :
383
+ text += self ._generate_prefixes (
384
+ update_stmt , update_stmt ._prefixes , ** kw
385
+ )
386
+
387
+ table_text = self .update_tables_clause (
388
+ update_stmt , update_stmt .table , render_extra_froms , ** kw
389
+ )
390
+
391
+ # CrateDB amendment.
392
+ crud_params = _get_crud_params_14 (
393
+ self , update_stmt , compile_state , ** kw
394
+ )
395
+
396
+ if update_stmt ._hints :
397
+ dialect_hints , table_text = self ._setup_crud_hints (
398
+ update_stmt , table_text
399
+ )
400
+ else :
401
+ dialect_hints = None
402
+
403
+ text += table_text
404
+
405
+ text += " SET "
406
+
407
+ # CrateDB amendment begin.
408
+ include_table = extra_froms and \
409
+ self .render_table_with_column_in_update_from
410
+
411
+ set_clauses = []
412
+
413
+ for c , expr , value in crud_params :
414
+ key = c ._compiler_dispatch (self , include_table = include_table )
415
+ clause = key + ' = ' + value
416
+ set_clauses .append (clause )
417
+
418
+ for k , v in compile_state ._dict_parameters .items ():
419
+ if isinstance (k , str ) and '[' in k :
420
+ bindparam = sa .sql .bindparam (k , v )
421
+ clause = k + ' = ' + self .process (bindparam )
422
+ set_clauses .append (clause )
423
+
424
+ text += ', ' .join (set_clauses )
425
+ # CrateDB amendment end.
426
+
427
+ if self .returning or update_stmt ._returning :
428
+ if self .returning_precedes_values :
429
+ text += " " + self .returning_clause (
430
+ update_stmt , self .returning or update_stmt ._returning
431
+ )
432
+
433
+ if extra_froms :
434
+ extra_from_text = self .update_from_clause (
435
+ update_stmt ,
436
+ update_stmt .table ,
437
+ render_extra_froms ,
438
+ dialect_hints ,
439
+ ** kw
440
+ )
441
+ if extra_from_text :
442
+ text += " " + extra_from_text
443
+
444
+ if update_stmt ._where_criteria :
445
+ t = self ._generate_delimited_and_list (
446
+ update_stmt ._where_criteria , ** kw
447
+ )
448
+ if t :
449
+ text += " WHERE " + t
450
+
451
+ limit_clause = self .update_limit_clause (update_stmt )
452
+ if limit_clause :
453
+ text += " " + limit_clause
454
+
455
+ if (
456
+ self .returning or update_stmt ._returning
457
+ ) and not self .returning_precedes_values :
458
+ text += " " + self .returning_clause (
459
+ update_stmt , self .returning or update_stmt ._returning
460
+ )
461
+
462
+ if self .ctes and toplevel :
463
+ text = self ._render_cte_clause () + text
464
+
465
+ self .stack .pop (- 1 )
466
+
467
+ return text
468
+
469
+
470
+ def _get_crud_params_14 (compiler , stmt , compile_state , ** kw ):
471
+ """create a set of tuples representing column/string pairs for use
472
+ in an INSERT or UPDATE statement.
473
+
474
+ Also generates the Compiled object's postfetch, prefetch, and
475
+ returning column collections, used for default handling and ultimately
476
+ populating the CursorResult's prefetch_cols() and postfetch_cols()
477
+ collections.
478
+
479
+ """
480
+ from sqlalchemy .sql .crud import _key_getters_for_crud_column
481
+ from sqlalchemy .sql .crud import _create_bind_param
482
+ from sqlalchemy .sql .crud import REQUIRED
483
+ from sqlalchemy .sql .crud import _get_stmt_parameter_tuples_params
484
+ from sqlalchemy .sql .crud import _get_multitable_params
485
+ from sqlalchemy .sql .crud import _scan_insert_from_select_cols
486
+ from sqlalchemy .sql .crud import _scan_cols
487
+ from sqlalchemy import exc # noqa: F401
488
+ from sqlalchemy .sql .crud import _extend_values_for_multiparams
489
+
490
+ compiler .postfetch = []
491
+ compiler .insert_prefetch = []
492
+ compiler .update_prefetch = []
493
+ compiler .returning = []
494
+
495
+ # getters - these are normally just column.key,
496
+ # but in the case of mysql multi-table update, the rules for
497
+ # .key must conditionally take tablename into account
498
+ (
499
+ _column_as_key ,
500
+ _getattr_col_key ,
501
+ _col_bind_name ,
502
+ ) = getters = _key_getters_for_crud_column (compiler , stmt , compile_state )
503
+
504
+ compiler ._key_getters_for_crud_column = getters
505
+
506
+ # no parameters in the statement, no parameters in the
507
+ # compiled params - return binds for all columns
508
+ if compiler .column_keys is None and compile_state ._no_parameters :
509
+ return [
510
+ (
511
+ c ,
512
+ compiler .preparer .format_column (c ),
513
+ _create_bind_param (compiler , c , None , required = True ),
514
+ )
515
+ for c in stmt .table .columns
516
+ ]
517
+
518
+ if compile_state ._has_multi_parameters :
519
+ spd = compile_state ._multi_parameters [0 ]
520
+ stmt_parameter_tuples = list (spd .items ())
521
+ elif compile_state ._ordered_values :
522
+ spd = compile_state ._dict_parameters
523
+ stmt_parameter_tuples = compile_state ._ordered_values
524
+ elif compile_state ._dict_parameters :
525
+ spd = compile_state ._dict_parameters
526
+ stmt_parameter_tuples = list (spd .items ())
527
+ else :
528
+ stmt_parameter_tuples = spd = None
529
+
530
+ # if we have statement parameters - set defaults in the
531
+ # compiled params
532
+ if compiler .column_keys is None :
533
+ parameters = {}
534
+ elif stmt_parameter_tuples :
535
+ parameters = dict (
536
+ (_column_as_key (key ), REQUIRED )
537
+ for key in compiler .column_keys
538
+ if key not in spd
539
+ )
540
+ else :
541
+ parameters = dict (
542
+ (_column_as_key (key ), REQUIRED ) for key in compiler .column_keys
543
+ )
544
+
545
+ # create a list of column assignment clauses as tuples
546
+ values = []
547
+
548
+ if stmt_parameter_tuples is not None :
549
+ _get_stmt_parameter_tuples_params (
550
+ compiler ,
551
+ compile_state ,
552
+ parameters ,
553
+ stmt_parameter_tuples ,
554
+ _column_as_key ,
555
+ values ,
556
+ kw ,
557
+ )
558
+
559
+ check_columns = {}
560
+
561
+ # special logic that only occurs for multi-table UPDATE
562
+ # statements
563
+ if compile_state .isupdate and compile_state .is_multitable :
564
+ _get_multitable_params (
565
+ compiler ,
566
+ stmt ,
567
+ compile_state ,
568
+ stmt_parameter_tuples ,
569
+ check_columns ,
570
+ _col_bind_name ,
571
+ _getattr_col_key ,
572
+ values ,
573
+ kw ,
574
+ )
575
+
576
+ if compile_state .isinsert and stmt ._select_names :
577
+ _scan_insert_from_select_cols (
578
+ compiler ,
579
+ stmt ,
580
+ compile_state ,
581
+ parameters ,
582
+ _getattr_col_key ,
583
+ _column_as_key ,
584
+ _col_bind_name ,
585
+ check_columns ,
586
+ values ,
587
+ kw ,
588
+ )
589
+ else :
590
+ _scan_cols (
591
+ compiler ,
592
+ stmt ,
593
+ compile_state ,
594
+ parameters ,
595
+ _getattr_col_key ,
596
+ _column_as_key ,
597
+ _col_bind_name ,
598
+ check_columns ,
599
+ values ,
600
+ kw ,
601
+ )
602
+
603
+ # CrateDB amendment.
604
+ # The rewriting logic in `rewrite_update` and `visit_update` needs
605
+ # adjustments here in order to prevent `sqlalchemy.exc.CompileError:
606
+ # Unconsumed column names: characters_name, data['nested']`
607
+ """
608
+ if parameters and stmt_parameter_tuples:
609
+ check = (
610
+ set(parameters)
611
+ .intersection(_column_as_key(k) for k, v in stmt_parameter_tuples)
612
+ .difference(check_columns)
613
+ )
614
+ if check:
615
+ raise exc.CompileError(
616
+ "Unconsumed column names: %s"
617
+ % (", ".join("%s" % (c,) for c in check))
618
+ )
619
+ """
620
+
621
+ if compile_state ._has_multi_parameters :
622
+ values = _extend_values_for_multiparams (
623
+ compiler , stmt , compile_state , values , kw
624
+ )
625
+ elif not values and compiler .for_executemany :
626
+ # convert an "INSERT DEFAULT VALUES"
627
+ # into INSERT (firstcol) VALUES (DEFAULT) which can be turned
628
+ # into an in-place multi values. This supports
629
+ # insert_executemany_returning mode :)
630
+ values = [
631
+ (
632
+ stmt .table .columns [0 ],
633
+ compiler .preparer .format_column (stmt .table .columns [0 ]),
634
+ "DEFAULT" ,
635
+ )
636
+ ]
637
+
638
+ return values
0 commit comments