-
Notifications
You must be signed in to change notification settings - Fork 364
fix: Add special cases for clone
and to_copy
where input of graph is output
#2265
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -45,27 +45,49 @@ def get_node_name(node: torch.fx.Node) -> str: | |
return node_name | ||
|
||
|
||
def is_only_operator_on_placeholder(node: torch.fx.Node) -> bool: | ||
"""Detects whether a call_function node is the only operator on a placeholder""" | ||
# Returns true if the node operates on a placeholder and is a direct output | ||
return ( | ||
node.op == "call_function" | ||
and any( | ||
arg.op == "placeholder" | ||
for arg in node.args | ||
if isinstance(arg, torch.fx.Node) | ||
) | ||
and any(user.op == "output" for user in list(node.users.keys())) | ||
) | ||
|
||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is there There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
def dynamic_unsupported(node: torch.fx.Node) -> bool: | ||
# Validate that none of the inputs to the node have Dynamic shapes | ||
assert isinstance( | ||
node, torch.fx.Node | ||
), "Inputs to validator functions must be FX Nodes" | ||
|
||
# Check node value itself | ||
if getattr(node.meta["val"], "_has_symbolic_sizes_strides", False): | ||
if ("val" in node.meta) and getattr( | ||
node.meta["val"], "_has_symbolic_sizes_strides", False | ||
): | ||
return False | ||
|
||
# Check node arguments individually | ||
if any( | ||
getattr(arg.meta["val"], "_has_symbolic_sizes_strides", False) | ||
( | ||
("val" in arg.meta) | ||
and getattr(arg.meta["val"], "_has_symbolic_sizes_strides", False) | ||
) | ||
for arg in node.args | ||
if isinstance(arg, torch.fx.Node) | ||
): | ||
return False | ||
|
||
# Check node keyword arguments individually | ||
if any( | ||
getattr(kwarg.meta["val"], "_has_symbolic_sizes_strides", False) | ||
( | ||
("val" in kwarg.meta) | ||
and getattr(kwarg.meta["val"], "_has_symbolic_sizes_strides", False) | ||
) | ||
for kwarg in node.kwargs.values() | ||
if isinstance(kwarg, torch.fx.Node) | ||
): | ||
|
@@ -82,9 +104,12 @@ def cast_trt_tensor( | |
target: Target = "", | ||
source_ir: Optional[SourceIR] = None, | ||
) -> TRTTensor: | ||
""" | ||
Given a TRT Tensor, convert that Tensor to the specified dtype | ||
"""Given a TRT Tensor, convert that Tensor to the specified dtype | ||
|
||
Adds an Identity layer to the network which performs the conversion | ||
if the input's dtype is different from the cast type. Otherwise returns | ||
input unchanged | ||
|
||
Args: | ||
network (TRTNetwork): A TensorRT network | ||
input_val (TRTTensor): A TRT Tensor to cast to a new data type | ||
|
@@ -191,7 +216,7 @@ def extend_attr_to_tuple( | |
if isinstance(val, tuple): | ||
return val | ||
else: | ||
raise AssertionError(f"Could not extend attribute {val}") | ||
raise AssertionError(f"Object {val} could not be extended to tuple") | ||
|
||
|
||
def cast_int_or_float_to_bool( | ||
|
Uh oh!
There was an error while loading. Please reload this page.