Skip to content

Commit 233d0bf

Browse files
authored
don't initialize cuda at import time (#3244)
1 parent 5129688 commit 233d0bf

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

py/torch_tensorrt/ts/_compile_spec.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ def _parse_compile_spec(compile_spec_: Dict[str, Any]) -> _ts_C.CompileSpec:
307307
def TensorRTCompileSpec(
308308
inputs: Optional[List[torch.Tensor | Input]] = None,
309309
input_signature: Optional[Any] = None,
310-
device: torch.device | Device = Device._current_device(),
310+
device: Optional[torch.device | Device] = None,
311311
disable_tf32: bool = False,
312312
sparse_weights: bool = False,
313313
enabled_precisions: Optional[Set[torch.dtype | dtype]] = None,
@@ -365,7 +365,7 @@ def TensorRTCompileSpec(
365365
compile_spec = {
366366
"inputs": inputs if inputs is not None else [],
367367
# "input_signature": input_signature,
368-
"device": device,
368+
"device": Device._current_device() if device is None else device,
369369
"disable_tf32": disable_tf32, # Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas
370370
"sparse_weights": sparse_weights, # Enable sparsity for convolution and fully connected layers.
371371
"enabled_precisions": (

0 commit comments

Comments
 (0)