@@ -433,159 +433,27 @@ impl AsyncPgConnection {
433
433
// so there is no need to even access the query in the async block below
434
434
let mut query_builder = PgQueryBuilder :: default ( ) ;
435
435
436
- let ( collect_bind_result, fake_oid_locations, generated_oids, bind_collector) = {
437
- // we don't resolve custom types here yet, we do that later
438
- // in the async block below as we might need to perform lookup
439
- // queries for that.
440
- //
441
- // We apply this workaround to prevent requiring all the diesel
442
- // serialization code to beeing async
443
- //
444
- // We give out constant fake oids here to optimize for the "happy" path
445
- // without custom type lookup
446
- let mut bind_collector_0 = RawBytesBindCollector :: < diesel:: pg:: Pg > :: new ( ) ;
447
- let mut metadata_lookup_0 = PgAsyncMetadataLookup {
448
- custom_oid : false ,
449
- generated_oids : None ,
450
- oid_generator : |_, _| ( FAKE_OID , FAKE_OID ) ,
451
- } ;
452
- let collect_bind_result_0 =
453
- query. collect_binds ( & mut bind_collector_0, & mut metadata_lookup_0, & Pg ) ;
454
-
455
- // we have encountered a custom type oid, so we need to perform more work here.
456
- // These oids can occure in two locations:
457
- //
458
- // * In the collected metadata -> relativly easy to resolve, just need to replace them below
459
- // * As part of the seralized bind blob -> hard to replace
460
- //
461
- // To address the second case, we perform a second run of the bind collector
462
- // with a different set of fake oids. Then we compare the output of the two runs
463
- // and use that information to infer where to replace bytes in the serialized output
464
-
465
- if metadata_lookup_0. custom_oid {
466
- // we try to get the maxium oid we encountered here
467
- // to be sure that we don't accidently give out a fake oid below that collides with
468
- // something
469
- let mut max_oid = bind_collector_0
470
- . metadata
471
- . iter ( )
472
- . flat_map ( |t| {
473
- [
474
- t. oid ( ) . unwrap_or_default ( ) ,
475
- t. array_oid ( ) . unwrap_or_default ( ) ,
476
- ]
477
- } )
478
- . max ( )
479
- . unwrap_or_default ( ) ;
480
- let mut bind_collector_1 = RawBytesBindCollector :: < diesel:: pg:: Pg > :: new ( ) ;
481
- let mut metadata_lookup_1 = PgAsyncMetadataLookup {
482
- custom_oid : false ,
483
- generated_oids : Some ( HashMap :: new ( ) ) ,
484
- oid_generator : move |_, _| {
485
- max_oid += 2 ;
486
- ( max_oid, max_oid + 1 )
487
- } ,
488
- } ;
489
- let collect_bind_result_2 =
490
- query. collect_binds ( & mut bind_collector_1, & mut metadata_lookup_1, & Pg ) ;
491
-
492
- assert_eq ! (
493
- bind_collector_0. binds. len( ) ,
494
- bind_collector_0. metadata. len( )
495
- ) ;
496
- let fake_oid_locations = std:: iter:: zip (
497
- bind_collector_0
498
- . binds
499
- . iter ( )
500
- . zip ( & bind_collector_0. metadata ) ,
501
- & bind_collector_1. binds ,
502
- )
503
- . enumerate ( )
504
- . flat_map ( |( bind_index, ( ( bytes_0, metadata_0) , bytes_1) ) | {
505
- // custom oids might appear in the serialized bind arguments for arrays or composite (record) types
506
- // in both cases the relevant buffer is a custom type on it's own
507
- // so we only need to check the cases that contain a fake OID on their own
508
- let ( bytes_0, bytes_1) = if matches ! ( metadata_0. oid( ) , Ok ( FAKE_OID ) ) {
509
- (
510
- bytes_0. as_deref ( ) . unwrap_or_default ( ) ,
511
- bytes_1. as_deref ( ) . unwrap_or_default ( ) ,
512
- )
513
- } else {
514
- // for all other cases, just return an empty
515
- // list to make the iteration below a no-op
516
- // and prevent the need of boxing
517
- ( & [ ] as & [ _ ] , & [ ] as & [ _ ] )
518
- } ;
519
- let lookup_map = metadata_lookup_1
520
- . generated_oids
521
- . as_ref ( )
522
- . map ( |map| {
523
- map. values ( )
524
- . flat_map ( |( oid, array_oid) | [ * oid, * array_oid] )
525
- . collect :: < HashSet < _ > > ( )
526
- } )
527
- . unwrap_or_default ( ) ;
528
- std:: iter:: zip (
529
- bytes_0. windows ( std:: mem:: size_of_val ( & FAKE_OID ) ) ,
530
- bytes_1. windows ( std:: mem:: size_of_val ( & FAKE_OID ) ) ,
531
- )
532
- . enumerate ( )
533
- . filter_map ( move |( byte_index, ( l, r) ) | {
534
- // here we infer if some byte sequence is a fake oid
535
- // We use the following conditions for that:
536
- //
537
- // * The first byte sequence matches the constant FAKE_OID
538
- // * The second sequence does not match the constant FAKE_OID
539
- // * The second sequence is contained in the set of generated oid,
540
- // otherwise we get false positives around the boundary
541
- // of a to be replaced byte sequence
542
- let r_val =
543
- u32:: from_be_bytes ( r. try_into ( ) . expect ( "That's the right size" ) ) ;
544
- ( l == FAKE_OID . to_be_bytes ( )
545
- && r != FAKE_OID . to_be_bytes ( )
546
- && lookup_map. contains ( & r_val) )
547
- . then_some ( ( bind_index, byte_index) )
548
- } )
549
- } )
550
- // Avoid storing the bind collectors in the returned Future
551
- . collect :: < Vec < _ > > ( ) ;
552
- (
553
- collect_bind_result_0. and ( collect_bind_result_2) ,
554
- fake_oid_locations,
555
- metadata_lookup_1. generated_oids ,
556
- bind_collector_1,
557
- )
558
- } else {
559
- ( collect_bind_result_0, Vec :: new ( ) , None , bind_collector_0)
560
- }
561
- } ;
436
+ let bind_data = construct_bind_data ( & query) ;
562
437
563
438
// The code that doesn't need the `T` generic parameter is in a separate function to reduce LLVM IR lines
564
439
self . with_prepared_statement_after_sql_built (
565
440
callback,
566
441
query. is_safe_to_cache_prepared ( & Pg ) ,
567
442
T :: query_id ( ) ,
568
443
query. to_sql ( & mut query_builder, & Pg ) ,
569
- collect_bind_result,
570
444
query_builder,
571
- bind_collector,
572
- fake_oid_locations,
573
- generated_oids,
445
+ bind_data,
574
446
)
575
447
}
576
448
577
- #[ allow( clippy:: too_many_arguments) ]
578
449
fn with_prepared_statement_after_sql_built < ' a , F , R > (
579
450
& mut self ,
580
451
callback : fn ( Arc < tokio_postgres:: Client > , Statement , Vec < ToSqlHelper > ) -> F ,
581
452
is_safe_to_cache_prepared : QueryResult < bool > ,
582
453
query_id : Option < std:: any:: TypeId > ,
583
454
to_sql_result : QueryResult < ( ) > ,
584
- collect_bind_result : QueryResult < ( ) > ,
585
455
query_builder : PgQueryBuilder ,
586
- mut bind_collector : RawBytesBindCollector < Pg > ,
587
- fake_oid_locations : Vec < ( usize , usize ) > ,
588
- generated_oids : GeneratedOidTypeMap ,
456
+ bind_data : BindData ,
589
457
) -> BoxFuture < ' a , QueryResult < R > >
590
458
where
591
459
F : Future < Output = QueryResult < R > > + Send + ' a ,
@@ -596,6 +464,12 @@ impl AsyncPgConnection {
596
464
let metadata_cache = self . metadata_cache . clone ( ) ;
597
465
let tm = self . transaction_state . clone ( ) ;
598
466
let instrumentation = self . instrumentation . clone ( ) ;
467
+ let BindData {
468
+ collect_bind_result,
469
+ fake_oid_locations,
470
+ generated_oids,
471
+ mut bind_collector,
472
+ } = bind_data;
599
473
600
474
async move {
601
475
let sql = to_sql_result. map ( |_| query_builder. finish ( ) ) ?;
@@ -710,6 +584,142 @@ impl AsyncPgConnection {
710
584
}
711
585
}
712
586
587
+ struct BindData {
588
+ collect_bind_result : Result < ( ) , Error > ,
589
+ fake_oid_locations : Vec < ( usize , usize ) > ,
590
+ generated_oids : GeneratedOidTypeMap ,
591
+ bind_collector : RawBytesBindCollector < Pg > ,
592
+ }
593
+
594
+ fn construct_bind_data ( query : & dyn QueryFragment < diesel:: pg:: Pg > ) -> BindData {
595
+ // we don't resolve custom types here yet, we do that later
596
+ // in the async block below as we might need to perform lookup
597
+ // queries for that.
598
+ //
599
+ // We apply this workaround to prevent requiring all the diesel
600
+ // serialization code to beeing async
601
+ //
602
+ // We give out constant fake oids here to optimize for the "happy" path
603
+ // without custom type lookup
604
+ let mut bind_collector_0 = RawBytesBindCollector :: < diesel:: pg:: Pg > :: new ( ) ;
605
+ let mut metadata_lookup_0 = PgAsyncMetadataLookup {
606
+ custom_oid : false ,
607
+ generated_oids : None ,
608
+ oid_generator : |_, _| ( FAKE_OID , FAKE_OID ) ,
609
+ } ;
610
+ let collect_bind_result_0 =
611
+ query. collect_binds ( & mut bind_collector_0, & mut metadata_lookup_0, & Pg ) ;
612
+ // we have encountered a custom type oid, so we need to perform more work here.
613
+ // These oids can occure in two locations:
614
+ //
615
+ // * In the collected metadata -> relativly easy to resolve, just need to replace them below
616
+ // * As part of the seralized bind blob -> hard to replace
617
+ //
618
+ // To address the second case, we perform a second run of the bind collector
619
+ // with a different set of fake oids. Then we compare the output of the two runs
620
+ // and use that information to infer where to replace bytes in the serialized output
621
+ if metadata_lookup_0. custom_oid {
622
+ // we try to get the maxium oid we encountered here
623
+ // to be sure that we don't accidently give out a fake oid below that collides with
624
+ // something
625
+ let mut max_oid = bind_collector_0
626
+ . metadata
627
+ . iter ( )
628
+ . flat_map ( |t| {
629
+ [
630
+ t. oid ( ) . unwrap_or_default ( ) ,
631
+ t. array_oid ( ) . unwrap_or_default ( ) ,
632
+ ]
633
+ } )
634
+ . max ( )
635
+ . unwrap_or_default ( ) ;
636
+ let mut bind_collector_1 = RawBytesBindCollector :: < diesel:: pg:: Pg > :: new ( ) ;
637
+ let mut metadata_lookup_1 = PgAsyncMetadataLookup {
638
+ custom_oid : false ,
639
+ generated_oids : Some ( HashMap :: new ( ) ) ,
640
+ oid_generator : move |_, _| {
641
+ max_oid += 2 ;
642
+ ( max_oid, max_oid + 1 )
643
+ } ,
644
+ } ;
645
+ let collect_bind_result_1 =
646
+ query. collect_binds ( & mut bind_collector_1, & mut metadata_lookup_1, & Pg ) ;
647
+
648
+ assert_eq ! (
649
+ bind_collector_0. binds. len( ) ,
650
+ bind_collector_0. metadata. len( )
651
+ ) ;
652
+ let fake_oid_locations = std:: iter:: zip (
653
+ bind_collector_0
654
+ . binds
655
+ . iter ( )
656
+ . zip ( & bind_collector_0. metadata ) ,
657
+ & bind_collector_1. binds ,
658
+ )
659
+ . enumerate ( )
660
+ . flat_map ( |( bind_index, ( ( bytes_0, metadata_0) , bytes_1) ) | {
661
+ // custom oids might appear in the serialized bind arguments for arrays or composite (record) types
662
+ // in both cases the relevant buffer is a custom type on it's own
663
+ // so we only need to check the cases that contain a fake OID on their own
664
+ let ( bytes_0, bytes_1) = if matches ! ( metadata_0. oid( ) , Ok ( FAKE_OID ) ) {
665
+ (
666
+ bytes_0. as_deref ( ) . unwrap_or_default ( ) ,
667
+ bytes_1. as_deref ( ) . unwrap_or_default ( ) ,
668
+ )
669
+ } else {
670
+ // for all other cases, just return an empty
671
+ // list to make the iteration below a no-op
672
+ // and prevent the need of boxing
673
+ ( & [ ] as & [ _ ] , & [ ] as & [ _ ] )
674
+ } ;
675
+ let lookup_map = metadata_lookup_1
676
+ . generated_oids
677
+ . as_ref ( )
678
+ . map ( |map| {
679
+ map. values ( )
680
+ . flat_map ( |( oid, array_oid) | [ * oid, * array_oid] )
681
+ . collect :: < HashSet < _ > > ( )
682
+ } )
683
+ . unwrap_or_default ( ) ;
684
+ std:: iter:: zip (
685
+ bytes_0. windows ( std:: mem:: size_of_val ( & FAKE_OID ) ) ,
686
+ bytes_1. windows ( std:: mem:: size_of_val ( & FAKE_OID ) ) ,
687
+ )
688
+ . enumerate ( )
689
+ . filter_map ( move |( byte_index, ( l, r) ) | {
690
+ // here we infer if some byte sequence is a fake oid
691
+ // We use the following conditions for that:
692
+ //
693
+ // * The first byte sequence matches the constant FAKE_OID
694
+ // * The second sequence does not match the constant FAKE_OID
695
+ // * The second sequence is contained in the set of generated oid,
696
+ // otherwise we get false positives around the boundary
697
+ // of a to be replaced byte sequence
698
+ let r_val = u32:: from_be_bytes ( r. try_into ( ) . expect ( "That's the right size" ) ) ;
699
+ ( l == FAKE_OID . to_be_bytes ( )
700
+ && r != FAKE_OID . to_be_bytes ( )
701
+ && lookup_map. contains ( & r_val) )
702
+ . then_some ( ( bind_index, byte_index) )
703
+ } )
704
+ } )
705
+ // Avoid storing the bind collectors in the returned Future
706
+ . collect :: < Vec < _ > > ( ) ;
707
+ BindData {
708
+ collect_bind_result : collect_bind_result_0. and ( collect_bind_result_1) ,
709
+ fake_oid_locations,
710
+ generated_oids : metadata_lookup_1. generated_oids ,
711
+ bind_collector : bind_collector_1,
712
+ }
713
+ } else {
714
+ BindData {
715
+ collect_bind_result : collect_bind_result_0,
716
+ fake_oid_locations : Vec :: new ( ) ,
717
+ generated_oids : None ,
718
+ bind_collector : bind_collector_0,
719
+ }
720
+ }
721
+ }
722
+
713
723
type GeneratedOidTypeMap = Option < HashMap < ( Option < String > , String ) , ( u32 , u32 ) > > ;
714
724
715
725
/// Collects types that need to be looked up, and causes fake OIDs to be written into the bind collector
0 commit comments