Skip to content

fix: Repair usage of torch_executed_ops #2562

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

import collections.abc
import logging
from typing import Any, List, Optional, Sequence, Set, Tuple, Union
from typing import Any, Collection, List, Optional, Sequence, Set, Tuple, Union

import torch
from torch.export import ExportedProgram
from torch.fx.node import Target
from torch_tensorrt._Device import Device
from torch_tensorrt._enums import ( # TODO: Should probabably be the TRT EngineCapability Enum
EngineCapability,
Expand Down Expand Up @@ -49,6 +50,9 @@
convert_module,
repair_long_or_double_inputs,
)
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
DYNAMO_CONVERTERS as CONVERTERS,
)
from torch_tensorrt.dynamo.lowering import apply_lowering_passes, get_decompositions
from torch_tensorrt.dynamo.utils import (
get_torch_inputs,
Expand Down Expand Up @@ -85,7 +89,7 @@ def compile(
truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE,
require_full_compilation: bool = REQUIRE_FULL_COMPILATION,
min_block_size: int = MIN_BLOCK_SIZE,
torch_executed_ops: Optional[List[str]] = None,
torch_executed_ops: Optional[Collection[Target]] = None,
torch_executed_modules: Optional[List[str]] = None,
pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES,
max_aux_streams: Optional[int] = MAX_AUX_STREAMS,
Expand Down Expand Up @@ -143,7 +147,7 @@ def compile(
calibrator (Union(torch_tensorrt._C.IInt8Calibrator, tensorrt.IInt8Calibrator)): Calibrator object which will provide data to the PTQ system for INT8 Calibration
require_full_compilation (bool): Require modules to be compiled end to end or return an error as opposed to returning a hybrid graph where operations that cannot be run in TensorRT are run in PyTorch
min_block_size (int): The minimum number of contiguous TensorRT convertable operations in order to run a set of operations in TensorRT
torch_executed_ops (List[str]): List of aten operators that must be run in PyTorch. An error will be thrown if this list is not empty but ``require_full_compilation`` is True
torch_executed_ops (Collection[Target]): Set of aten operators that must be run in PyTorch. An error will be thrown if this set is not empty but ``require_full_compilation`` is True
torch_executed_modules (List[str]): List of modules that must be run in PyTorch. An error will be thrown if this list is not empty but ``require_full_compilation`` is True
pass_through_build_failures (bool): Error out if there are issues during compilation (only applicable to torch.compile workflows)
max_aux_stream (Optional[int]): Maximum streams in the engine
Expand Down Expand Up @@ -212,7 +216,7 @@ def compile(
"min_block_size": min_block_size,
"torch_executed_ops": torch_executed_ops
if torch_executed_ops is not None
else [],
else set(),
"pass_through_build_failures": pass_through_build_failures,
"max_aux_streams": max_aux_streams,
"version_compatible": version_compatible,
Expand Down Expand Up @@ -256,6 +260,9 @@ def compile_module(
"""
dryrun_tracker = DryRunTracker()

# Set torch-executed ops
CONVERTERS.set_disallowed_targets(settings.torch_executed_ops)

# Check the number of supported operations in the graph
num_supported_ops, total_ops = partitioning.get_graph_converter_support(
gm, settings.debug, settings.torch_executed_ops
Expand Down
7 changes: 4 additions & 3 deletions py/torch_tensorrt/dynamo/_settings.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from dataclasses import dataclass, field
from typing import Optional, Set, Union
from typing import Collection, Optional, Union

import torch
from tensorrt import EngineCapability
from torch.fx.node import Target
from torch_tensorrt._Device import Device
from torch_tensorrt.dynamo._defaults import (
DEBUG,
Expand Down Expand Up @@ -41,7 +42,7 @@ class CompilationSettings:
debug (bool): Whether to print out verbose debugging information
workspace_size (int): Workspace TRT is allowed to use for the module (0 is default)
min_block_size (int): Minimum number of operators per TRT-Engine Block
torch_executed_ops (Sequence[str]): Sequence of operations to run in Torch, regardless of converter coverage
torch_executed_ops (Collection[Target]): Collection of operations to run in Torch, regardless of converter coverage
pass_through_build_failures (bool): Whether to fail on TRT engine build errors (True) or not (False)
max_aux_streams (Optional[int]): Maximum number of allowed auxiliary TRT streams for each engine
version_compatible (bool): Provide version forward-compatibility for engine plan files
Expand Down Expand Up @@ -75,7 +76,7 @@ class CompilationSettings:
debug: bool = DEBUG
workspace_size: int = WORKSPACE_SIZE
min_block_size: int = MIN_BLOCK_SIZE
torch_executed_ops: Set[str] = field(default_factory=set)
torch_executed_ops: Collection[Target] = field(default_factory=set)
pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES
max_aux_streams: Optional[int] = MAX_AUX_STREAMS
version_compatible: bool = VERSION_COMPATIBLE
Expand Down
25 changes: 25 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import (
Any,
Callable,
Collection,
Dict,
List,
Optional,
Expand Down Expand Up @@ -212,8 +213,16 @@ def __init__(
CallingConvention.CTX for _ in range(len(self.registries))
]

self.disallowed_targets: Collection[Target] = set()

self.validate_invariants()

def set_disallowed_targets(self, torch_executed_ops: Collection[Target]) -> None:
self.disallowed_targets = torch_executed_ops

def get_disallowed_targets(self, torch_executed_ops: Collection[Target]) -> None:
self.disallowed_targets = torch_executed_ops

def validate_invariants(self) -> None:
"""Validates the invariants required of the dictionaries in the registries

Expand Down Expand Up @@ -253,6 +262,14 @@ def __getitem_without_validation__(

self.validate_invariants()

if (
key in self.disallowed_targets
or self.qualified_name_or_str(key) in self.disallowed_targets
):
raise KeyError(
f"A converter exists for {key}, but it was " "explicitly disallowed"
)

# Iterate over all registries and return the first converter found
for registry, calling_convention in zip(
self.registries, self.registry_calling_conventions
Expand Down Expand Up @@ -288,6 +305,14 @@ def __getitem__(
self.validate_invariants()
key = node.target

if (
key in self.disallowed_targets
or self.qualified_name_or_str(key) in self.disallowed_targets
):
raise KeyError(
f"A converter exists for {key}, but it was " "explicitly disallowed"
)

# Iterate over all registries, validating the converter on the input node
# If no capability_validator function is found, assume full coverage
for registry, calling_convention in zip(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,10 @@ def is_node_supported(
node_name = ConverterRegistry.qualified_name_or_str(node.target)

if (
node in CONVERTERS or node.op == "get_attr"
) and node_name not in self.torch_executed_ops:
(node in CONVERTERS or node.op == "get_attr")
and node_name not in self.torch_executed_ops
and node.target not in self.torch_executed_ops
):
# If node is a proper, supported computational node, store the operator
if not node.is_impure() and node.op != "get_attr":
if node_name not in self.supported_operators:
Expand Down
25 changes: 11 additions & 14 deletions py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import logging
from typing import Collection, Dict, List, Mapping, Optional, Sequence, Set, Tuple
from typing import Collection, Dict, List, Mapping, Optional, Sequence, Tuple

import torch
from torch.fx.graph_module import GraphModule
from torch.fx.node import Target
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition
from torch.fx.passes.operator_support import OperatorSupport, SupportDict
from torch_tensorrt.dynamo._defaults import (
Expand Down Expand Up @@ -133,25 +134,25 @@ class TorchTensorRTOperatorSupport(OperatorSupport): # type: ignore[misc]
def __init__(
self,
support_dict: Optional[SupportDict] = None,
torch_executed_ops: Optional[Set[str]] = None,
torch_executed_ops: Collection[Target] = set(),
):
super().__init__(support_dict)

# Initialize sets of supported/unsupported operators
self.supported_operators: Dict[str, int] = {}
self.unsupported_operators: Dict[str, int] = {}
self.torch_executed_ops: Set[str] = (
torch_executed_ops if torch_executed_ops is not None else set()
)
self.torch_executed_ops: Collection[Target] = torch_executed_ops

def is_node_supported(
self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node
) -> bool:
node_name = ConverterRegistry.qualified_name_or_str(node.target)

if (
node in CONVERTERS or node.op == "get_attr"
) and node_name not in self.torch_executed_ops:
(node in CONVERTERS or node.op == "get_attr")
and node_name not in self.torch_executed_ops
and node.target not in self.torch_executed_ops
):
# If node is a proper, supported computational node, store the operator
if not node.is_impure() and node.op != "get_attr":
if node_name not in self.supported_operators:
Expand Down Expand Up @@ -201,7 +202,7 @@ def partition(
gm: torch.fx.GraphModule,
verbose: bool = DEBUG,
min_block_size: int = MIN_BLOCK_SIZE,
torch_executed_ops: Optional[Set[str]] = None,
torch_executed_ops: Collection[Target] = set(),
require_full_compilation: bool = REQUIRE_FULL_COMPILATION,
) -> Tuple[torch.fx.GraphModule, TorchTensorRTOperatorSupport]:
"""Partition an FX GraphModule with aten ops into TRT engines
Expand All @@ -211,16 +212,12 @@ def partition(
gm: FX GraphModule to partition
verbose: Bool representing whether to print operator support
min_block_size: Minimum number of operators per TRT-Engine Block
torch_executed_ops: Sequence of operations to run in Torch, regardless of converter coverage
torch_executed_ops: Collection of operations to run in Torch, regardless of converter coverage
require_full_compilation: Whether to require that all operators be run in TRT
Returns:
torch.fx.GraphModule, TorchTensorRTOperatorSupport
"""
supported_ops = TorchTensorRTOperatorSupport(
torch_executed_ops=torch_executed_ops
if torch_executed_ops is not None
else set()
)
supported_ops = TorchTensorRTOperatorSupport(torch_executed_ops=torch_executed_ops)
partitioner = TRTPartitioner(
gm,
supported_ops,
Expand Down
5 changes: 3 additions & 2 deletions tests/py/dynamo/models/test_dyn_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import pytest
import timm
import torch
import torch_tensorrt as torchtrt
from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity

import torch_tensorrt as torchtrt

assertions = unittest.TestCase()


Expand Down Expand Up @@ -97,7 +98,7 @@ def forward(self, x):
"ir": ir,
"pass_through_build_failures": True,
"optimization_level": 1,
"torch_executed_ops": "torch.ops.aten.abs.default",
"torch_executed_ops": {"torch.ops.aten.abs.default"},
"min_block_size": 1,
}

Expand Down
5 changes: 3 additions & 2 deletions tests/py/dynamo/models/test_export_serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import pytest
import timm
import torch
import torch_tensorrt as torchtrt
import torchvision.models as models
from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity

import torch_tensorrt as torchtrt

assertions = unittest.TestCase()


Expand Down Expand Up @@ -206,7 +207,7 @@ def forward(self, x):
],
"ir": ir,
"min_block_size": 1,
"torch_executed_ops": "torch.ops.aten.relu.default",
"torch_executed_ops": {"torch.ops.aten.relu.default"},
}

exp_program = torchtrt.dynamo.trace(model, **compile_spec)
Expand Down