@@ -424,7 +424,23 @@ def __init__(self, input_dim, output_dim):
424
424
if w != 84 :
425
425
raise ValueError (f"Expecting input width: 84, got: { w } " )
426
426
427
- self .online = nn .Sequential (
427
+ self .online = self ._build_cnn (input_dim , output_dim )
428
+ self .target = self ._build_cnn (input_dim , output_dim )
429
+
430
+ # Q_target parameters are frozen.
431
+ for p in self .target .parameters ():
432
+ p .requires_grad = False
433
+
434
+ def forward (self , input , model ):
435
+ if model == "online" :
436
+ return self .online (input )
437
+ elif model == "target" :
438
+ return self .target (input )
439
+
440
+ def _build_cnn (self , input_dim , output_dim ):
441
+ c , _ , _ = input_dim
442
+
443
+ cnn = nn .Sequential (
428
444
nn .Conv2d (in_channels = c , out_channels = 32 , kernel_size = 8 , stride = 4 ),
429
445
nn .ReLU (),
430
446
nn .Conv2d (in_channels = 32 , out_channels = 64 , kernel_size = 4 , stride = 2 ),
@@ -437,17 +453,7 @@ def __init__(self, input_dim, output_dim):
437
453
nn .Linear (512 , output_dim ),
438
454
)
439
455
440
- self .target = copy .deepcopy (self .online )
441
-
442
- # Q_target parameters are frozen.
443
- for p in self .target .parameters ():
444
- p .requires_grad = False
445
-
446
- def forward (self , input , model ):
447
- if model == "online" :
448
- return self .online (input )
449
- elif model == "target" :
450
- return self .target (input )
456
+ return cnn
451
457
452
458
453
459
######################################################################
0 commit comments