Skip to content

Commit aec2861

Browse files
committed
Check signature before cloning callee body.
1 parent 930faff commit aec2861

File tree

1 file changed

+26
-28
lines changed

1 file changed

+26
-28
lines changed

compiler/rustc_mir_transform/src/inline.rs

Lines changed: 26 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ impl<'tcx> Inliner<'tcx> {
183183

184184
self.check_mir_is_available(caller_body, &callsite.callee)?;
185185
let callee_body = try_instance_mir(self.tcx, callsite.callee.def)?;
186-
self.check_mir_body(callsite, callee_body, callee_attrs)?;
186+
self.check_mir_body(caller_body, callsite, callee_body, callee_attrs)?;
187187

188188
if !self.tcx.consider_optimizing(|| {
189189
format!("Inline {:?} into {:?}", callsite.callee, caller_body.source)
@@ -199,8 +199,6 @@ impl<'tcx> Inliner<'tcx> {
199199
return Err("failed to normalize callee body");
200200
};
201201

202-
self.check_subst_body(caller_body, callsite, &callee_body)?;
203-
204202
let old_blocks = caller_body.basic_blocks.next_index();
205203
self.inline_call(caller_body, &callsite, callee_body);
206204
let new_blocks = old_blocks..caller_body.basic_blocks.next_index();
@@ -370,6 +368,7 @@ impl<'tcx> Inliner<'tcx> {
370368
#[instrument(level = "debug", skip(self, callee_body))]
371369
fn check_mir_body(
372370
&self,
371+
caller_body: &Body<'tcx>,
373372
callsite: &CallSite<'tcx>,
374373
callee_body: &Body<'tcx>,
375374
callee_attrs: &CodegenFnAttrs,
@@ -440,32 +439,21 @@ impl<'tcx> Inliner<'tcx> {
440439
// Abort if type validation found anything fishy.
441440
checker.validation?;
442441

443-
let cost = checker.cost;
444-
if let InlineAttr::Always = callee_attrs.inline {
445-
debug!("INLINING {:?} because inline(always) [cost={}]", callsite, cost);
446-
Ok(())
447-
} else if cost <= threshold {
448-
debug!("INLINING {:?} [cost={} <= threshold={}]", callsite, cost, threshold);
449-
Ok(())
450-
} else {
451-
debug!("NOT inlining {:?} [cost={} > threshold={}]", callsite, cost, threshold);
452-
Err("cost above threshold")
453-
}
454-
}
442+
let substitute = |ty| {
443+
let ty = ty::EarlyBinder::bind(ty);
444+
callsite
445+
.callee
446+
.try_subst_mir_and_normalize_erasing_regions(self.tcx, self.param_env, ty)
447+
.map_err(|_| "failed to normalize callee body")
448+
};
455449

456-
/// Check call signature compatibility.
457-
/// Normally, this shouldn't be required, but trait normalization failure can create a
458-
/// validation ICE.
459-
fn check_subst_body(
460-
&self,
461-
caller_body: &Body<'tcx>,
462-
callsite: &CallSite<'tcx>,
463-
callee_body: &Body<'tcx>,
464-
) -> Result<(), &'static str> {
450+
// Check call signature compatibility.
451+
// Normally, this shouldn't be required, but trait normalization failure can create a
452+
// validation ICE.
465453
let terminator = caller_body[callsite.block].terminator.as_ref().unwrap();
466454
let TerminatorKind::Call { args, destination, .. } = &terminator.kind else { bug!() };
467455
let destination_ty = destination.ty(&caller_body.local_decls, self.tcx).ty;
468-
let output_type = callee_body.return_ty();
456+
let output_type = substitute(callee_body.return_ty())?;
469457
if !util::is_subtype(self.tcx, self.param_env, output_type, destination_ty) {
470458
trace!(?output_type, ?destination_ty);
471459
return Err("failed to normalize return type");
@@ -485,15 +473,15 @@ impl<'tcx> Inliner<'tcx> {
485473
for (arg_ty, input) in
486474
arg_tuple_tys.iter().zip(callee_body.args_iter().skip(skipped_args))
487475
{
488-
let input_type = callee_body.local_decls[input].ty;
476+
let input_type = substitute(callee_body.local_decls[input].ty)?;
489477
if !util::is_subtype(self.tcx, self.param_env, input_type, arg_ty) {
490478
trace!(?arg_ty, ?input_type);
491479
return Err("failed to normalize tuple argument type");
492480
}
493481
}
494482
} else {
495483
for (arg, input) in args.iter().zip(callee_body.args_iter()) {
496-
let input_type = callee_body.local_decls[input].ty;
484+
let input_type = substitute(callee_body.local_decls[input].ty)?;
497485
let arg_ty = arg.ty(&caller_body.local_decls, self.tcx);
498486
if !util::is_subtype(self.tcx, self.param_env, input_type, arg_ty) {
499487
trace!(?arg_ty, ?input_type);
@@ -502,7 +490,17 @@ impl<'tcx> Inliner<'tcx> {
502490
}
503491
}
504492

505-
Ok(())
493+
let cost = checker.cost;
494+
if let InlineAttr::Always = callee_attrs.inline {
495+
debug!("INLINING {:?} because inline(always) [cost={}]", callsite, cost);
496+
Ok(())
497+
} else if cost <= threshold {
498+
debug!("INLINING {:?} [cost={} <= threshold={}]", callsite, cost, threshold);
499+
Ok(())
500+
} else {
501+
debug!("NOT inlining {:?} [cost={} > threshold={}]", callsite, cost, threshold);
502+
Err("cost above threshold")
503+
}
506504
}
507505

508506
fn inline_call(

0 commit comments

Comments
 (0)