Skip to content

Commit 36d249d

Browse files
authored
fix: Fix null inputs case (#3334)
1 parent 68c2d45 commit 36d249d

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -263,8 +263,10 @@ def cross_compile_for_windows(
263263
"When enable_weight_streaming is enabled, it requires use_explicit_typing to be set to True"
264264
)
265265
# Aliasing inputs to arg_inputs for better understanding
266-
if not arg_inputs and not inputs:
267-
raise AssertionError("'arg_inputs' and 'inputs' should not both be None.")
266+
if not arg_inputs and not kwarg_inputs and not inputs:
267+
raise AssertionError(
268+
"'arg_inputs', 'kwarg_inputs' and 'inputs' should not all be None."
269+
)
268270

269271
elif arg_inputs and inputs:
270272
raise AssertionError(
@@ -582,8 +584,10 @@ def compile(
582584
"When enable_weight_streaming is enabled, it requires use_explicit_typing to be set to True"
583585
)
584586
# Aliasing inputs to arg_inputs for better understanding
585-
if not arg_inputs and not inputs:
586-
raise AssertionError("'arg_inputs' and 'inputs' should not both be None.")
587+
if not arg_inputs and not kwarg_inputs and not inputs:
588+
raise AssertionError(
589+
"'arg_inputs', 'kwarg_inputs' and 'inputs' should not all be None."
590+
)
587591

588592
elif arg_inputs and inputs:
589593
raise AssertionError(
@@ -1069,8 +1073,10 @@ def convert_exported_program_to_serialized_trt_engine(
10691073
"TensorRT's `REFIT` flag is not compatible with `enable_weight_streaming=True` for now. This issue was reported on https://github.com/pytorch/TensorRT/issues/3305"
10701074
)
10711075

1072-
if arg_inputs is None and inputs is None:
1073-
raise AssertionError("'arg_inputs' and 'inputs' should not both be None.")
1076+
if not arg_inputs and not kwarg_inputs and not inputs:
1077+
raise AssertionError(
1078+
"'arg_inputs', 'kwarg_inputs' and 'inputs' should not all be None."
1079+
)
10741080

10751081
elif arg_inputs is not None and inputs is not None:
10761082
raise AssertionError(

0 commit comments

Comments
 (0)