Skip to content

Commit 72189f9

Browse files
authored
[TensorRT EP] Enable more trt options (#237)
1 parent 6b896f0 commit 72189f9

File tree

2 files changed

+218
-0
lines changed

2 files changed

+218
-0
lines changed

README.md

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,48 @@ TensorRT can be used in conjunction with an ONNX model to further optimize the p
9393
* `trt_engine_cache_enable`: Enable engine caching.
9494
* `trt_engine_cache_path`: Specify engine cache path.
9595

96+
To explore the usage of more parameters, follow the mapping table below and check [ONNX Runtime doc](https://onnxruntime.ai/docs/execution-providers/TensorRT-ExecutionProvider.html#execution-provider-options) for detail.
97+
98+
> Please link to the latest ONNX Runtime binaries in CMake or build from [main branch of ONNX Runtime](https://github.com/microsoft/onnxruntime/tree/main) to enable latest options.
99+
100+
### Parameter mapping between ONNX Runtime and Triton ONNXRuntime Backend
101+
102+
| Key in Triton model configuration | Value in Triton model config | Corresponding TensorRT EP option in ONNX Runtime | Type |
103+
| --------------------------------- | --------------------------------------------------- | :----------------------------------------------- | :----- |
104+
| max_workspace_size_bytes | e.g: "4294967296" | trt_max_workspace_size | int |
105+
| trt_max_partition_iterations | e.g: "1000" | trt_max_partition_iterations | int |
106+
| trt_min_subgraph_size | e.g: "1" | trt_min_subgraph_size | int |
107+
| precision_mode | "FP16" | trt_fp16_enable | bool |
108+
| precision_mode | "INT8" | trt_int8_enable | bool |
109+
| int8_calibration_table_name | | trt_int8_calibration_table_name | string |
110+
| int8_use_native_calibration_table | e.g: "1" or "true", "0" or "false" | trt_int8_use_native_calibration_table | bool |
111+
| trt_dla_enable | | trt_dla_enable | bool |
112+
| trt_dla_core | e.g: "0" | trt_dla_core | int |
113+
| trt_engine_cache_enable | e.g: "1" or "true", "0" or "false" | trt_engine_cache_enable | bool |
114+
| trt_engine_cache_path | | trt_engine_cache_path | string |
115+
| trt_engine_cache_prefix | | trt_engine_cache_prefix | string |
116+
| trt_dump_subgraphs | e.g: "1" or "true", "0" or "false" | trt_dump_subgraphs | bool |
117+
| trt_force_sequential_engine_build | e.g: "1" or "true", "0" or "false" | trt_force_sequential_engine_build | bool |
118+
| trt_context_memory_sharing_enable | e.g: "1" or "true", "0" or "false" | trt_context_memory_sharing_enable | bool |
119+
| trt_layer_norm_fp32_fallback | e.g: "1" or "true", "0" or "false" | trt_layer_norm_fp32_fallback | bool |
120+
| trt_timing_cache_enable | e.g: "1" or "true", "0" or "false" | trt_timing_cache_enable | bool |
121+
| trt_timing_cache_path | | trt_timing_cache_path | string |
122+
| trt_force_timing_cache | e.g: "1" or "true", "0" or "false" | trt_force_timing_cache | bool |
123+
| trt_detailed_build_log | e.g: "1" or "true", "0" or "false" | trt_detailed_build_log | bool |
124+
| trt_build_heuristics_enable | e.g: "1" or "true", "0" or "false" | trt_build_heuristics_enable | bool |
125+
| trt_sparsity_enable | e.g: "1" or "true", "0" or "false" | trt_sparsity_enable | bool |
126+
| trt_builder_optimization_level | e.g: "3" | trt_builder_optimization_level | int |
127+
| trt_auxiliary_streams | e.g: "-1" | trt_auxiliary_streams | int |
128+
| trt_tactic_sources | e.g: "-CUDNN,+CUBLAS"; | trt_tactic_sources | string |
129+
| trt_extra_plugin_lib_paths | | trt_extra_plugin_lib_paths | string |
130+
| trt_profile_min_shapes | e.g: "input1:dim1xdimd2...,input2:dim1xdim2...,..." | trt_profile_min_shapes | string |
131+
| trt_profile_max_shapes | e.g: "input1:dim1xdimd2...,input2:dim1xdim2...,..." | trt_profile_max_shapes | string |
132+
| trt_profile_opt_shapes | e.g: "input1:dim1xdimd2...,input2:dim1xdim2...,..." | trt_profile_opt_shapes | string |
133+
| trt_cuda_graph_enable | e.g: "1" or "true", "0" or "false" | trt_cuda_graph_enable | bool |
134+
| trt_dump_ep_context_model | e.g: "1" or "true", "0" or "false" | trt_dump_ep_context_model | bool |
135+
| trt_ep_context_file_path | | trt_ep_context_file_path | string |
136+
| trt_ep_context_embed_mode | e.g: "1" | trt_ep_context_embed_mode | int |
137+
96138
The section of model config file specifying these parameters will look like:
97139

98140
```
@@ -104,6 +146,7 @@ optimization { execution_accelerators {
104146
name : "tensorrt"
105147
parameters { key: "precision_mode" value: "FP16" }
106148
parameters { key: "max_workspace_size_bytes" value: "1073741824" }}
149+
parameters { key: "trt_engine_cache_enable" value: "1" }}
107150
]
108151
}}
109152
.

src/onnxruntime.cc

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,22 @@ ModelState::LoadModel(
473473
value_string, &max_workspace_size_bytes));
474474
key = "trt_max_workspace_size";
475475
value = value_string;
476+
} else if (param_key == "trt_max_partition_iterations") {
477+
RETURN_IF_ERROR(params.MemberAsString(
478+
param_key.c_str(), &value_string));
479+
int trt_max_partition_iterations;
480+
RETURN_IF_ERROR(ParseIntValue(
481+
value_string, &trt_max_partition_iterations));
482+
key = "trt_max_partition_iterations";
483+
value = value_string;
484+
} else if (param_key == "trt_min_subgraph_size") {
485+
RETURN_IF_ERROR(params.MemberAsString(
486+
param_key.c_str(), &value_string));
487+
int trt_min_subgraph_size;
488+
RETURN_IF_ERROR(
489+
ParseIntValue(value_string, &trt_min_subgraph_size));
490+
key = "trt_min_subgraph_size";
491+
value = value_string;
476492
} else if (param_key == "int8_calibration_table_name") {
477493
RETURN_IF_ERROR(
478494
params.MemberAsString(param_key.c_str(), &value));
@@ -485,6 +501,21 @@ ModelState::LoadModel(
485501
value_string, &use_native_calibration_table));
486502
key = "trt_int8_use_native_calibration_table";
487503
value = value_string;
504+
} else if (param_key == "trt_dla_enable") {
505+
RETURN_IF_ERROR(params.MemberAsString(
506+
param_key.c_str(), &value_string));
507+
bool trt_dla_enable;
508+
RETURN_IF_ERROR(
509+
ParseBoolValue(value_string, &trt_dla_enable));
510+
key = "trt_dla_enable";
511+
value = value_string;
512+
} else if (param_key == "trt_dla_core") {
513+
RETURN_IF_ERROR(params.MemberAsString(
514+
param_key.c_str(), &value_string));
515+
int trt_dla_core;
516+
RETURN_IF_ERROR(ParseIntValue(value_string, &trt_dla_core));
517+
key = "trt_dla_core";
518+
value = value_string;
488519
} else if (param_key == "trt_engine_cache_enable") {
489520
RETURN_IF_ERROR(params.MemberAsString(
490521
param_key.c_str(), &value_string));
@@ -497,6 +528,150 @@ ModelState::LoadModel(
497528
RETURN_IF_ERROR(
498529
params.MemberAsString(param_key.c_str(), &value));
499530
key = "trt_engine_cache_path";
531+
} else if (param_key == "trt_engine_cache_prefix") {
532+
RETURN_IF_ERROR(
533+
params.MemberAsString(param_key.c_str(), &value));
534+
key = "trt_engine_cache_prefix";
535+
} else if (param_key == "trt_dump_subgraphs") {
536+
RETURN_IF_ERROR(params.MemberAsString(
537+
param_key.c_str(), &value_string));
538+
bool dump_subgraphs;
539+
RETURN_IF_ERROR(
540+
ParseBoolValue(value_string, &dump_subgraphs));
541+
key = "trt_dump_subgraphs";
542+
value = value_string;
543+
} else if (param_key == "trt_force_sequential_engine_build") {
544+
RETURN_IF_ERROR(params.MemberAsString(
545+
param_key.c_str(), &value_string));
546+
bool trt_force_sequential_engine_build;
547+
RETURN_IF_ERROR(ParseBoolValue(
548+
value_string, &trt_force_sequential_engine_build));
549+
key = "trt_force_sequential_engine_build";
550+
value = value_string;
551+
} else if (param_key == "trt_context_memory_sharing_enable") {
552+
RETURN_IF_ERROR(params.MemberAsString(
553+
param_key.c_str(), &value_string));
554+
bool trt_context_memory_sharing_enable;
555+
RETURN_IF_ERROR(ParseBoolValue(
556+
value_string, &trt_context_memory_sharing_enable));
557+
key = "trt_context_memory_sharing_enable";
558+
value = value_string;
559+
} else if (param_key == "trt_layer_norm_fp32_fallback") {
560+
RETURN_IF_ERROR(params.MemberAsString(
561+
param_key.c_str(), &value_string));
562+
bool trt_layer_norm_fp32_fallback;
563+
RETURN_IF_ERROR(ParseBoolValue(
564+
value_string, &trt_layer_norm_fp32_fallback));
565+
key = "trt_layer_norm_fp32_fallback";
566+
value = value_string;
567+
} else if (param_key == "trt_timing_cache_enable") {
568+
RETURN_IF_ERROR(params.MemberAsString(
569+
param_key.c_str(), &value_string));
570+
bool trt_timing_cache_enable;
571+
RETURN_IF_ERROR(
572+
ParseBoolValue(value_string, &trt_timing_cache_enable));
573+
key = "trt_timing_cache_enable";
574+
value = value_string;
575+
} else if (param_key == "trt_timing_cache_path") {
576+
RETURN_IF_ERROR(
577+
params.MemberAsString(param_key.c_str(), &value));
578+
key = "trt_timing_cache_path";
579+
} else if (param_key == "trt_force_timing_cache") {
580+
RETURN_IF_ERROR(params.MemberAsString(
581+
param_key.c_str(), &value_string));
582+
bool trt_force_timing_cache;
583+
RETURN_IF_ERROR(
584+
ParseBoolValue(value_string, &trt_force_timing_cache));
585+
key = "trt_force_timing_cache";
586+
value = value_string;
587+
} else if (param_key == "trt_detailed_build_log") {
588+
RETURN_IF_ERROR(params.MemberAsString(
589+
param_key.c_str(), &value_string));
590+
bool trt_detailed_build_log;
591+
RETURN_IF_ERROR(
592+
ParseBoolValue(value_string, &trt_detailed_build_log));
593+
key = "trt_detailed_build_log";
594+
value = value_string;
595+
} else if (param_key == "trt_build_heuristics_enable") {
596+
RETURN_IF_ERROR(params.MemberAsString(
597+
param_key.c_str(), &value_string));
598+
bool trt_build_heuristics_enable;
599+
RETURN_IF_ERROR(ParseBoolValue(
600+
value_string, &trt_build_heuristics_enable));
601+
key = "trt_build_heuristics_enable";
602+
value = value_string;
603+
} else if (param_key == "trt_sparsity_enable") {
604+
RETURN_IF_ERROR(params.MemberAsString(
605+
param_key.c_str(), &value_string));
606+
bool trt_sparsity_enable;
607+
RETURN_IF_ERROR(
608+
ParseBoolValue(value_string, &trt_sparsity_enable));
609+
key = "trt_sparsity_enable";
610+
value = value_string;
611+
} else if (param_key == "trt_builder_optimization_level") {
612+
RETURN_IF_ERROR(params.MemberAsString(
613+
param_key.c_str(), &value_string));
614+
int trt_builder_optimization_level;
615+
RETURN_IF_ERROR(ParseIntValue(
616+
value_string, &trt_builder_optimization_level));
617+
key = "trt_builder_optimization_level";
618+
value = value_string;
619+
} else if (param_key == "trt_auxiliary_streams") {
620+
RETURN_IF_ERROR(params.MemberAsString(
621+
param_key.c_str(), &value_string));
622+
int trt_auxiliary_streams;
623+
RETURN_IF_ERROR(
624+
ParseIntValue(value_string, &trt_auxiliary_streams));
625+
key = "trt_auxiliary_streams";
626+
value = value_string;
627+
} else if (param_key == "trt_tactic_sources") {
628+
RETURN_IF_ERROR(
629+
params.MemberAsString(param_key.c_str(), &value));
630+
key = "trt_tactic_sources";
631+
} else if (param_key == "trt_extra_plugin_lib_paths") {
632+
RETURN_IF_ERROR(
633+
params.MemberAsString(param_key.c_str(), &value));
634+
key = "trt_extra_plugin_lib_paths";
635+
} else if (param_key == "trt_profile_min_shapes") {
636+
RETURN_IF_ERROR(
637+
params.MemberAsString(param_key.c_str(), &value));
638+
key = "trt_profile_min_shapes";
639+
} else if (param_key == "trt_profile_max_shapes") {
640+
RETURN_IF_ERROR(
641+
params.MemberAsString(param_key.c_str(), &value));
642+
key = "trt_profile_max_shapes";
643+
} else if (param_key == "trt_profile_opt_shapes") {
644+
RETURN_IF_ERROR(
645+
params.MemberAsString(param_key.c_str(), &value));
646+
key = "trt_profile_opt_shapes";
647+
} else if (param_key == "trt_cuda_graph_enable") {
648+
RETURN_IF_ERROR(params.MemberAsString(
649+
param_key.c_str(), &value_string));
650+
bool trt_cuda_graph_enable;
651+
RETURN_IF_ERROR(
652+
ParseBoolValue(value_string, &trt_cuda_graph_enable));
653+
key = "trt_cuda_graph_enable";
654+
value = value_string;
655+
} else if (param_key == "trt_dump_ep_context_model") {
656+
RETURN_IF_ERROR(params.MemberAsString(
657+
param_key.c_str(), &value_string));
658+
bool trt_dump_ep_context_model;
659+
RETURN_IF_ERROR(ParseBoolValue(
660+
value_string, &trt_dump_ep_context_model));
661+
key = "trt_dump_ep_context_model";
662+
value = value_string;
663+
} else if (param_key == "trt_ep_context_file_path") {
664+
RETURN_IF_ERROR(
665+
params.MemberAsString(param_key.c_str(), &value));
666+
key = "trt_ep_context_file_path";
667+
} else if (param_key == "trt_ep_context_embed_mode") {
668+
RETURN_IF_ERROR(params.MemberAsString(
669+
param_key.c_str(), &value_string));
670+
int trt_ep_context_embed_mode;
671+
RETURN_IF_ERROR(ParseIntValue(
672+
value_string, &trt_ep_context_embed_mode));
673+
key = "trt_ep_context_embed_mode";
674+
value = value_string;
500675
} else {
501676
return TRITONSERVER_ErrorNew(
502677
TRITONSERVER_ERROR_INVALID_ARG,

0 commit comments

Comments
 (0)