@@ -64,7 +64,7 @@ impl ArrowBuilder<DualAverageStats> for DualAverageStatsBuilder {
64
64
self . n_steps . push ( Some ( value. n_steps ) ) ;
65
65
}
66
66
67
- fn finalize ( mut self ) -> StructArray {
67
+ fn finalize ( mut self ) -> Option < StructArray > {
68
68
let fields = vec ! [
69
69
Field :: new( "step_size_bar" , DataType :: Float64 , false ) ,
70
70
Field :: new( "mean_tree_accept" , DataType :: Float64 , false ) ,
@@ -77,7 +77,7 @@ impl ArrowBuilder<DualAverageStats> for DualAverageStatsBuilder {
77
77
self . n_steps. as_box( ) ,
78
78
] ;
79
79
80
- StructArray :: new ( DataType :: Struct ( fields) , arrays, None )
80
+ Some ( StructArray :: new ( DataType :: Struct ( fields) , arrays, None ) )
81
81
}
82
82
}
83
83
@@ -177,19 +177,11 @@ impl<F: CpuLogpFunc, M: MassMatrix> AdaptStrategy for DualAverageStrategy<F, M>
177
177
}
178
178
179
179
/// Settings for mass matrix adaptation
180
- #[ derive( Clone , Copy , Debug ) ]
180
+ #[ derive( Clone , Copy , Debug , Default ) ]
181
181
pub struct DiagAdaptExpSettings {
182
182
pub store_mass_matrix : bool ,
183
183
}
184
184
185
- impl Default for DiagAdaptExpSettings {
186
- fn default ( ) -> Self {
187
- Self {
188
- store_mass_matrix : false ,
189
- }
190
- }
191
- }
192
-
193
185
pub ( crate ) struct ExpWindowDiagAdapt < F > {
194
186
dim : usize ,
195
187
exp_variance_draw : RunningVariance ,
@@ -284,7 +276,7 @@ impl ArrowBuilder<ExpWindowDiagAdaptStats> for ExpWindowDiagAdaptStatsBuilder {
284
276
}
285
277
}
286
278
287
- fn finalize ( self ) -> StructArray {
279
+ fn finalize ( self ) -> Option < StructArray > {
288
280
if let Some ( mut store) = self . mass_matrix_inv {
289
281
let fields = vec ! [ Field :: new(
290
282
"mass_matrix_inv" ,
@@ -294,9 +286,9 @@ impl ArrowBuilder<ExpWindowDiagAdaptStats> for ExpWindowDiagAdaptStatsBuilder {
294
286
295
287
let arrays = vec ! [ store. as_box( ) ] ;
296
288
297
- StructArray :: new ( DataType :: Struct ( fields) , arrays, None )
289
+ Some ( StructArray :: new ( DataType :: Struct ( fields) , arrays, None ) )
298
290
} else {
299
- StructArray :: new ( DataType :: Struct ( vec ! [ ] ) , vec ! [ ] , None )
291
+ None
300
292
}
301
293
}
302
294
}
@@ -570,17 +562,24 @@ impl<D1: Debug + ArrowRow, D2: Debug + ArrowRow> ArrowBuilder<CombinedStats<D1,
570
562
self . stats2 . append_value ( & value. stats2 ) ;
571
563
}
572
564
573
- fn finalize ( self ) -> StructArray {
574
- let mut data1 = self . stats1 . finalize ( ) . into_data ( ) ;
575
- let data2 = self . stats2 . finalize ( ) . into_data ( ) ;
565
+ fn finalize ( self ) -> Option < StructArray > {
566
+ match ( self . stats1 . finalize ( ) , self . stats2 . finalize ( ) ) {
567
+ ( None , None ) => None ,
568
+ ( Some ( stats1) , None ) => Some ( stats1) ,
569
+ ( None , Some ( stats2) ) => Some ( stats2) ,
570
+ ( Some ( stats1) , Some ( stats2) ) => {
571
+ let mut data1 = stats1. into_data ( ) ;
572
+ let data2 = stats2. into_data ( ) ;
576
573
577
- assert ! ( data1. 2 . is_none( ) ) ;
578
- assert ! ( data2. 2 . is_none( ) ) ;
574
+ assert ! ( data1. 2 . is_none( ) ) ;
575
+ assert ! ( data2. 2 . is_none( ) ) ;
579
576
580
- data1. 0 . extend ( data2. 0 ) ;
581
- data1. 1 . extend ( data2. 1 ) ;
577
+ data1. 0 . extend ( data2. 0 ) ;
578
+ data1. 1 . extend ( data2. 1 ) ;
582
579
583
- StructArray :: new ( DataType :: Struct ( data1. 0 ) , data1. 1 , None )
580
+ Some ( StructArray :: new ( DataType :: Struct ( data1. 0 ) , data1. 1 , None ) )
581
+ }
582
+ }
584
583
}
585
584
}
586
585
0 commit comments