Skip to content

Commit 6dc3bfa

Browse files
committed
chore: Initial input_type api impl
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent f0fb4c5 commit 6dc3bfa

File tree

3 files changed

+93
-1
lines changed

3 files changed

+93
-1
lines changed

core/ir/ir.h

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include <vector>
44
#include "NvInfer.h"
5+
// #include "trtorch.h"
56

67
namespace trtorch {
78
namespace core {
@@ -18,6 +19,17 @@ struct InputRange {
1819
InputRange(std::vector<int64_t> min_shape, std::vector<int64_t> opt_shape, std::vector<int64_t> max_shape);
1920
};
2021

22+
// struct Input{
23+
// Input(std::vector<int64_t> shape);
24+
// Input(std::vector<int64_t> min_shape, std::vector<int64_t> opt_shape, std::vector<int64_t> max_shape);
25+
// Input(std::vector<int64_t> shape, DataType dtype=DataType::kFloat32);
26+
// Input(std::vector<int64_t> min_shape, std::vector<int64_t> opt_shape, std::vector<int64_t> max_shape, DataType dtype=DataType::kFloat32);
27+
// nvinfer1::Dims min;
28+
// nvinfer1::Dims max;
29+
// nvinfer1::Dims opt;
30+
// nvinfer1::DataType dtype;
31+
// }
32+
2133
} // namespace ir
2234
} // namespace core
23-
} // namespace trtorch
35+
} // namespace trtorch

cpp/api/include/trtorch/trtorch.h

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,59 @@ struct TRTORCH_API CompileSpec {
191191
Value value;
192192
};
193193

194+
/**
195+
* @brief A struct to hold Input of a network.
196+
* This struct has all the info (shape, dtype, name, memory_format) of an input tensor.
197+
* The shape field in this struct can either hold a single vector representing an input shape,
198+
* signifying a static input shape or a set of three input shapes representing
199+
* the min, optiminal and max input shapes allowed for the engine.
200+
* dtype : This can take values among values supported by trtorch::DataType
201+
*/
202+
struct TRTORCH_API Input {
203+
/// Minimum acceptable input size into the engine
204+
std::vector<int64_t> min;
205+
/// Optimal input size into the engine (gets best performace)
206+
std::vector<int64_t> opt;
207+
/// Maximum acceptable input size into the engine
208+
std::vector<int64_t> max;
209+
/// Data type of the input
210+
DataType dtype;
211+
212+
/**
213+
* @brief Construct a new Input Range object for static input size from
214+
* vector
215+
*
216+
* @param opt
217+
*/
218+
Input(std::vector<int64_t> opt, DataType dtype=DataType::kFloat);
219+
/**
220+
* @brief Construct a new Input Range object static input size from
221+
* c10::ArrayRef (the type produced by tensor.sizes())
222+
*
223+
* @param opt
224+
*/
225+
Input(c10::ArrayRef<int64_t> opt, DataType dtype=DataType::kFloat);
226+
/**
227+
* @brief Construct a new Input Range object dynamic input size from vectors
228+
* for min, opt, and max supported sizes
229+
*
230+
* @param min
231+
* @param opt
232+
* @param max
233+
*/
234+
Input(std::vector<int64_t> min, std::vector<int64_t> opt, std::vector<int64_t> max, DataType dtype=DataType::kFloat);
235+
/**
236+
* @brief Construct a new Input Range object dynamic input size from
237+
* c10::ArrayRef (the type produced by tensor.sizes()) for min, opt, and max
238+
* supported sizes
239+
*
240+
* @param min
241+
* @param opt
242+
* @param max
243+
*/
244+
Input(c10::ArrayRef<int64_t> min, c10::ArrayRef<int64_t> opt, c10::ArrayRef<int64_t> max, DataType dtype=DataType::kFloat);
245+
};
246+
194247
/**
195248
* Emum for selecting engine capability
196249
*/

cpp/api/src/compile_spec.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,33 @@ CompileSpec::CompileSpec(std::vector<std::vector<int64_t>> fixed_sizes) {
7171
}
7272
}
7373

74+
/* ====== DEFINE INPUTS CLASS MEMBERS ======*/
75+
CompileSpec::Input::Input(std::vector<int64_t> opt) {
76+
this->opt = opt;
77+
this->min = opt;
78+
this->max = opt;
79+
}
80+
81+
CompileSpec::Input::Input(c10::IntArrayRef opt) {
82+
this->opt = core::util::toVec(opt);
83+
this->min = core::util::toVec(opt);
84+
this->max = core::util::toVec(opt);
85+
}
86+
87+
CompileSpec::Input::Input(std::vector<int64_t> min, std::vector<int64_t> opt, std::vector<int64_t> max) {
88+
this->opt = opt;
89+
this->min = min;
90+
this->max = max;
91+
}
92+
93+
CompileSpec::Input::Input(c10::IntArrayRef min, c10::IntArrayRef opt, c10::IntArrayRef max) {
94+
this->opt = core::util::toVec(opt);
95+
this->min = core::util::toVec(min);
96+
this->max = core::util::toVec(max);
97+
}
98+
99+
/* ==========================================*/
100+
74101
core::ir::InputRange to_internal_input_range(CompileSpec::InputRange i) {
75102
return core::ir::InputRange(i.min, i.opt, i.max);
76103
}

0 commit comments

Comments
 (0)