Skip to content

Commit 9344de1

Browse files
committed
bf16 support
1 parent ca59597 commit 9344de1

File tree

3 files changed

+62
-7
lines changed

3 files changed

+62
-7
lines changed

issue_3458.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from __future__ import annotations
2+
3+
import os
4+
5+
import torch
6+
import torch_tensorrt
7+
8+
os.environ["CI_BUILD"] = "1"
9+
10+
dtype = torch.bfloat16
11+
device = torch.device("cuda", 0)
12+
13+
14+
class MyModule(torch.nn.Module):
15+
def __init__(self) -> None:
16+
super().__init__()
17+
18+
def forward(self, x: torch.Tensor) -> torch.Tensor:
19+
return x * 0.5
20+
21+
22+
with torch.inference_mode():
23+
model = MyModule().eval().to(device, dtype)
24+
inputs = (torch.randn(1, 3, 224, 224, dtype=dtype, device=device),)
25+
exported_program = torch.export.export(model, inputs)
26+
27+
trt_model = torch_tensorrt.dynamo.compile(
28+
exported_program,
29+
inputs,
30+
device=device,
31+
enabled_precisions={dtype},
32+
debug=True,
33+
min_block_size=1,
34+
)
35+
36+
print(trt_model(inputs[0]))

py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,10 @@
1515
get_trt_tensor,
1616
has_dynamic_shape,
1717
set_layer_name,
18+
to_torch,
1819
)
1920
from torch_tensorrt.dynamo.types import TRTElementWiseOp, TRTTensor
2021

21-
import tensorrt as trt
22-
2322

2423
def get_python_op_from_trt_elementwise_op(
2524
trt_op: TRTElementWiseOp,
@@ -125,10 +124,9 @@ def convert_binary_elementwise(
125124
# dtype but we don't have a way to detect whether it makes sense for the
126125
# scalar to be float or half. Hence we go with the lhs dtype.
127126
if is_lhs_trt_tensor and isinstance(rhs_val, (float, int, bool)):
128-
rhs_val = np.array([rhs_val], dtype=_enums.dtype._from(lhs_dtype).to(np.dtype))
127+
rhs_val = to_torch(rhs_val, dtype=lhs_dtype)
129128
if is_rhs_trt_tensor and isinstance(lhs_val, (float, int, bool)):
130-
lhs_val = np.array([lhs_val], dtype=_enums.dtype._from(rhs_dtype).to(np.dtype))
131-
129+
lhs_val = to_torch(lhs_val, dtype=rhs_dtype)
132130
lhs_val = get_trt_tensor(ctx, lhs_val, f"{name}_lhs", lhs_dtype)
133131
rhs_val = get_trt_tensor(ctx, rhs_val, f"{name}_rhs", rhs_dtype)
134132

tests/py/dynamo/conversion/test_binary_ops_aten.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,11 @@
22

33
import torch
44
import torch.nn as nn
5+
from .harness import DispatchTestCase
56
from parameterized import parameterized
67
from torch.testing._internal.common_utils import run_tests
78
from torch_tensorrt import Input
89

9-
from .harness import DispatchTestCase
10-
1110
NEED_TEST_BOTH_CONSTANTS_CASE = True
1211

1312
elementwise_ops = [
@@ -228,6 +227,28 @@ def forward(self, x, y):
228227
]
229228
self.run_test_with_dynamic_shape(Op(), input_specs)
230229

230+
@parameterized.expand(
231+
[
232+
(f"bf16_{op[0].__name__}_one_constant", op[0])
233+
for op in elementwise_ops
234+
if op[0].__name__ not in ["pow.Tensor_Tensor", "fmod.Tensor"]
235+
]
236+
)
237+
def test_elementwise_ops_bf16(self, _, orig_op):
238+
class TestModule(nn.Module):
239+
def __init__(self, orig_op):
240+
super().__init__()
241+
self.constant = torch.randn(1)
242+
self.orig_op = orig_op
243+
244+
def forward(self, x):
245+
x = self.orig_op(x, self.constant)
246+
return self.orig_op(x, -2)
247+
248+
m = TestModule(orig_op)
249+
inputs = [torch.randn(2, 2, dtype=torch.bfloat16)]
250+
self.run_test(m, inputs)
251+
231252

232253
if __name__ == "__main__":
233254
run_tests()

0 commit comments

Comments
 (0)