Skip to content

Commit ff0ccf3

Browse files
committed
Some more cleanup
1 parent 3d8b5a5 commit ff0ccf3

File tree

1 file changed

+145
-135
lines changed

1 file changed

+145
-135
lines changed

src/pg/mod.rs

Lines changed: 145 additions & 135 deletions
Original file line numberDiff line numberDiff line change
@@ -433,159 +433,27 @@ impl AsyncPgConnection {
433433
// so there is no need to even access the query in the async block below
434434
let mut query_builder = PgQueryBuilder::default();
435435

436-
let (collect_bind_result, fake_oid_locations, generated_oids, bind_collector) = {
437-
// we don't resolve custom types here yet, we do that later
438-
// in the async block below as we might need to perform lookup
439-
// queries for that.
440-
//
441-
// We apply this workaround to prevent requiring all the diesel
442-
// serialization code to beeing async
443-
//
444-
// We give out constant fake oids here to optimize for the "happy" path
445-
// without custom type lookup
446-
let mut bind_collector_0 = RawBytesBindCollector::<diesel::pg::Pg>::new();
447-
let mut metadata_lookup_0 = PgAsyncMetadataLookup {
448-
custom_oid: false,
449-
generated_oids: None,
450-
oid_generator: |_, _| (FAKE_OID, FAKE_OID),
451-
};
452-
let collect_bind_result_0 =
453-
query.collect_binds(&mut bind_collector_0, &mut metadata_lookup_0, &Pg);
454-
455-
// we have encountered a custom type oid, so we need to perform more work here.
456-
// These oids can occure in two locations:
457-
//
458-
// * In the collected metadata -> relativly easy to resolve, just need to replace them below
459-
// * As part of the seralized bind blob -> hard to replace
460-
//
461-
// To address the second case, we perform a second run of the bind collector
462-
// with a different set of fake oids. Then we compare the output of the two runs
463-
// and use that information to infer where to replace bytes in the serialized output
464-
465-
if metadata_lookup_0.custom_oid {
466-
// we try to get the maxium oid we encountered here
467-
// to be sure that we don't accidently give out a fake oid below that collides with
468-
// something
469-
let mut max_oid = bind_collector_0
470-
.metadata
471-
.iter()
472-
.flat_map(|t| {
473-
[
474-
t.oid().unwrap_or_default(),
475-
t.array_oid().unwrap_or_default(),
476-
]
477-
})
478-
.max()
479-
.unwrap_or_default();
480-
let mut bind_collector_1 = RawBytesBindCollector::<diesel::pg::Pg>::new();
481-
let mut metadata_lookup_1 = PgAsyncMetadataLookup {
482-
custom_oid: false,
483-
generated_oids: Some(HashMap::new()),
484-
oid_generator: move |_, _| {
485-
max_oid += 2;
486-
(max_oid, max_oid + 1)
487-
},
488-
};
489-
let collect_bind_result_2 =
490-
query.collect_binds(&mut bind_collector_1, &mut metadata_lookup_1, &Pg);
491-
492-
assert_eq!(
493-
bind_collector_0.binds.len(),
494-
bind_collector_0.metadata.len()
495-
);
496-
let fake_oid_locations = std::iter::zip(
497-
bind_collector_0
498-
.binds
499-
.iter()
500-
.zip(&bind_collector_0.metadata),
501-
&bind_collector_1.binds,
502-
)
503-
.enumerate()
504-
.flat_map(|(bind_index, ((bytes_0, metadata_0), bytes_1))| {
505-
// custom oids might appear in the serialized bind arguments for arrays or composite (record) types
506-
// in both cases the relevant buffer is a custom type on it's own
507-
// so we only need to check the cases that contain a fake OID on their own
508-
let (bytes_0, bytes_1) = if matches!(metadata_0.oid(), Ok(FAKE_OID)) {
509-
(
510-
bytes_0.as_deref().unwrap_or_default(),
511-
bytes_1.as_deref().unwrap_or_default(),
512-
)
513-
} else {
514-
// for all other cases, just return an empty
515-
// list to make the iteration below a no-op
516-
// and prevent the need of boxing
517-
(&[] as &[_], &[] as &[_])
518-
};
519-
let lookup_map = metadata_lookup_1
520-
.generated_oids
521-
.as_ref()
522-
.map(|map| {
523-
map.values()
524-
.flat_map(|(oid, array_oid)| [*oid, *array_oid])
525-
.collect::<HashSet<_>>()
526-
})
527-
.unwrap_or_default();
528-
std::iter::zip(
529-
bytes_0.windows(std::mem::size_of_val(&FAKE_OID)),
530-
bytes_1.windows(std::mem::size_of_val(&FAKE_OID)),
531-
)
532-
.enumerate()
533-
.filter_map(move |(byte_index, (l, r))| {
534-
// here we infer if some byte sequence is a fake oid
535-
// We use the following conditions for that:
536-
//
537-
// * The first byte sequence matches the constant FAKE_OID
538-
// * The second sequence does not match the constant FAKE_OID
539-
// * The second sequence is contained in the set of generated oid,
540-
// otherwise we get false positives around the boundary
541-
// of a to be replaced byte sequence
542-
let r_val =
543-
u32::from_be_bytes(r.try_into().expect("That's the right size"));
544-
(l == FAKE_OID.to_be_bytes()
545-
&& r != FAKE_OID.to_be_bytes()
546-
&& lookup_map.contains(&r_val))
547-
.then_some((bind_index, byte_index))
548-
})
549-
})
550-
// Avoid storing the bind collectors in the returned Future
551-
.collect::<Vec<_>>();
552-
(
553-
collect_bind_result_0.and(collect_bind_result_2),
554-
fake_oid_locations,
555-
metadata_lookup_1.generated_oids,
556-
bind_collector_1,
557-
)
558-
} else {
559-
(collect_bind_result_0, Vec::new(), None, bind_collector_0)
560-
}
561-
};
436+
let bind_data = construct_bind_data(&query);
562437

563438
// The code that doesn't need the `T` generic parameter is in a separate function to reduce LLVM IR lines
564439
self.with_prepared_statement_after_sql_built(
565440
callback,
566441
query.is_safe_to_cache_prepared(&Pg),
567442
T::query_id(),
568443
query.to_sql(&mut query_builder, &Pg),
569-
collect_bind_result,
570444
query_builder,
571-
bind_collector,
572-
fake_oid_locations,
573-
generated_oids,
445+
bind_data,
574446
)
575447
}
576448

577-
#[allow(clippy::too_many_arguments)]
578449
fn with_prepared_statement_after_sql_built<'a, F, R>(
579450
&mut self,
580451
callback: fn(Arc<tokio_postgres::Client>, Statement, Vec<ToSqlHelper>) -> F,
581452
is_safe_to_cache_prepared: QueryResult<bool>,
582453
query_id: Option<std::any::TypeId>,
583454
to_sql_result: QueryResult<()>,
584-
collect_bind_result: QueryResult<()>,
585455
query_builder: PgQueryBuilder,
586-
mut bind_collector: RawBytesBindCollector<Pg>,
587-
fake_oid_locations: Vec<(usize, usize)>,
588-
generated_oids: GeneratedOidTypeMap,
456+
bind_data: BindData,
589457
) -> BoxFuture<'a, QueryResult<R>>
590458
where
591459
F: Future<Output = QueryResult<R>> + Send + 'a,
@@ -596,6 +464,12 @@ impl AsyncPgConnection {
596464
let metadata_cache = self.metadata_cache.clone();
597465
let tm = self.transaction_state.clone();
598466
let instrumentation = self.instrumentation.clone();
467+
let BindData {
468+
collect_bind_result,
469+
fake_oid_locations,
470+
generated_oids,
471+
mut bind_collector,
472+
} = bind_data;
599473

600474
async move {
601475
let sql = to_sql_result.map(|_| query_builder.finish())?;
@@ -710,6 +584,142 @@ impl AsyncPgConnection {
710584
}
711585
}
712586

587+
struct BindData {
588+
collect_bind_result: Result<(), Error>,
589+
fake_oid_locations: Vec<(usize, usize)>,
590+
generated_oids: GeneratedOidTypeMap,
591+
bind_collector: RawBytesBindCollector<Pg>,
592+
}
593+
594+
fn construct_bind_data(query: &dyn QueryFragment<diesel::pg::Pg>) -> BindData {
595+
// we don't resolve custom types here yet, we do that later
596+
// in the async block below as we might need to perform lookup
597+
// queries for that.
598+
//
599+
// We apply this workaround to prevent requiring all the diesel
600+
// serialization code to beeing async
601+
//
602+
// We give out constant fake oids here to optimize for the "happy" path
603+
// without custom type lookup
604+
let mut bind_collector_0 = RawBytesBindCollector::<diesel::pg::Pg>::new();
605+
let mut metadata_lookup_0 = PgAsyncMetadataLookup {
606+
custom_oid: false,
607+
generated_oids: None,
608+
oid_generator: |_, _| (FAKE_OID, FAKE_OID),
609+
};
610+
let collect_bind_result_0 =
611+
query.collect_binds(&mut bind_collector_0, &mut metadata_lookup_0, &Pg);
612+
// we have encountered a custom type oid, so we need to perform more work here.
613+
// These oids can occure in two locations:
614+
//
615+
// * In the collected metadata -> relativly easy to resolve, just need to replace them below
616+
// * As part of the seralized bind blob -> hard to replace
617+
//
618+
// To address the second case, we perform a second run of the bind collector
619+
// with a different set of fake oids. Then we compare the output of the two runs
620+
// and use that information to infer where to replace bytes in the serialized output
621+
if metadata_lookup_0.custom_oid {
622+
// we try to get the maxium oid we encountered here
623+
// to be sure that we don't accidently give out a fake oid below that collides with
624+
// something
625+
let mut max_oid = bind_collector_0
626+
.metadata
627+
.iter()
628+
.flat_map(|t| {
629+
[
630+
t.oid().unwrap_or_default(),
631+
t.array_oid().unwrap_or_default(),
632+
]
633+
})
634+
.max()
635+
.unwrap_or_default();
636+
let mut bind_collector_1 = RawBytesBindCollector::<diesel::pg::Pg>::new();
637+
let mut metadata_lookup_1 = PgAsyncMetadataLookup {
638+
custom_oid: false,
639+
generated_oids: Some(HashMap::new()),
640+
oid_generator: move |_, _| {
641+
max_oid += 2;
642+
(max_oid, max_oid + 1)
643+
},
644+
};
645+
let collect_bind_result_1 =
646+
query.collect_binds(&mut bind_collector_1, &mut metadata_lookup_1, &Pg);
647+
648+
assert_eq!(
649+
bind_collector_0.binds.len(),
650+
bind_collector_0.metadata.len()
651+
);
652+
let fake_oid_locations = std::iter::zip(
653+
bind_collector_0
654+
.binds
655+
.iter()
656+
.zip(&bind_collector_0.metadata),
657+
&bind_collector_1.binds,
658+
)
659+
.enumerate()
660+
.flat_map(|(bind_index, ((bytes_0, metadata_0), bytes_1))| {
661+
// custom oids might appear in the serialized bind arguments for arrays or composite (record) types
662+
// in both cases the relevant buffer is a custom type on it's own
663+
// so we only need to check the cases that contain a fake OID on their own
664+
let (bytes_0, bytes_1) = if matches!(metadata_0.oid(), Ok(FAKE_OID)) {
665+
(
666+
bytes_0.as_deref().unwrap_or_default(),
667+
bytes_1.as_deref().unwrap_or_default(),
668+
)
669+
} else {
670+
// for all other cases, just return an empty
671+
// list to make the iteration below a no-op
672+
// and prevent the need of boxing
673+
(&[] as &[_], &[] as &[_])
674+
};
675+
let lookup_map = metadata_lookup_1
676+
.generated_oids
677+
.as_ref()
678+
.map(|map| {
679+
map.values()
680+
.flat_map(|(oid, array_oid)| [*oid, *array_oid])
681+
.collect::<HashSet<_>>()
682+
})
683+
.unwrap_or_default();
684+
std::iter::zip(
685+
bytes_0.windows(std::mem::size_of_val(&FAKE_OID)),
686+
bytes_1.windows(std::mem::size_of_val(&FAKE_OID)),
687+
)
688+
.enumerate()
689+
.filter_map(move |(byte_index, (l, r))| {
690+
// here we infer if some byte sequence is a fake oid
691+
// We use the following conditions for that:
692+
//
693+
// * The first byte sequence matches the constant FAKE_OID
694+
// * The second sequence does not match the constant FAKE_OID
695+
// * The second sequence is contained in the set of generated oid,
696+
// otherwise we get false positives around the boundary
697+
// of a to be replaced byte sequence
698+
let r_val = u32::from_be_bytes(r.try_into().expect("That's the right size"));
699+
(l == FAKE_OID.to_be_bytes()
700+
&& r != FAKE_OID.to_be_bytes()
701+
&& lookup_map.contains(&r_val))
702+
.then_some((bind_index, byte_index))
703+
})
704+
})
705+
// Avoid storing the bind collectors in the returned Future
706+
.collect::<Vec<_>>();
707+
BindData {
708+
collect_bind_result: collect_bind_result_0.and(collect_bind_result_1),
709+
fake_oid_locations,
710+
generated_oids: metadata_lookup_1.generated_oids,
711+
bind_collector: bind_collector_1,
712+
}
713+
} else {
714+
BindData {
715+
collect_bind_result: collect_bind_result_0,
716+
fake_oid_locations: Vec::new(),
717+
generated_oids: None,
718+
bind_collector: bind_collector_0,
719+
}
720+
}
721+
}
722+
713723
type GeneratedOidTypeMap = Option<HashMap<(Option<String>, String), (u32, u32)>>;
714724

715725
/// Collects types that need to be looked up, and causes fake OIDs to be written into the bind collector

0 commit comments

Comments
 (0)