Skip to content

Commit edcc7e9

Browse files
updated the code to avoid deepcopy()
1 parent e7563f6 commit edcc7e9

File tree

1 file changed

+18
-12
lines changed

1 file changed

+18
-12
lines changed

intermediate_source/mario_rl_tutorial.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,23 @@ def __init__(self, input_dim, output_dim):
424424
if w != 84:
425425
raise ValueError(f"Expecting input width: 84, got: {w}")
426426

427-
self.online = nn.Sequential(
427+
self.online = self._build_cnn(input_dim, output_dim)
428+
self.target = self._build_cnn(input_dim, output_dim)
429+
430+
# Q_target parameters are frozen.
431+
for p in self.target.parameters():
432+
p.requires_grad = False
433+
434+
def forward(self, input, model):
435+
if model == "online":
436+
return self.online(input)
437+
elif model == "target":
438+
return self.target(input)
439+
440+
def _build_cnn(self, input_dim, output_dim):
441+
c, _, _ = input_dim
442+
443+
cnn = nn.Sequential(
428444
nn.Conv2d(in_channels=c, out_channels=32, kernel_size=8, stride=4),
429445
nn.ReLU(),
430446
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
@@ -437,17 +453,7 @@ def __init__(self, input_dim, output_dim):
437453
nn.Linear(512, output_dim),
438454
)
439455

440-
self.target = copy.deepcopy(self.online)
441-
442-
# Q_target parameters are frozen.
443-
for p in self.target.parameters():
444-
p.requires_grad = False
445-
446-
def forward(self, input, model):
447-
if model == "online":
448-
return self.online(input)
449-
elif model == "target":
450-
return self.target(input)
456+
return cnn
451457

452458

453459
######################################################################

0 commit comments

Comments
 (0)