@@ -22,7 +22,7 @@ use futures_util::future::Either;
22
22
use futures_util:: stream:: { BoxStream , TryStreamExt } ;
23
23
use futures_util:: TryFutureExt ;
24
24
use futures_util:: { Future , FutureExt , StreamExt } ;
25
- use std:: collections:: HashMap ;
25
+ use std:: collections:: { HashMap , HashSet } ;
26
26
use std:: sync:: Arc ;
27
27
use tokio:: sync:: broadcast;
28
28
use tokio:: sync:: oneshot;
@@ -38,6 +38,8 @@ mod row;
38
38
mod serialize;
39
39
mod transaction_builder;
40
40
41
+ const FAKE_OID : u32 = 0 ;
42
+
41
43
/// A connection to a PostgreSQL database.
42
44
///
43
45
/// Connection URLs should be in the form
@@ -257,7 +259,7 @@ fn type_from_oid(t: &PgTypeMetadata) -> QueryResult<Type> {
257
259
}
258
260
259
261
Ok ( Type :: new (
260
- "diesel_custom_type" . into ( ) ,
262
+ format ! ( "diesel_custom_type_{oid}" ) ,
261
263
oid,
262
264
tokio_postgres:: types:: Kind :: Simple ,
263
265
"public" . into ( ) ,
@@ -357,43 +359,134 @@ impl AsyncPgConnection {
357
359
. to_sql ( & mut query_builder, & Pg )
358
360
. map ( |_| query_builder. finish ( ) ) ;
359
361
360
- let mut bind_collector = RawBytesBindCollector :: < diesel:: pg:: Pg > :: new ( ) ;
361
362
let query_id = T :: query_id ( ) ;
362
363
363
- // we don't resolve custom types here yet, we do that later
364
- // in the async block below as we might need to perform lookup
365
- // queries for that.
366
- //
367
- // We apply this workaround to prevent requiring all the diesel
368
- // serialization code to beeing async
369
- let mut bind_collector_0 = RawBytesBindCollector :: < diesel:: pg:: Pg > :: new ( ) ;
370
- let collect_bind_result_0 = query. collect_binds (
371
- & mut bind_collector_0,
372
- & mut SameOidEveryTime { first_byte : 0 } ,
373
- & Pg ,
374
- ) ;
375
-
376
- let mut bind_collector_1 = RawBytesBindCollector :: < diesel:: pg:: Pg > :: new ( ) ;
377
- let collect_bind_result_1 = query. collect_binds (
378
- & mut bind_collector_1,
379
- & mut SameOidEveryTime { first_byte : 1 } ,
380
- & Pg ,
381
- ) ;
382
-
383
- let mut metadata_lookup = PgAsyncMetadataLookup :: new ( & bind_collector_0) ;
384
- let collect_bind_result =
385
- query. collect_binds ( & mut bind_collector, & mut metadata_lookup, & Pg ) ;
386
-
387
- let fake_oid_locations = std:: iter:: zip ( bind_collector_0. binds , bind_collector_1. binds )
388
- . enumerate ( )
389
- . flat_map ( |( bind_index, ( bytes_0, bytes_1) ) | {
390
- std:: iter:: zip ( bytes_0. unwrap_or_default ( ) , bytes_1. unwrap_or_default ( ) )
364
+ let ( collect_bind_result, fake_oid_locations, generated_oids, mut bind_collector) = {
365
+ // we don't resolve custom types here yet, we do that later
366
+ // in the async block below as we might need to perform lookup
367
+ // queries for that.
368
+ //
369
+ // We apply this workaround to prevent requiring all the diesel
370
+ // serialization code to beeing async
371
+ //
372
+ // We give out constant fake oids here to optimize for the "happy" path
373
+ // without custom type lookup
374
+ let mut bind_collector_0 = RawBytesBindCollector :: < diesel:: pg:: Pg > :: new ( ) ;
375
+ let mut metadata_lookup_0 = PgAsyncMetadataLookup {
376
+ custom_oid : false ,
377
+ generated_oids : None ,
378
+ oid_generator : |_, _| ( FAKE_OID , FAKE_OID ) ,
379
+ } ;
380
+ let collect_bind_result_0 =
381
+ query. collect_binds ( & mut bind_collector_0, & mut metadata_lookup_0, & Pg ) ;
382
+
383
+ // we have encountered a custom type oid, so we need to perform more work here.
384
+ // These oids can occure in two locations:
385
+ //
386
+ // * In the collected metadata -> relativly easy to resolve, just need to replace them below
387
+ // * As part of the seralized bind blob -> hard to replace
388
+ //
389
+ // To address the second case, we perform a second run of the bind collector
390
+ // with a different set of fake oids. Then we compare the output of the two runs
391
+ // and use that information to infer where to replace bytes in the serialized output
392
+
393
+ if metadata_lookup_0. custom_oid {
394
+ // we try to get the maxium oid we encountered here
395
+ // to be sure that we don't accidently give out a fake oid below that collides with
396
+ // something
397
+ let mut max_oid = bind_collector_0
398
+ . metadata
399
+ . iter ( )
400
+ . flat_map ( |t| {
401
+ [
402
+ t. oid ( ) . unwrap_or_default ( ) ,
403
+ t. array_oid ( ) . unwrap_or_default ( ) ,
404
+ ]
405
+ } )
406
+ . max ( )
407
+ . unwrap_or_default ( ) ;
408
+ let mut bind_collector_1 = RawBytesBindCollector :: < diesel:: pg:: Pg > :: new ( ) ;
409
+ let mut metadata_lookup_1 = PgAsyncMetadataLookup {
410
+ custom_oid : false ,
411
+ generated_oids : Some ( HashMap :: new ( ) ) ,
412
+ oid_generator : move |_, _| {
413
+ max_oid += 2 ;
414
+ ( max_oid, max_oid + 1 )
415
+ } ,
416
+ } ;
417
+ let collect_bind_result_2 =
418
+ query. collect_binds ( & mut bind_collector_1, & mut metadata_lookup_1, & Pg ) ;
419
+
420
+ assert_eq ! (
421
+ bind_collector_0. binds. len( ) ,
422
+ bind_collector_0. metadata. len( )
423
+ ) ;
424
+ let fake_oid_locations = std:: iter:: zip (
425
+ bind_collector_0
426
+ . binds
427
+ . iter ( )
428
+ . zip ( & bind_collector_0. metadata ) ,
429
+ & bind_collector_1. binds ,
430
+ )
431
+ . enumerate ( )
432
+ . flat_map ( |( bind_index, ( ( bytes_0, metadata_0) , bytes_1) ) | {
433
+ // custom oids might appear in the serialized bind arguments for arrays or composite (record) types
434
+ // in both cases the relevant buffer is a custom type on it's own
435
+ // so we only need to check the cases that contain a fake OID on their own
436
+ let ( bytes_0, bytes_1) = if matches ! ( metadata_0. oid( ) , Ok ( FAKE_OID ) ) {
437
+ (
438
+ bytes_0. as_deref ( ) . unwrap_or_default ( ) ,
439
+ bytes_1. as_deref ( ) . unwrap_or_default ( ) ,
440
+ )
441
+ } else {
442
+ // for all other cases, just return an empty
443
+ // list to make the iteration below a no-op
444
+ // and prevent the need of boxing
445
+ ( & [ ] as & [ _ ] , & [ ] as & [ _ ] )
446
+ } ;
447
+ let lookup_map = metadata_lookup_1
448
+ . generated_oids
449
+ . as_ref ( )
450
+ . map ( |map| {
451
+ map. values ( )
452
+ . flat_map ( |( oid, array_oid) | [ * oid, * array_oid] )
453
+ . collect :: < HashSet < _ > > ( )
454
+ } )
455
+ . unwrap_or_default ( ) ;
456
+ std:: iter:: zip (
457
+ bytes_0. windows ( std:: mem:: size_of_val ( & FAKE_OID ) ) ,
458
+ bytes_1. windows ( std:: mem:: size_of_val ( & FAKE_OID ) ) ,
459
+ )
391
460
. enumerate ( )
392
- . filter ( |& ( _, bytes) | bytes == ( 0 , 1 ) )
393
- . map ( move |( byte_index, _) | ( bind_index, byte_index) )
394
- } )
395
- // Avoid storing the bind collectors in the returned Future
396
- . collect :: < Vec < _ > > ( ) ;
461
+ . filter_map ( move |( byte_index, ( l, r) ) | {
462
+ // here we infer if some byte sequence is a fake oid
463
+ // We use the following conditions for that:
464
+ //
465
+ // * The first byte sequence matches the constant FAKE_OID
466
+ // * The second sequence does not match the constant FAKE_OID
467
+ // * The second sequence is contained in the set of generated oid,
468
+ // otherwise we get false positives around the boundary
469
+ // of a to be replaced byte sequence
470
+ let r_val =
471
+ u32:: from_be_bytes ( r. try_into ( ) . expect ( "That's the right size" ) ) ;
472
+ ( l == FAKE_OID . to_be_bytes ( )
473
+ && r != FAKE_OID . to_be_bytes ( )
474
+ && lookup_map. contains ( & r_val) )
475
+ . then_some ( ( bind_index, byte_index) )
476
+ } )
477
+ } )
478
+ // Avoid storing the bind collectors in the returned Future
479
+ . collect :: < Vec < _ > > ( ) ;
480
+ (
481
+ collect_bind_result_0. and ( collect_bind_result_2) ,
482
+ fake_oid_locations,
483
+ metadata_lookup_1. generated_oids ,
484
+ bind_collector_1,
485
+ )
486
+ } else {
487
+ ( collect_bind_result_0, Vec :: new ( ) , None , bind_collector_0)
488
+ }
489
+ } ;
397
490
398
491
let raw_connection = self . conn . clone ( ) ;
399
492
let stmt_cache = self . stmt_cache . clone ( ) ;
@@ -403,59 +496,49 @@ impl AsyncPgConnection {
403
496
async move {
404
497
let sql = sql?;
405
498
let is_safe_to_cache_prepared = is_safe_to_cache_prepared?;
406
- collect_bind_result_0?;
407
- collect_bind_result_1?;
408
499
collect_bind_result?;
409
500
// Check whether we need to resolve some types at all
410
501
//
411
502
// If the user doesn't use custom types there is no need
412
503
// to borther with that at all
413
- if !metadata_lookup . unresolved_types . is_empty ( ) {
504
+ if let Some ( ref unresolved_types ) = generated_oids {
414
505
let metadata_cache = & mut * metadata_cache. lock ( ) . await ;
415
- let mut real_oids = HashMap :: < u32 , u32 > :: new ( ) ;
506
+ let mut real_oids = HashMap :: new ( ) ;
416
507
417
- for ( index, ( schema, lookup_type_name) ) in metadata_lookup. unresolved_types . iter ( ) . enumerate ( ) {
508
+ for ( ( schema, lookup_type_name) , ( fake_oid, fake_array_oid) ) in
509
+ unresolved_types
510
+ {
418
511
// for each unresolved item
419
512
// we check whether it's arleady in the cache
420
513
// or perform a lookup and insert it into the cache
421
514
let cache_key = PgMetadataCacheKey :: new (
422
- schema. as_ref ( ) . map ( Into :: into) ,
515
+ schema. as_deref ( ) . map ( Into :: into) ,
423
516
lookup_type_name. into ( ) ,
424
517
) ;
425
- let real_metadata = if let Some ( type_metadata) = metadata_cache. lookup_type ( & cache_key) {
518
+ let real_metadata = if let Some ( type_metadata) =
519
+ metadata_cache. lookup_type ( & cache_key)
520
+ {
426
521
type_metadata
427
522
} else {
428
- let type_metadata = lookup_type (
429
- schema. clone ( ) ,
430
- lookup_type_name. clone ( ) ,
431
- & raw_connection,
432
- )
433
- . await ?;
523
+ let type_metadata =
524
+ lookup_type ( schema. clone ( ) , lookup_type_name. clone ( ) , & raw_connection)
525
+ . await ?;
434
526
metadata_cache. store_type ( cache_key, type_metadata) ;
435
527
436
528
PgTypeMetadata :: from_result ( Ok ( type_metadata) )
437
529
} ;
438
- let ( fake_oid, fake_array_oid) = metadata_lookup. fake_oids ( index) ;
439
- let [ real_oid, real_array_oid] = unwrap_oids ( & real_metadata) ;
440
- real_oids. extend ( [
441
- ( fake_oid, real_oid) ,
442
- ( fake_array_oid, real_array_oid) ,
443
- ] ) ;
530
+ // let (fake_oid, fake_array_oid) = metadata_lookup.fake_oids(index);
531
+ let ( real_oid, real_array_oid) = unwrap_oids ( & real_metadata) ;
532
+ real_oids. extend ( [ ( * fake_oid, real_oid) , ( * fake_array_oid, real_array_oid) ] ) ;
444
533
}
445
534
446
535
// Replace fake OIDs with real OIDs in `bind_collector.metadata`
447
536
for m in & mut bind_collector. metadata {
448
- let [ oid, array_oid] = unwrap_oids ( & m)
449
- . map ( |oid| {
450
- real_oids
451
- . get ( & oid)
452
- . copied ( )
453
- // If `oid` is not a key in `real_oids`, then `HasSqlType::metadata` returned it as a
454
- // hardcoded value instead of being lied to by `PgAsyncMetadataLookup`. In this case,
455
- // the existing value is already the real OID, so it's kept.
456
- . unwrap_or ( oid)
457
- } ) ;
458
- * m = PgTypeMetadata :: new ( oid, array_oid) ;
537
+ let ( oid, array_oid) = unwrap_oids ( m) ;
538
+ * m = PgTypeMetadata :: new (
539
+ real_oids. get ( & oid) . copied ( ) . unwrap_or ( oid) ,
540
+ real_oids. get ( & array_oid) . copied ( ) . unwrap_or ( array_oid)
541
+ ) ;
459
542
}
460
543
// Replace fake OIDs with real OIDs in `bind_collector.binds`
461
544
for ( bind_index, byte_index) in fake_oid_locations {
@@ -503,53 +586,31 @@ impl AsyncPgConnection {
503
586
}
504
587
}
505
588
589
+ type GeneratedOidTypeMap = Option < HashMap < ( Option < String > , String ) , ( u32 , u32 ) > > ;
590
+
506
591
/// Collects types that need to be looked up, and causes fake OIDs to be written into the bind collector
507
592
/// so they can be replaced with asynchronously fetched OIDs after the original query is dropped
508
- struct PgAsyncMetadataLookup {
509
- unresolved_types : Vec < ( Option < String > , String ) > ,
510
- min_fake_oid : u32 ,
593
+ struct PgAsyncMetadataLookup < F : FnMut ( & str , Option < & str > ) -> ( u32 , u32 ) + ' static > {
594
+ custom_oid : bool ,
595
+ generated_oids : GeneratedOidTypeMap ,
596
+ oid_generator : F ,
511
597
}
512
598
513
- impl PgAsyncMetadataLookup {
514
- fn new ( bind_collector_0 : & RawBytesBindCollector < Pg > ) -> Self {
515
- let max_hardcoded_oid = bind_collector_0
516
- . metadata
517
- . iter ( )
518
- . flat_map ( |m| [ m. oid ( ) . unwrap_or ( 0 ) , m. array_oid ( ) . unwrap_or ( 0 ) ] )
519
- . max ( )
520
- . unwrap_or ( 0 ) ;
521
- Self {
522
- unresolved_types : Vec :: new ( ) ,
523
- min_fake_oid : max_hardcoded_oid + 1 ,
524
- }
525
- }
526
-
527
- fn fake_oids ( & self , index : usize ) -> ( u32 , u32 ) {
528
- let oid = self . min_fake_oid + ( ( index as u32 ) * 2 ) ;
529
- ( oid, oid + 1 )
530
- }
531
- }
532
-
533
- impl PgMetadataLookup for PgAsyncMetadataLookup {
599
+ impl < F > PgMetadataLookup for PgAsyncMetadataLookup < F >
600
+ where
601
+ F : FnMut ( & str , Option < & str > ) -> ( u32 , u32 ) + ' static ,
602
+ {
534
603
fn lookup_type ( & mut self , type_name : & str , schema : Option < & str > ) -> PgTypeMetadata {
535
- let index = self . unresolved_types . len ( ) ;
536
- self . unresolved_types
537
- . push ( ( schema. map ( ToOwned :: to_owned) , type_name. to_owned ( ) ) ) ;
538
- PgTypeMetadata :: from_result ( Ok ( self . fake_oids ( index) ) )
539
- }
540
- }
604
+ self . custom_oid = true ;
541
605
542
- /// Allows unambiguously determining:
543
- /// * where OIDs are written in `bind_collector.binds` after being returned by `lookup_type`
544
- /// * determining the maximum hardcoded OID in `bind_collector.metadata`
545
- struct SameOidEveryTime {
546
- first_byte : u8 ,
547
- }
606
+ let oid = if let Some ( map ) = & mut self . generated_oids {
607
+ * map . entry ( ( schema . map ( ToOwned :: to_owned ) , type_name . to_owned ( ) ) )
608
+ . or_insert_with ( || ( self . oid_generator ) ( type_name , schema ) )
609
+ } else {
610
+ ( self . oid_generator ) ( type_name , schema )
611
+ } ;
548
612
549
- impl PgMetadataLookup for SameOidEveryTime {
550
- fn lookup_type ( & mut self , _type_name : & str , _schema : Option < & str > ) -> PgTypeMetadata {
551
- let oid = u32:: from_be_bytes ( [ self . first_byte , 0 , 0 , 0 ] ) ;
552
- PgTypeMetadata :: new ( oid, oid)
613
+ PgTypeMetadata :: from_result ( Ok ( oid) )
553
614
}
554
615
}
555
616
@@ -583,9 +644,12 @@ async fn lookup_type(
583
644
Ok ( ( r. get ( 0 ) , r. get ( 1 ) ) )
584
645
}
585
646
586
- fn unwrap_oids ( metadata : & PgTypeMetadata ) -> [ u32 ; 2 ] {
587
- [ metadata. oid ( ) . ok ( ) , metadata. array_oid ( ) . ok ( ) ]
588
- . map ( |oid| oid. expect ( "PgTypeMetadata is supposed to always be Ok here" ) )
647
+ fn unwrap_oids ( metadata : & PgTypeMetadata ) -> ( u32 , u32 ) {
648
+ let err_msg = "PgTypeMetadata is supposed to always be Ok here" ;
649
+ (
650
+ metadata. oid ( ) . expect ( err_msg) ,
651
+ metadata. array_oid ( ) . expect ( err_msg) ,
652
+ )
589
653
}
590
654
591
655
fn replace_fake_oid (
0 commit comments