@@ -256,6 +256,7 @@ def create_trainer(args):
256
256
click_probability_key = "probability" ,
257
257
train = False ,
258
258
label_names = args .labels ,
259
+ max_interactions = args .max_val_interactions ,
259
260
),
260
261
inferer = SimpleInferer (),
261
262
postprocessing = post_transform ,
@@ -307,6 +308,7 @@ def create_trainer(args):
307
308
click_probability_key = "probability" ,
308
309
train = True ,
309
310
label_names = args .labels ,
311
+ max_interactions = args .max_train_interactions ,
310
312
),
311
313
optimizer = optimizer ,
312
314
loss_function = loss_function ,
@@ -393,8 +395,8 @@ def main():
393
395
394
396
parser .add_argument ("-f" , "--val_freq" , type = int , default = 1 )
395
397
parser .add_argument ("-lr" , "--learning_rate" , type = float , default = 0.0001 )
396
- parser .add_argument ("-it" , "--max_train_interactions" , type = int , default = 15 )
397
- parser .add_argument ("-iv" , "--max_val_interactions" , type = int , default = 5 )
398
+ parser .add_argument ("-it" , "--max_train_interactions" , type = int , default = 1 )
399
+ parser .add_argument ("-iv" , "--max_val_interactions" , type = int , default = 1 )
398
400
399
401
parser .add_argument ("-dpt" , "--deepgrow_probability_train" , type = float , default = 0.4 )
400
402
parser .add_argument ("-dpv" , "--deepgrow_probability_val" , type = float , default = 1.0 )
0 commit comments