1
+ use rustc_abi:: HasDataLayout ;
1
2
use rustc_ast:: expand:: autodiff_attrs:: { AutoDiffItem , DiffActivity } ;
2
3
use rustc_hir:: def_id:: LOCAL_CRATE ;
3
4
use rustc_middle:: bug;
@@ -16,6 +17,7 @@ fn adjust_activity_to_abi<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec
16
17
// We don't actually pass the types back into the type system.
17
18
// All we do is decide how to handle the arguments.
18
19
let sig = fn_ty. fn_sig ( tcx) . skip_binder ( ) ;
20
+ let pointer_size = tcx. data_layout ( ) . pointer_size ;
19
21
20
22
let mut new_activities = vec ! [ ] ;
21
23
let mut new_positions = vec ! [ ] ;
@@ -70,6 +72,25 @@ fn adjust_activity_to_abi<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec
70
72
continue ;
71
73
}
72
74
}
75
+
76
+ let pci = PseudoCanonicalInput { typing_env : TypingEnv :: fully_monomorphized ( ) , value : * ty } ;
77
+
78
+ let layout = match tcx. layout_of ( pci) {
79
+ Ok ( layout) => layout. layout ,
80
+ Err ( _) => {
81
+ bug ! ( "failed to compute layout for type {:?}" , ty) ;
82
+ }
83
+ } ;
84
+
85
+ let is_product = |t : Ty < ' tcx > | matches ! ( t. kind( ) , ty:: Tuple ( _) | ty:: Adt ( _, _) ) ;
86
+
87
+ if layout. size ( ) <= pointer_size * 2 && is_product ( * ty) {
88
+ let n_scalars = count_scalar_fields ( tcx, * ty) ;
89
+ for _ in 0 ..n_scalars. saturating_sub ( 1 ) {
90
+ new_activities. push ( da[ i] . clone ( ) ) ;
91
+ new_positions. push ( i + 1 ) ;
92
+ }
93
+ }
73
94
}
74
95
// now add the extra activities coming from slices
75
96
// Reverse order to not invalidate the indices
@@ -80,6 +101,20 @@ fn adjust_activity_to_abi<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec
80
101
}
81
102
}
82
103
104
+ fn count_scalar_fields < ' tcx > ( tcx : TyCtxt < ' tcx > , ty : Ty < ' tcx > ) -> usize {
105
+ match ty. kind ( ) {
106
+ ty:: Float ( _) | ty:: Int ( _) | ty:: Uint ( _) => 1 ,
107
+ ty:: Adt ( def, substs) if def. is_struct ( ) => def
108
+ . non_enum_variant ( )
109
+ . fields
110
+ . iter ( )
111
+ . map ( |f| count_scalar_fields ( tcx, f. ty ( tcx, substs) ) )
112
+ . sum ( ) ,
113
+ ty:: Tuple ( substs) => substs. iter ( ) . map ( |t| count_scalar_fields ( tcx, t) ) . sum ( ) ,
114
+ _ => 0 ,
115
+ }
116
+ }
117
+
83
118
pub ( crate ) fn find_autodiff_source_functions < ' tcx > (
84
119
tcx : TyCtxt < ' tcx > ,
85
120
usage_map : & UsageMap < ' tcx > ,
0 commit comments