Skip to content

feat: Add sample torch.compile backend for tensorrt aten path #1751

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Apr 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -869,6 +869,24 @@ commands:
- store_artifacts:
path: /tmp/testlogs

test-dynamo-torch_compile:
description: "Test the Dynamo torch_compile path"
steps:
- run:
name: Run Dynamo torch_compile E2E tests
command: |
cd py/torch_tensorrt/dynamo/
pushd test/
pip3 install timm
pip3 install transformers
pytest --junitxml=/tmp/artifacts/test_results/dynamo/test_results.xml --ir torch_compile
popd

- store_test_results:
path: /tmp/artifacts
- store_artifacts:
path: /tmp/testlogs

test-dynamo-fx_ts:
description: "Test the dynamo backend"
steps:
Expand Down Expand Up @@ -1117,6 +1135,7 @@ jobs:
command: pip3 install --pre /tmp/dist/x86_64-linux/*cp39-cp39*.whl
# We install torch after torch-trt because pip automatically enforces the version constraint otherwise
- dump-test-env
- test-dynamo-torch_compile
- test-dynamo-fx_ts

test-py-dynamo-x86_64-linux-no-aten:
Expand Down
4 changes: 2 additions & 2 deletions py/torch_tensorrt/_Input.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,10 +367,10 @@ def example_tensor(self, optimization_profile_field: str = None) -> torch.Tensor
)

if self.shape_mode == Input._ShapeMode.STATIC:
return torch.randn(self.shape).to(
return torch.rand(self.shape).to(
dtype=self.dtype if not self.torch_dtype else self.torch_dtype
)
else:
return torch.randn(self.shape[optimization_profile_field]).to(
return torch.rand(self.shape[optimization_profile_field]).to(
dtype=self.dtype if not self.torch_dtype else self.torch_dtype
)
1 change: 1 addition & 0 deletions py/torch_tensorrt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def _find_lib(name, paths):

from torch_tensorrt import fx
from torch_tensorrt import dynamo
from torch_tensorrt.dynamo import torch_compile


def _register_with_torch():
Expand Down
8 changes: 8 additions & 0 deletions py/torch_tensorrt/_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class _IRType(Enum):
ts = 0
fx = 1
fx_ts_compat = 2
torch_compile = 3


class _ModuleType(Enum):
Expand Down Expand Up @@ -46,6 +47,7 @@ def _get_target_ir(module_type: _ModuleType, ir: str) -> _IRType:

ir_targets_torchscript = any([ir == opt for opt in ["torchscript", "ts"]])
ir_targets_fx = ir == "fx"
ir_targets_torch_compile = ir == "torch_compile"
ir_targets_fx_ts_compat = ir == "fx_ts_compat"

if module_is_tsable and ir_targets_torchscript:
Expand All @@ -54,6 +56,8 @@ def _get_target_ir(module_type: _ModuleType, ir: str) -> _IRType:
return _IRType.fx
elif module_is_fxable and ir_targets_fx_ts_compat:
return _IRType.fx_ts_compat
elif module_is_fxable and ir_targets_torch_compile:
return _IRType.torch_compile
else:
if ir == "default":
# Options are listed in order of preference
Expand Down Expand Up @@ -152,6 +156,10 @@ def compile(
dynamic_batch=False,
**kwargs,
)
elif target_ir == _IRType.torch_compile:
return torch_tensorrt.dynamo.torch_compile(
module, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs
)
elif target_ir == _IRType.fx_ts_compat:
return torch_tensorrt.dynamo.fx_ts_compat.compile(
module, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from torch_tensorrt.dynamo import fx_ts_compat
from .torch_compile import compile as torch_compile
18 changes: 18 additions & 0 deletions py/torch_tensorrt/dynamo/test/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import pytest


def pytest_addoption(parser):
parser.addoption(
"--ir",
metavar="Internal Representation",
nargs=1,
type=str,
required=True,
help="IR to compile with",
choices=["torch_compile", "fx_ts_compat"],
)


@pytest.fixture
def ir(request):
return request.config.getoption("--ir")[0]
144 changes: 144 additions & 0 deletions py/torch_tensorrt/dynamo/test/test_dynamo_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import torch
import timm
import pytest

import torch_tensorrt as torchtrt
import torchvision.models as models

from transformers import BertModel

from utils import COSINE_THRESHOLD, cosine_similarity


@pytest.mark.unit
def test_resnet18(ir):
model = models.resnet18(pretrained=True).eval().to("cuda")
input = torch.randn((1, 3, 224, 224)).to("cuda")

compile_spec = {
"inputs": [
torchtrt.Input(
input.shape, dtype=torch.float, format=torch.contiguous_format
)
],
"device": torchtrt.Device("cuda:0"),
"enabled_precisions": {torch.float},
"ir": ir,
}

trt_mod = torchtrt.compile(model, **compile_spec)
cos_sim = cosine_similarity(model(input), trt_mod(input))
assert (
cos_sim > COSINE_THRESHOLD,
f"Resnet50 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)


@pytest.mark.unit
def test_mobilenet_v2(ir):
model = models.mobilenet_v2(pretrained=True).eval().to("cuda")
input = torch.randn((1, 3, 224, 224)).to("cuda")

compile_spec = {
"inputs": [
torchtrt.Input(
input.shape, dtype=torch.float, format=torch.contiguous_format
)
],
"device": torchtrt.Device("cuda:0"),
"enabled_precisions": {torch.float},
"ir": ir,
}

trt_mod = torchtrt.compile(model, **compile_spec)
cos_sim = cosine_similarity(model(input), trt_mod(input))
assert (
cos_sim > COSINE_THRESHOLD,
f"Mobilenet v2 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)


@pytest.mark.unit
def test_efficientnet_b0(ir):
model = timm.create_model("efficientnet_b0", pretrained=True).eval().to("cuda")
input = torch.randn((1, 3, 224, 224)).to("cuda")

compile_spec = {
"inputs": [
torchtrt.Input(
input.shape, dtype=torch.float, format=torch.contiguous_format
)
],
"device": torchtrt.Device("cuda:0"),
"enabled_precisions": {torch.float},
"ir": ir,
}

trt_mod = torchtrt.compile(model, **compile_spec)
cos_sim = cosine_similarity(model(input), trt_mod(input))
assert (
cos_sim > COSINE_THRESHOLD,
f"EfficientNet-B0 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)


@pytest.mark.unit
def test_bert_base_uncased(ir):
model = BertModel.from_pretrained("bert-base-uncased").cuda().eval()
input = torch.randint(0, 1, (1, 14), dtype=torch.int32).to("cuda")
input2 = torch.randint(0, 1, (1, 14), dtype=torch.int32).to("cuda")

compile_spec = {
"inputs": [
torchtrt.Input(
input.shape,
dtype=input.dtype,
format=torch.contiguous_format,
),
torchtrt.Input(
input.shape,
dtype=input.dtype,
format=torch.contiguous_format,
),
],
"device": torchtrt.Device("cuda:0"),
"enabled_precisions": {torch.float},
"truncate_long_and_double": True,
"debug": True,
"ir": ir,
}
trt_mod = torchtrt.compile(model, **compile_spec)

model_outputs = model(input, input2)
trt_model_outputs = trt_mod(input, input2)
for key in model_outputs.keys():
out, trt_out = model_outputs[key], trt_model_outputs[key]
cos_sim = cosine_similarity(out, trt_out)
assert (
cos_sim > COSINE_THRESHOLD,
f"HF BERT base-uncased TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)


@pytest.mark.unit
def test_resnet18_half(ir):
model = models.resnet18(pretrained=True).eval().to("cuda").half()
input = torch.randn((1, 3, 224, 224)).to("cuda").half()

compile_spec = {
"inputs": [
torchtrt.Input(
input.shape, dtype=torch.half, format=torch.contiguous_format
)
],
"device": torchtrt.Device("cuda:0"),
"enabled_precisions": {torch.half},
"ir": ir,
}

trt_mod = torchtrt.compile(model, **compile_spec)
cos_sim = cosine_similarity(model(input), trt_mod(input))
assert (
cos_sim > COSINE_THRESHOLD,
f"Resnet50 Half TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)
54 changes: 54 additions & 0 deletions py/torch_tensorrt/dynamo/test/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import torch

COSINE_THRESHOLD = 0.99


def cosine_similarity(gt_tensor, pred_tensor):
gt_tensor = gt_tensor.flatten().to(torch.float32)
pred_tensor = pred_tensor.flatten().to(torch.float32)
if torch.sum(gt_tensor) == 0.0 or torch.sum(pred_tensor) == 0.0:
if torch.allclose(gt_tensor, pred_tensor, atol=1e-4, rtol=1e-4, equal_nan=True):
return 1.0
res = torch.nn.functional.cosine_similarity(gt_tensor, pred_tensor, dim=0, eps=1e-6)
res = res.cpu().detach().item()

return res


def same_output_format(trt_output, torch_output):
# For each encountered collection type, ensure the torch and trt outputs agree
# on type and size, checking recursively through all member elements.
if isinstance(trt_output, tuple):
return (
isinstance(torch_output, tuple)
and (len(trt_output) == len(torch_output))
and all(
same_output_format(trt_entry, torch_entry)
for trt_entry, torch_entry in zip(trt_output, torch_output)
)
)
elif isinstance(trt_output, list):
return (
isinstance(torch_output, list)
and (len(trt_output) == len(torch_output))
and all(
same_output_format(trt_entry, torch_entry)
for trt_entry, torch_entry in zip(trt_output, torch_output)
)
)
elif isinstance(trt_output, dict):
return (
isinstance(torch_output, dict)
and (len(trt_output) == len(torch_output))
and (trt_output.keys() == torch_output.keys())
and all(
same_output_format(trt_output[key], torch_output[key])
for key in trt_output.keys()
)
)
elif isinstance(trt_output, set) or isinstance(trt_output, frozenset):
raise AssertionError(
"Unsupported output type 'set' encountered in output format check."
)
else:
return type(trt_output) is type(torch_output)
Loading