Skip to content

Commit f9bb47c

Browse files
committed
Adjust autodiff activities for abi transformations (small tuples and structs)
1 parent d00435f commit f9bb47c

File tree

1 file changed

+35
-0
lines changed
  • compiler/rustc_monomorphize/src/partitioning

1 file changed

+35
-0
lines changed

compiler/rustc_monomorphize/src/partitioning/autodiff.rs

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use rustc_abi::HasDataLayout;
12
use rustc_ast::expand::autodiff_attrs::{AutoDiffItem, DiffActivity};
23
use rustc_hir::def_id::LOCAL_CRATE;
34
use rustc_middle::bug;
@@ -16,6 +17,7 @@ fn adjust_activity_to_abi<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec
1617
// We don't actually pass the types back into the type system.
1718
// All we do is decide how to handle the arguments.
1819
let sig = fn_ty.fn_sig(tcx).skip_binder();
20+
let pointer_size = tcx.data_layout().pointer_size;
1921

2022
let mut new_activities = vec![];
2123
let mut new_positions = vec![];
@@ -70,6 +72,25 @@ fn adjust_activity_to_abi<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec
7072
continue;
7173
}
7274
}
75+
76+
let pci = PseudoCanonicalInput { typing_env: TypingEnv::fully_monomorphized(), value: *ty };
77+
78+
let layout = match tcx.layout_of(pci) {
79+
Ok(layout) => layout.layout,
80+
Err(_) => {
81+
bug!("failed to compute layout for type {:?}", ty);
82+
}
83+
};
84+
85+
let is_product = |t: Ty<'tcx>| matches!(t.kind(), ty::Tuple(_) | ty::Adt(_, _));
86+
87+
if layout.size() <= pointer_size * 2 && is_product(*ty) {
88+
let n_scalars = count_scalar_fields(tcx, *ty);
89+
for _ in 0..n_scalars.saturating_sub(1) {
90+
new_activities.push(da[i].clone());
91+
new_positions.push(i + 1);
92+
}
93+
}
7394
}
7495
// now add the extra activities coming from slices
7596
// Reverse order to not invalidate the indices
@@ -80,6 +101,20 @@ fn adjust_activity_to_abi<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec
80101
}
81102
}
82103

104+
fn count_scalar_fields<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> usize {
105+
match ty.kind() {
106+
ty::Float(_) | ty::Int(_) | ty::Uint(_) => 1,
107+
ty::Adt(def, substs) if def.is_struct() => def
108+
.non_enum_variant()
109+
.fields
110+
.iter()
111+
.map(|f| count_scalar_fields(tcx, f.ty(tcx, substs)))
112+
.sum(),
113+
ty::Tuple(substs) => substs.iter().map(|t| count_scalar_fields(tcx, t)).sum(),
114+
_ => 0,
115+
}
116+
}
117+
83118
pub(crate) fn find_autodiff_source_functions<'tcx>(
84119
tcx: TyCtxt<'tcx>,
85120
usage_map: &UsageMap<'tcx>,

0 commit comments

Comments
 (0)