Skip to content

Commit de4f358

Browse files
committed
tests(//tests): Implementing tests to check new input class behavior
This commits adds tests for the new Input class including verifying that default behavior works properly It also moves tests out module and into cpp for cpp api tests Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent bdde52e commit de4f358

File tree

13 files changed

+390
-116
lines changed

13 files changed

+390
-116
lines changed

cpp/api/include/trtorch/trtorch.h

Lines changed: 80 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,19 @@ struct TRTORCH_API CompileSpec {
371371
DataType dtype;
372372
/// Expected tensor format for the input
373373
TensorFormat format;
374+
375+
/**
376+
* @brief Construct a new Input spec object for static input size from
377+
* vector, optional arguments allow the user to configure expected input shape
378+
* tensor format. dtype (Expected data type for the input) defaults to PyTorch
379+
* / traditional TRT convection (FP32 for FP32 only, FP16 for FP32 and FP16, FP32 for Int8)
380+
*
381+
* @param shape Input tensor shape
382+
* @param dtype Expected data type for the input (Defaults to Float32)
383+
* @param format Expected tensor format for the input (Defaults to contiguous)
384+
*/
385+
Input(std::vector<int64_t> shape, TensorFormat format=TensorFormat::kContiguous);
386+
374387
/**
375388
* @brief Construct a new Input spec object for static input size from
376389
* vector, optional arguments allow the user to configure expected input shape
@@ -380,7 +393,20 @@ struct TRTORCH_API CompileSpec {
380393
* @param dtype Expected data type for the input (Defaults to Float32)
381394
* @param format Expected tensor format for the input (Defaults to contiguous)
382395
*/
383-
Input(std::vector<int64_t> shape, DataType dtype=DataType::kFloat, TensorFormat format=TensorFormat::kContiguous);
396+
Input(std::vector<int64_t> shape, DataType dtype, TensorFormat format=TensorFormat::kContiguous);
397+
398+
/**
399+
* @brief Construct a new Input spec object for static input size from
400+
* c10::ArrayRef (the type produced by tensor.sizes()), vector, optional arguments
401+
* allow the user to configure expected input shape tensor format
402+
* dtype (Expected data type for the input) defaults to PyTorch
403+
* / traditional TRT convection (FP32 for FP32 only, FP16 for FP32 and FP16, FP32 for Int8)
404+
*
405+
* @param shape Input tensor shape
406+
* @param format Expected tensor format for the input (Defaults to contiguous)
407+
*/
408+
Input(c10::ArrayRef<int64_t> shape, TensorFormat format=TensorFormat::kContiguous);
409+
384410
/**
385411
* @brief Construct a new Input spec object for static input size from
386412
* c10::ArrayRef (the type produced by tensor.sizes()), vector, optional arguments
@@ -390,7 +416,21 @@ struct TRTORCH_API CompileSpec {
390416
* @param dtype Expected data type for the input (Defaults to Float32)
391417
* @param format Expected tensor format for the input (Defaults to contiguous)
392418
*/
393-
Input(c10::ArrayRef<int64_t> shape, DataType dtype=DataType::kFloat, TensorFormat format=TensorFormat::kContiguous);
419+
Input(c10::ArrayRef<int64_t> shape, DataType dtype, TensorFormat format=TensorFormat::kContiguous);
420+
421+
/**
422+
* @brief Construct a new Input Range object dynamic input size from
423+
* c10::ArrayRef (the type produced by tensor.sizes()) for min, opt, and max
424+
* supported sizes. dtype (Expected data type for the input) defaults to PyTorch
425+
* / traditional TRT convection (FP32 for FP32 only, FP16 for FP32 and FP16, FP32 for Int8)
426+
*
427+
* @param min_shape Minimum shape for input tensor
428+
* @param opt_shape Target optimization shape for input tensor
429+
* @param max_shape Maximum acceptible shape for input tensor
430+
* @param format Expected tensor format for the input (Defaults to contiguous)
431+
*/
432+
Input(std::vector<int64_t> min_shape, std::vector<int64_t> opt_shape, std::vector<int64_t> max_shape, TensorFormat format=TensorFormat::kContiguous);
433+
394434
/**
395435
* @brief Construct a new Input spec object for a dynamic input size from vectors
396436
* for minimum shape, optimal shape, and max shape supported sizes optional arguments
@@ -402,7 +442,21 @@ struct TRTORCH_API CompileSpec {
402442
* @param dtype Expected data type for the input (Defaults to Float32)
403443
* @param format Expected tensor format for the input (Defaults to contiguous)
404444
*/
405-
Input(std::vector<int64_t> min_shape, std::vector<int64_t> opt_shape, std::vector<int64_t> max_shape, DataType dtype=DataType::kFloat, TensorFormat format=TensorFormat::kContiguous);
445+
Input(std::vector<int64_t> min_shape, std::vector<int64_t> opt_shape, std::vector<int64_t> max_shape, DataType dtype, TensorFormat format=TensorFormat::kContiguous);
446+
447+
/**
448+
* @brief Construct a new Input Range object dynamic input size from
449+
* c10::ArrayRef (the type produced by tensor.sizes()) for min, opt, and max
450+
* supported sizes. dtype (Expected data type for the input) defaults to PyTorch
451+
* / traditional TRT convection (FP32 for FP32 only, FP16 for FP32 and FP16, FP32 for Int8)
452+
*
453+
* @param min_shape Minimum shape for input tensor
454+
* @param opt_shape Target optimization shape for input tensor
455+
* @param max_shape Maximum acceptible shape for input tensor
456+
* @param format Expected tensor format for the input (Defaults to contiguous)
457+
*/
458+
Input(c10::ArrayRef<int64_t> min_shape, c10::ArrayRef<int64_t> opt_shape, c10::ArrayRef<int64_t> max_shape, TensorFormat format=TensorFormat::kContiguous);
459+
406460
/**
407461
* @brief Construct a new Input Range object dynamic input size from
408462
* c10::ArrayRef (the type produced by tensor.sizes()) for min, opt, and max
@@ -414,10 +468,12 @@ struct TRTORCH_API CompileSpec {
414468
* @param dtype Expected data type for the input (Defaults to Float32)
415469
* @param format Expected tensor format for the input (Defaults to contiguous)
416470
*/
417-
Input(c10::ArrayRef<int64_t> min_shape, c10::ArrayRef<int64_t> opt_shape, c10::ArrayRef<int64_t> max_shape, DataType dtype=DataType::kFloat, TensorFormat format=TensorFormat::kContiguous);
471+
Input(c10::ArrayRef<int64_t> min_shape, c10::ArrayRef<int64_t> opt_shape, c10::ArrayRef<int64_t> max_shape, DataType dtype, TensorFormat format=TensorFormat::kContiguous);
418472

473+
bool get_explicit_set_dtype() {return explicit_set_dtype;}
419474
private:
420475
bool input_is_dynamic;
476+
bool explicit_set_dtype;
421477
};
422478

423479
/**
@@ -512,28 +568,45 @@ struct TRTORCH_API CompileSpec {
512568
*
513569
* @param input_ranges
514570
*/
515-
[[deprecated("trtorch::CompileSpec::CompileSpec(std::vector<InputRange> input_ranges) is being deprecated in favor of trtorch::CompileSpec::CompileSpec(std::vector<Input> inputs). trtorch::CompileSpec::CompileSpec(std::vector<InputRange> input_ranges) will be removed in TRTorch v0.5.0")]]
571+
[[deprecated("trtorch::CompileSpec::CompileSpec(std::vector<InputRange> input_ranges) is being deprecated in favor of trtorch::CompileSpec::CompileSpec(std::vector<Input> inputs). Please use CompileSpec(std::vector<Input> inputs). trtorch::CompileSpec::CompileSpec(std::vector<InputRange> input_ranges) will be removed in TRTorch v0.5.0")]]
516572
CompileSpec(std::vector<InputRange> input_ranges) : input_ranges(std::move(input_ranges)) {}
517573
/**
518574
* @brief Construct a new Extra Info object
519575
* Convienence constructor to set fixed input size from vectors describing
520576
* size of input tensors. Each entry in the vector represents a input and
521577
* should be provided in call order.
522578
*
579+
* This constructor should be use as a convience in the case that all inputs are static sized and
580+
* you are okay with default input dtype and formats (FP32 for FP32 and INT8 weights, FP16 for FP16 weights, contiguous)
581+
*
523582
* @param fixed_sizes
524583
*/
525-
[[deprecated("trtorch::CompileSpec::InputRange is being deprecated in favor of trtorch::CompileSpec::Input. trtorch::CompileSpec::InputRange will be removed in TRTorch v0.5.0")]]
526584
CompileSpec(std::vector<std::vector<int64_t>> fixed_sizes);
585+
527586
/**
528587
* @brief Construct a new Extra Info object
529588
* Convienence constructor to set fixed input size from c10::ArrayRef's (the
530589
* output of tensor.sizes()) describing size of input tensors. Each entry in
531590
* the vector represents a input and should be provided in call order.
591+
*
592+
* This constructor should be use as a convience in the case that all inputs are static sized and
593+
* you are okay with default input dtype and formats (FP32 for FP32 and INT8 weights, FP16 for FP16 weights, contiguous)
594+
*
532595
* @param fixed_sizes
533596
*/
534-
[[deprecated("trtorch::CompileSpec::InputRange is being deprecated in favor of trtorch::CompileSpec::Input. trtorch::CompileSpec::InputRange will be removed in TRTorch v0.5.0")]]
535597
CompileSpec(std::vector<c10::ArrayRef<int64_t>> fixed_sizes);
536598

599+
/**
600+
* @brief Construct a new Extra Info object from input ranges.
601+
* Each entry in the vector represents a input and should be provided in call
602+
* order.
603+
*
604+
* Use this constructor to define inputs with dynamic shape, specific input types or tensor formats
605+
*
606+
* @param inputs
607+
*/
608+
CompileSpec(std::vector<Input> inputs) : inputs(std::move(inputs)) {}
609+
537610
// Defaults should reflect TensorRT defaults for BuilderConfig
538611

539612
/**

cpp/api/src/compile_spec.cpp

Lines changed: 68 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,23 +105,46 @@ CompileSpec::InputRange::InputRange(c10::IntArrayRef min, c10::IntArrayRef opt,
105105

106106
CompileSpec::CompileSpec(std::vector<c10::ArrayRef<int64_t>> fixed_sizes) {
107107
for (auto in : fixed_sizes) {
108-
input_ranges.push_back(InputRange(in));
108+
inputs.push_back(Input(in));
109109
}
110110
}
111111

112112
CompileSpec::CompileSpec(std::vector<std::vector<int64_t>> fixed_sizes) {
113113
for (auto in : fixed_sizes) {
114-
input_ranges.push_back(InputRange(in));
114+
inputs.push_back(Input(in));
115115
}
116116
}
117117

118118
/* ====== DEFINE INPUTS CLASS MEMBERS ======*/
119+
CompileSpec::Input::Input(std::vector<int64_t> shape, TensorFormat format) {
120+
this->opt_shape = shape;
121+
this->min_shape = shape;
122+
this->max_shape = shape;
123+
this->shape = shape;
124+
this->dtype = dtype;
125+
this->explicit_set_dtype = false;
126+
this->format = format;
127+
this->input_is_dynamic = false;
128+
}
129+
119130
CompileSpec::Input::Input(std::vector<int64_t> shape, DataType dtype, TensorFormat format) {
120131
this->opt_shape = shape;
121132
this->min_shape = shape;
122133
this->max_shape = shape;
123134
this->shape = shape;
124135
this->dtype = dtype;
136+
this->explicit_set_dtype = true;
137+
this->format = format;
138+
this->input_is_dynamic = false;
139+
}
140+
141+
CompileSpec::Input::Input(c10::IntArrayRef shape, TensorFormat format) {
142+
this->opt_shape = core::util::toVec(shape);
143+
this->min_shape = core::util::toVec(shape);
144+
this->max_shape = core::util::toVec(shape);
145+
this->shape = core::util::toVec(shape);
146+
this->dtype = DataType::kFloat;
147+
this->explicit_set_dtype = false;
125148
this->format = format;
126149
this->input_is_dynamic = false;
127150
}
@@ -132,16 +155,40 @@ CompileSpec::Input::Input(c10::IntArrayRef shape, DataType dtype, TensorFormat f
132155
this->max_shape = core::util::toVec(shape);
133156
this->shape = core::util::toVec(shape);
134157
this->dtype = dtype;
158+
this->explicit_set_dtype = true;
135159
this->format = format;
136160
this->input_is_dynamic = false;
137161
}
138162

163+
CompileSpec::Input::Input(std::vector<int64_t> min_shape, std::vector<int64_t> opt_shape, std::vector<int64_t> max_shape, TensorFormat format) {
164+
this->opt_shape = opt_shape;
165+
this->min_shape = min_shape;
166+
this->max_shape = max_shape;
167+
this->shape = core::util::toVec(core::ir::Input(this->min_shape, this->opt_shape, this->max_shape).input_shape);
168+
this->dtype = dtype;
169+
this->explicit_set_dtype = false;
170+
this->format = format;
171+
this->input_is_dynamic = true;
172+
}
173+
139174
CompileSpec::Input::Input(std::vector<int64_t> min_shape, std::vector<int64_t> opt_shape, std::vector<int64_t> max_shape, DataType dtype, TensorFormat format) {
140175
this->opt_shape = opt_shape;
141176
this->min_shape = min_shape;
142177
this->max_shape = max_shape;
143178
this->shape = core::util::toVec(core::ir::Input(this->min_shape, this->opt_shape, this->max_shape).input_shape);
144179
this->dtype = dtype;
180+
this->explicit_set_dtype = true;
181+
this->format = format;
182+
this->input_is_dynamic = true;
183+
}
184+
185+
CompileSpec::Input::Input(c10::IntArrayRef min_shape, c10::IntArrayRef opt_shape, c10::IntArrayRef max_shape, TensorFormat format) {
186+
this->opt_shape = core::util::toVec(opt_shape);
187+
this->min_shape = core::util::toVec(min_shape);
188+
this->max_shape = core::util::toVec(max_shape);
189+
this->shape = core::util::toVec(core::ir::Input(this->min_shape, this->opt_shape, this->max_shape).input_shape);
190+
this->dtype = dtype;
191+
this->explicit_set_dtype = false;
145192
this->format = format;
146193
this->input_is_dynamic = true;
147194
}
@@ -152,6 +199,7 @@ CompileSpec::Input::Input(c10::IntArrayRef min_shape, c10::IntArrayRef opt_shape
152199
this->max_shape = core::util::toVec(max_shape);
153200
this->shape = core::util::toVec(core::ir::Input(this->min_shape, this->opt_shape, this->max_shape).input_shape);
154201
this->dtype = dtype;
202+
this->explicit_set_dtype = true;
155203
this->format = format;
156204
this->input_is_dynamic = true;
157205
}
@@ -191,14 +239,31 @@ core::CompileSpec to_internal_compile_spec(CompileSpec external) {
191239
internal = core::CompileSpec(to_vec_internal_inputs(external.inputs));
192240
}
193241

194-
if (external.enabled_precisions.size() <= 1 && toTRTDataType(external.op_precision) != nvinfer1::DataType::kFLOAT) {
242+
if (external.enabled_precisions.size() <= 1 && toTRTDataType(*external.enabled_precisions.begin()) == nvinfer1::DataType::kFLOAT && toTRTDataType(external.op_precision) != nvinfer1::DataType::kFLOAT) {
195243
internal.convert_info.engine_settings.enabled_precisions.insert(toTRTDataType(external.op_precision));
196244
} else {
197245
for(auto p : external.enabled_precisions) {
198246
internal.convert_info.engine_settings.enabled_precisions.insert(toTRTDataType(p));
199247
}
200248
}
201249

250+
/* We want default behavior for types to match PyTorch, so in the case the user did not explicitly set the dtype for
251+
inputs they will follow PyTorch convetions */
252+
for (size_t i = 0; i < external.inputs.size(); i++) {
253+
std::cout << "EXPLICIT " << external.inputs[i].get_explicit_set_dtype() << std::endl;
254+
if (!external.inputs[i].get_explicit_set_dtype()) {
255+
auto& precisions = internal.convert_info.engine_settings.enabled_precisions;
256+
auto& internal_ins = internal.convert_info.inputs;
257+
if (precisions.find(nvinfer1::DataType::kINT8) != precisions.end()) {
258+
internal_ins[i].dtype = nvinfer1::DataType::kFLOAT;
259+
} else if (precisions.find(nvinfer1::DataType::kHALF) != precisions.end()) {
260+
internal_ins[i].dtype = nvinfer1::DataType::kHALF;
261+
} else {
262+
internal_ins[i].dtype = nvinfer1::DataType::kFLOAT;
263+
}
264+
std::cout << "internal type: " << internal_ins[i].dtype;
265+
}
266+
}
202267
internal.convert_info.engine_settings.disable_tf32 = external.disable_tf32;
203268
internal.convert_info.engine_settings.refit = external.refit;
204269
internal.convert_info.engine_settings.debug = external.debug;

tests/BUILD

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
test_suite(
22
name = "tests",
33
tests = [
4+
":cpp_api_tests",
45
"//tests/core:core_tests",
5-
"//tests/modules:module_tests",
66
],
77
)
88

@@ -30,6 +30,13 @@ test_suite(
3030
],
3131
)
3232

33+
test_suite(
34+
name = "cpp_api_tests",
35+
tests = [
36+
"//tests/cpp:api_tests"
37+
]
38+
)
39+
3340
test_suite(
3441
name = "python_api_tests",
3542
tests = [

tests/core/runtime/BUILD

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
package(default_visibility = ["//visibility:public"])
2+
3+
config_setting(
4+
name = "use_pre_cxx11_abi",
5+
values = {
6+
"define": "abi=pre_cxx11_abi",
7+
},
8+
)

0 commit comments

Comments
 (0)