Skip to content

Commit 1d0372b

Browse files
committed
More optimizations
* Do not generate a second bind collector if we don't encounter a custom oid at all * Do not generate a third bind collector at all, we don't need that * Skip comparing buffers for types without custom oids as they won't contain any difference * Minor cleanup + documentation of the approach
1 parent ab42065 commit 1d0372b

File tree

2 files changed

+199
-113
lines changed

2 files changed

+199
-113
lines changed

src/pg/mod.rs

Lines changed: 173 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ use futures_util::future::Either;
2222
use futures_util::stream::{BoxStream, TryStreamExt};
2323
use futures_util::TryFutureExt;
2424
use futures_util::{Future, FutureExt, StreamExt};
25-
use std::collections::HashMap;
25+
use std::collections::{HashMap, HashSet};
2626
use std::sync::Arc;
2727
use tokio::sync::broadcast;
2828
use tokio::sync::oneshot;
@@ -38,6 +38,8 @@ mod row;
3838
mod serialize;
3939
mod transaction_builder;
4040

41+
const FAKE_OID: u32 = 0;
42+
4143
/// A connection to a PostgreSQL database.
4244
///
4345
/// Connection URLs should be in the form
@@ -257,7 +259,7 @@ fn type_from_oid(t: &PgTypeMetadata) -> QueryResult<Type> {
257259
}
258260

259261
Ok(Type::new(
260-
"diesel_custom_type".into(),
262+
format!("diesel_custom_type_{oid}"),
261263
oid,
262264
tokio_postgres::types::Kind::Simple,
263265
"public".into(),
@@ -357,43 +359,134 @@ impl AsyncPgConnection {
357359
.to_sql(&mut query_builder, &Pg)
358360
.map(|_| query_builder.finish());
359361

360-
let mut bind_collector = RawBytesBindCollector::<diesel::pg::Pg>::new();
361362
let query_id = T::query_id();
362363

363-
// we don't resolve custom types here yet, we do that later
364-
// in the async block below as we might need to perform lookup
365-
// queries for that.
366-
//
367-
// We apply this workaround to prevent requiring all the diesel
368-
// serialization code to beeing async
369-
let mut bind_collector_0 = RawBytesBindCollector::<diesel::pg::Pg>::new();
370-
let collect_bind_result_0 = query.collect_binds(
371-
&mut bind_collector_0,
372-
&mut SameOidEveryTime { first_byte: 0 },
373-
&Pg,
374-
);
375-
376-
let mut bind_collector_1 = RawBytesBindCollector::<diesel::pg::Pg>::new();
377-
let collect_bind_result_1 = query.collect_binds(
378-
&mut bind_collector_1,
379-
&mut SameOidEveryTime { first_byte: 1 },
380-
&Pg,
381-
);
382-
383-
let mut metadata_lookup = PgAsyncMetadataLookup::new(&bind_collector_0);
384-
let collect_bind_result =
385-
query.collect_binds(&mut bind_collector, &mut metadata_lookup, &Pg);
386-
387-
let fake_oid_locations = std::iter::zip(bind_collector_0.binds, bind_collector_1.binds)
388-
.enumerate()
389-
.flat_map(|(bind_index, (bytes_0, bytes_1))| {
390-
std::iter::zip(bytes_0.unwrap_or_default(), bytes_1.unwrap_or_default())
364+
let (collect_bind_result, fake_oid_locations, generated_oids, mut bind_collector) = {
365+
// we don't resolve custom types here yet, we do that later
366+
// in the async block below as we might need to perform lookup
367+
// queries for that.
368+
//
369+
// We apply this workaround to prevent requiring all the diesel
370+
// serialization code to beeing async
371+
//
372+
// We give out constant fake oids here to optimize for the "happy" path
373+
// without custom type lookup
374+
let mut bind_collector_0 = RawBytesBindCollector::<diesel::pg::Pg>::new();
375+
let mut metadata_lookup_0 = PgAsyncMetadataLookup {
376+
custom_oid: false,
377+
generated_oids: None,
378+
oid_generator: |_, _| (FAKE_OID, FAKE_OID),
379+
};
380+
let collect_bind_result_0 =
381+
query.collect_binds(&mut bind_collector_0, &mut metadata_lookup_0, &Pg);
382+
383+
// we have encountered a custom type oid, so we need to perform more work here.
384+
// These oids can occure in two locations:
385+
//
386+
// * In the collected metadata -> relativly easy to resolve, just need to replace them below
387+
// * As part of the seralized bind blob -> hard to replace
388+
//
389+
// To address the second case, we perform a second run of the bind collector
390+
// with a different set of fake oids. Then we compare the output of the two runs
391+
// and use that information to infer where to replace bytes in the serialized output
392+
393+
if metadata_lookup_0.custom_oid {
394+
// we try to get the maxium oid we encountered here
395+
// to be sure that we don't accidently give out a fake oid below that collides with
396+
// something
397+
let mut max_oid = bind_collector_0
398+
.metadata
399+
.iter()
400+
.flat_map(|t| {
401+
[
402+
t.oid().unwrap_or_default(),
403+
t.array_oid().unwrap_or_default(),
404+
]
405+
})
406+
.max()
407+
.unwrap_or_default();
408+
let mut bind_collector_1 = RawBytesBindCollector::<diesel::pg::Pg>::new();
409+
let mut metadata_lookup_1 = PgAsyncMetadataLookup {
410+
custom_oid: false,
411+
generated_oids: Some(HashMap::new()),
412+
oid_generator: move |_, _| {
413+
max_oid += 2;
414+
(max_oid, max_oid + 1)
415+
},
416+
};
417+
let collect_bind_result_2 =
418+
query.collect_binds(&mut bind_collector_1, &mut metadata_lookup_1, &Pg);
419+
420+
assert_eq!(
421+
bind_collector_0.binds.len(),
422+
bind_collector_0.metadata.len()
423+
);
424+
let fake_oid_locations = std::iter::zip(
425+
bind_collector_0
426+
.binds
427+
.iter()
428+
.zip(&bind_collector_0.metadata),
429+
&bind_collector_1.binds,
430+
)
431+
.enumerate()
432+
.flat_map(|(bind_index, ((bytes_0, metadata_0), bytes_1))| {
433+
// custom oids might appear in the serialized bind arguments for arrays or composite (record) types
434+
// in both cases the relevant buffer is a custom type on it's own
435+
// so we only need to check the cases that contain a fake OID on their own
436+
let (bytes_0, bytes_1) = if matches!(metadata_0.oid(), Ok(FAKE_OID)) {
437+
(
438+
bytes_0.as_deref().unwrap_or_default(),
439+
bytes_1.as_deref().unwrap_or_default(),
440+
)
441+
} else {
442+
// for all other cases, just return an empty
443+
// list to make the iteration below a no-op
444+
// and prevent the need of boxing
445+
(&[] as &[_], &[] as &[_])
446+
};
447+
let lookup_map = metadata_lookup_1
448+
.generated_oids
449+
.as_ref()
450+
.map(|map| {
451+
map.values()
452+
.flat_map(|(oid, array_oid)| [*oid, *array_oid])
453+
.collect::<HashSet<_>>()
454+
})
455+
.unwrap_or_default();
456+
std::iter::zip(
457+
bytes_0.windows(std::mem::size_of_val(&FAKE_OID)),
458+
bytes_1.windows(std::mem::size_of_val(&FAKE_OID)),
459+
)
391460
.enumerate()
392-
.filter(|&(_, bytes)| bytes == (0, 1))
393-
.map(move |(byte_index, _)| (bind_index, byte_index))
394-
})
395-
// Avoid storing the bind collectors in the returned Future
396-
.collect::<Vec<_>>();
461+
.filter_map(move |(byte_index, (l, r))| {
462+
// here we infer if some byte sequence is a fake oid
463+
// We use the following conditions for that:
464+
//
465+
// * The first byte sequence matches the constant FAKE_OID
466+
// * The second sequence does not match the constant FAKE_OID
467+
// * The second sequence is contained in the set of generated oid,
468+
// otherwise we get false positives around the boundary
469+
// of a to be replaced byte sequence
470+
let r_val =
471+
u32::from_be_bytes(r.try_into().expect("That's the right size"));
472+
(l == FAKE_OID.to_be_bytes()
473+
&& r != FAKE_OID.to_be_bytes()
474+
&& lookup_map.contains(&r_val))
475+
.then_some((bind_index, byte_index))
476+
})
477+
})
478+
// Avoid storing the bind collectors in the returned Future
479+
.collect::<Vec<_>>();
480+
(
481+
collect_bind_result_0.and(collect_bind_result_2),
482+
fake_oid_locations,
483+
metadata_lookup_1.generated_oids,
484+
bind_collector_1,
485+
)
486+
} else {
487+
(collect_bind_result_0, Vec::new(), None, bind_collector_0)
488+
}
489+
};
397490

398491
let raw_connection = self.conn.clone();
399492
let stmt_cache = self.stmt_cache.clone();
@@ -403,59 +496,49 @@ impl AsyncPgConnection {
403496
async move {
404497
let sql = sql?;
405498
let is_safe_to_cache_prepared = is_safe_to_cache_prepared?;
406-
collect_bind_result_0?;
407-
collect_bind_result_1?;
408499
collect_bind_result?;
409500
// Check whether we need to resolve some types at all
410501
//
411502
// If the user doesn't use custom types there is no need
412503
// to borther with that at all
413-
if !metadata_lookup.unresolved_types.is_empty() {
504+
if let Some(ref unresolved_types) = generated_oids {
414505
let metadata_cache = &mut *metadata_cache.lock().await;
415-
let mut real_oids = HashMap::<u32, u32>::new();
506+
let mut real_oids = HashMap::new();
416507

417-
for (index, (schema, lookup_type_name)) in metadata_lookup.unresolved_types.iter().enumerate() {
508+
for ((schema, lookup_type_name), (fake_oid, fake_array_oid)) in
509+
unresolved_types
510+
{
418511
// for each unresolved item
419512
// we check whether it's arleady in the cache
420513
// or perform a lookup and insert it into the cache
421514
let cache_key = PgMetadataCacheKey::new(
422-
schema.as_ref().map(Into::into),
515+
schema.as_deref().map(Into::into),
423516
lookup_type_name.into(),
424517
);
425-
let real_metadata = if let Some(type_metadata) = metadata_cache.lookup_type(&cache_key) {
518+
let real_metadata = if let Some(type_metadata) =
519+
metadata_cache.lookup_type(&cache_key)
520+
{
426521
type_metadata
427522
} else {
428-
let type_metadata = lookup_type(
429-
schema.clone(),
430-
lookup_type_name.clone(),
431-
&raw_connection,
432-
)
433-
.await?;
523+
let type_metadata =
524+
lookup_type(schema.clone(), lookup_type_name.clone(), &raw_connection)
525+
.await?;
434526
metadata_cache.store_type(cache_key, type_metadata);
435527

436528
PgTypeMetadata::from_result(Ok(type_metadata))
437529
};
438-
let (fake_oid, fake_array_oid) = metadata_lookup.fake_oids(index);
439-
let [real_oid, real_array_oid] = unwrap_oids(&real_metadata);
440-
real_oids.extend([
441-
(fake_oid, real_oid),
442-
(fake_array_oid, real_array_oid),
443-
]);
530+
// let (fake_oid, fake_array_oid) = metadata_lookup.fake_oids(index);
531+
let (real_oid, real_array_oid) = unwrap_oids(&real_metadata);
532+
real_oids.extend([(*fake_oid, real_oid), (*fake_array_oid, real_array_oid)]);
444533
}
445534

446535
// Replace fake OIDs with real OIDs in `bind_collector.metadata`
447536
for m in &mut bind_collector.metadata {
448-
let [oid, array_oid] = unwrap_oids(&m)
449-
.map(|oid| {
450-
real_oids
451-
.get(&oid)
452-
.copied()
453-
// If `oid` is not a key in `real_oids`, then `HasSqlType::metadata` returned it as a
454-
// hardcoded value instead of being lied to by `PgAsyncMetadataLookup`. In this case,
455-
// the existing value is already the real OID, so it's kept.
456-
.unwrap_or(oid)
457-
});
458-
*m = PgTypeMetadata::new(oid, array_oid);
537+
let (oid, array_oid) = unwrap_oids(m);
538+
*m = PgTypeMetadata::new(
539+
real_oids.get(&oid).copied().unwrap_or(oid),
540+
real_oids.get(&array_oid).copied().unwrap_or(array_oid)
541+
);
459542
}
460543
// Replace fake OIDs with real OIDs in `bind_collector.binds`
461544
for (bind_index, byte_index) in fake_oid_locations {
@@ -503,53 +586,31 @@ impl AsyncPgConnection {
503586
}
504587
}
505588

589+
type GeneratedOidTypeMap = Option<HashMap<(Option<String>, String), (u32, u32)>>;
590+
506591
/// Collects types that need to be looked up, and causes fake OIDs to be written into the bind collector
507592
/// so they can be replaced with asynchronously fetched OIDs after the original query is dropped
508-
struct PgAsyncMetadataLookup {
509-
unresolved_types: Vec<(Option<String>, String)>,
510-
min_fake_oid: u32,
593+
struct PgAsyncMetadataLookup<F: FnMut(&str, Option<&str>) -> (u32, u32) + 'static> {
594+
custom_oid: bool,
595+
generated_oids: GeneratedOidTypeMap,
596+
oid_generator: F,
511597
}
512598

513-
impl PgAsyncMetadataLookup {
514-
fn new(bind_collector_0: &RawBytesBindCollector<Pg>) -> Self {
515-
let max_hardcoded_oid = bind_collector_0
516-
.metadata
517-
.iter()
518-
.flat_map(|m| [m.oid().unwrap_or(0), m.array_oid().unwrap_or(0)])
519-
.max()
520-
.unwrap_or(0);
521-
Self {
522-
unresolved_types: Vec::new(),
523-
min_fake_oid: max_hardcoded_oid + 1,
524-
}
525-
}
526-
527-
fn fake_oids(&self, index: usize) -> (u32, u32) {
528-
let oid = self.min_fake_oid + ((index as u32) * 2);
529-
(oid, oid + 1)
530-
}
531-
}
532-
533-
impl PgMetadataLookup for PgAsyncMetadataLookup {
599+
impl<F> PgMetadataLookup for PgAsyncMetadataLookup<F>
600+
where
601+
F: FnMut(&str, Option<&str>) -> (u32, u32) + 'static,
602+
{
534603
fn lookup_type(&mut self, type_name: &str, schema: Option<&str>) -> PgTypeMetadata {
535-
let index = self.unresolved_types.len();
536-
self.unresolved_types
537-
.push((schema.map(ToOwned::to_owned), type_name.to_owned()));
538-
PgTypeMetadata::from_result(Ok(self.fake_oids(index)))
539-
}
540-
}
604+
self.custom_oid = true;
541605

542-
/// Allows unambiguously determining:
543-
/// * where OIDs are written in `bind_collector.binds` after being returned by `lookup_type`
544-
/// * determining the maximum hardcoded OID in `bind_collector.metadata`
545-
struct SameOidEveryTime {
546-
first_byte: u8,
547-
}
606+
let oid = if let Some(map) = &mut self.generated_oids {
607+
*map.entry((schema.map(ToOwned::to_owned), type_name.to_owned()))
608+
.or_insert_with(|| (self.oid_generator)(type_name, schema))
609+
} else {
610+
(self.oid_generator)(type_name, schema)
611+
};
548612

549-
impl PgMetadataLookup for SameOidEveryTime {
550-
fn lookup_type(&mut self, _type_name: &str, _schema: Option<&str>) -> PgTypeMetadata {
551-
let oid = u32::from_be_bytes([self.first_byte, 0, 0, 0]);
552-
PgTypeMetadata::new(oid, oid)
613+
PgTypeMetadata::from_result(Ok(oid))
553614
}
554615
}
555616

@@ -583,9 +644,12 @@ async fn lookup_type(
583644
Ok((r.get(0), r.get(1)))
584645
}
585646

586-
fn unwrap_oids(metadata: &PgTypeMetadata) -> [u32; 2] {
587-
[metadata.oid().ok(), metadata.array_oid().ok()]
588-
.map(|oid| oid.expect("PgTypeMetadata is supposed to always be Ok here"))
647+
fn unwrap_oids(metadata: &PgTypeMetadata) -> (u32, u32) {
648+
let err_msg = "PgTypeMetadata is supposed to always be Ok here";
649+
(
650+
metadata.oid().expect(err_msg),
651+
metadata.array_oid().expect(err_msg),
652+
)
589653
}
590654

591655
fn replace_fake_oid(

0 commit comments

Comments
 (0)