Skip to content

Commit f11aed4

Browse files
committed
[ExecuTorch] Arm Ethos: Do not depend on torch.testing._internal
Pull Request resolved: #8839 This can cuase issues with `disable_global_flags` and internal state of the library, this is something which is set when importing this. ghstack-source-id: 269065441 Differential Revision: [D70402061](https://our.internmc.facebook.com/intern/diff/D70402061/)
1 parent 7ce47fc commit f11aed4

File tree

2 files changed

+28
-6
lines changed

2 files changed

+28
-6
lines changed

backends/arm/test/passes/test_rescale_pass.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@
1313
from executorch.backends.arm.test import common, conftest
1414
from executorch.backends.arm.test.tester.arm_tester import ArmTester
1515
from parameterized import parameterized
16-
from torch.testing._internal import optests
1716

17+
class OpCheckError(Exception):
18+
pass
1819

1920
def test_rescale_op():
2021
sample_inputs = [
@@ -64,7 +65,7 @@ def test_nonzero_zp_for_int32():
6465
),
6566
]
6667
for sample_input in sample_inputs:
67-
with pytest.raises(optests.generate_tests.OpCheckError):
68+
with pytest.raises(OpCheckError):
6869
torch.library.opcheck(torch.ops.tosa._rescale, sample_input)
6970

7071

@@ -87,7 +88,7 @@ def test_zp_outside_range():
8788
),
8889
]
8990
for sample_input in sample_inputs:
90-
with pytest.raises(optests.generate_tests.OpCheckError):
91+
with pytest.raises(OpCheckError):
9192
torch.library.opcheck(torch.ops.tosa._rescale, sample_input)
9293

9394

backends/arm/test/runner_utils.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,33 @@
3434
from torch.fx.node import Node
3535

3636
from torch.overrides import TorchFunctionMode
37-
from torch.testing._internal.common_utils import torch_to_numpy_dtype_dict
3837
from tosa import TosaGraph
3938

4039
logger = logging.getLogger(__name__)
4140
logger.setLevel(logging.CRITICAL)
4241

42+
# Copied from PyTorch.
43+
# From torch/testing/_internal/common_utils.py:torch_to_numpy_dtype_dict
44+
# To avoid a dependency on _internal stuff.
45+
_torch_to_numpy_dtype_dict = {
46+
torch.bool: np.bool_,
47+
torch.uint8: np.uint8,
48+
torch.uint16: np.uint16,
49+
torch.uint32: np.uint32,
50+
torch.uint64: np.uint64,
51+
torch.int8: np.int8,
52+
torch.int16: np.int16,
53+
torch.int32: np.int32,
54+
torch.int64: np.int64,
55+
torch.float16: np.float16,
56+
torch.float32: np.float32,
57+
torch.float64: np.float64,
58+
torch.bfloat16: np.float32,
59+
torch.complex32: np.complex64,
60+
torch.complex64: np.complex64,
61+
torch.complex128: np.complex128,
62+
}
63+
4364

4465
class QuantizationParams:
4566
__slots__ = ["node_name", "zp", "scale", "qmin", "qmax", "dtype"]
@@ -335,7 +356,7 @@ def run_corstone(
335356
output_dtype = node.meta["val"].dtype
336357
tosa_ref_output = np.fromfile(
337358
os.path.join(intermediate_path, f"out-{i}.bin"),
338-
torch_to_numpy_dtype_dict[output_dtype],
359+
_torch_to_numpy_dtype_dict[output_dtype],
339360
)
340361

341362
output_np.append(torch.from_numpy(tosa_ref_output).reshape(output_shape))
@@ -349,7 +370,7 @@ def prep_data_for_save(
349370
):
350371
if isinstance(data, torch.Tensor):
351372
data_np = np.array(data.detach(), order="C").astype(
352-
torch_to_numpy_dtype_dict[data.dtype]
373+
_torch_to_numpy_dtype_dict[data.dtype]
353374
)
354375
else:
355376
data_np = np.array(data)

0 commit comments

Comments
 (0)