Skip to content

Commit b76024d

Browse files
authored
fix: Missing parameters in compiler settings (#2749)
1 parent 89003e9 commit b76024d

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

py/torch_tensorrt/ts/_compiler.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44

55
import torch
66
import torch_tensorrt._C.ts as _C
7-
from torch_tensorrt import _enums
87
from torch_tensorrt._Device import Device
98
from torch_tensorrt._Input import Input
109
from torch_tensorrt.ts._compile_spec import _parse_compile_spec, _parse_device
1110

11+
from torch_tensorrt import _enums
12+
1213

1314
def compile(
1415
module: torch.jit.ScriptModule,
@@ -137,6 +138,9 @@ def compile(
137138
"capability": capability, # Restrict kernel selection to safe gpu kernels or safe dla kernels
138139
"num_avg_timing_iters": num_avg_timing_iters, # Number of averaging timing iterations used to select kernels
139140
"workspace_size": workspace_size, # Maximum size of workspace given to TensorRT
141+
"dla_sram_size": dla_sram_size,
142+
"dla_local_dram_size": dla_local_dram_size,
143+
"dla_global_dram_size": dla_global_dram_size,
140144
"calibrator": calibrator,
141145
"truncate_long_and_double": truncate_long_and_double,
142146
"torch_fallback": {

0 commit comments

Comments
 (0)