Skip to content

Commit b2284b5

Browse files
Develop fix visual torchh export (#4494) (#4497)
* Fixing exporting of ONNX for visual when using threading * docstring was wrong
1 parent 736ac08 commit b2284b5

File tree

2 files changed

+39
-10
lines changed

2 files changed

+39
-10
lines changed

ml-agents/mlagents/trainers/torch/model_serialization.py

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import threading
23
from mlagents.torch_utils import torch
34

45
from mlagents_envs.logging_util import get_logger
@@ -8,6 +9,32 @@
89
logger = get_logger(__name__)
910

1011

12+
class exporting_to_onnx:
13+
"""
14+
Set this context by calling
15+
```
16+
with exporting_to_onnx():
17+
```
18+
Within this context, the variable exporting_to_onnx.is_exporting() will be true.
19+
This implementation is thread safe.
20+
"""
21+
22+
_local_data = threading.local()
23+
_local_data._is_exporting = False
24+
25+
def __enter__(self):
26+
self._local_data._is_exporting = True
27+
28+
def __exit__(self, *args):
29+
self._local_data._is_exporting = False
30+
31+
@staticmethod
32+
def is_exporting():
33+
if not hasattr(exporting_to_onnx._local_data, "_is_exporting"):
34+
return False
35+
return exporting_to_onnx._local_data._is_exporting
36+
37+
1138
class ModelSerializer:
1239
def __init__(self, policy):
1340
# ONNX only support input in NCHW (channel first) format.
@@ -61,13 +88,14 @@ def export_policy_model(self, output_filepath: str) -> None:
6188
onnx_output_path = f"{output_filepath}.onnx"
6289
logger.info(f"Converting to {onnx_output_path}")
6390

64-
torch.onnx.export(
65-
self.policy.actor_critic,
66-
self.dummy_input,
67-
onnx_output_path,
68-
opset_version=SerializationSettings.onnx_opset,
69-
input_names=self.input_names,
70-
output_names=self.output_names,
71-
dynamic_axes=self.dynamic_axes,
72-
)
91+
with exporting_to_onnx():
92+
torch.onnx.export(
93+
self.policy.actor_critic,
94+
self.dummy_input,
95+
onnx_output_path,
96+
opset_version=SerializationSettings.onnx_opset,
97+
input_names=self.input_names,
98+
output_names=self.output_names,
99+
dynamic_axes=self.dynamic_axes,
100+
)
73101
logger.info(f"Exported {onnx_output_path}")

ml-agents/mlagents/trainers/torch/networks.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from mlagents.trainers.torch.utils import ModelUtils
1414
from mlagents.trainers.torch.decoders import ValueHeads
1515
from mlagents.trainers.torch.layers import LSTM, LinearEncoder
16+
from mlagents.trainers.torch.model_serialization import exporting_to_onnx
1617

1718
ActivationFunction = Callable[[torch.Tensor], torch.Tensor]
1819
EncoderFunction = Callable[
@@ -84,7 +85,7 @@ def forward(
8485

8586
for idx, processor in enumerate(self.visual_processors):
8687
vis_input = vis_inputs[idx]
87-
if not torch.onnx.is_in_onnx_export():
88+
if not exporting_to_onnx.is_exporting():
8889
vis_input = vis_input.permute([0, 3, 1, 2])
8990
processed_vis = processor(vis_input)
9091
encodes.append(processed_vis)

0 commit comments

Comments
 (0)