@@ -190,6 +190,8 @@ void AddParamsToCtxValueMap(ConversionCtx* ctx, GraphParams& params) {
190
190
}
191
191
}
192
192
193
+ void EvaluateLoopBlock (ConversionCtx* ctx, const torch::jit::Node* n);
194
+
193
195
void MapIValues (ConversionCtx* ctx, c10::ArrayRef<const torch::jit::Value*> in_list, c10::ArrayRef<const torch::jit::Value*> out_list, int64_t in_offset, int64_t out_offset) {
194
196
std::vector<std::pair<const torch::jit::Value*, const torch::jit::Value*>> input_output_pairs;
195
197
std::transform (in_list.begin () + in_offset, in_list.end (), out_list.begin () + out_offset,
@@ -199,11 +201,54 @@ void MapIValues(ConversionCtx* ctx, c10::ArrayRef<const torch::jit::Value*> in_l
199
201
});
200
202
201
203
for (auto p : input_output_pairs) {
202
- auto input = ctx->evaluated_value_map [p.first ];
203
- ctx->evaluated_value_map [p.second ] = torch::jit::IValue (input);
204
+ if (ctx->evaluated_value_map .find (p.first ) != ctx->evaluated_value_map .end ()) {
205
+ auto input = ctx->evaluated_value_map [p.first ];
206
+ ctx->evaluated_value_map [p.second ] = torch::jit::IValue (input);
207
+ } else if (ctx->value_tensor_map .find (p.first ) != ctx->value_tensor_map .end ()) {
208
+ auto input = ctx->value_tensor_map [p.first ];
209
+ ctx->value_tensor_map [p.second ] = input;
210
+ } else {
211
+ TRTORCH_THROW_ERROR (" Cannot find Value " << p.first ->debugName () << " either evaluated values or tensor maps (MapIValues)" );
212
+ }
204
213
}
205
214
}
206
215
216
+ void EvaluateConditionalBlock (ConversionCtx* ctx, const torch::jit::Node* n, bool contained_in_loop = false ) {
217
+ bool output_type_includes_tensor = false ;
218
+ for (auto o : n->outputs ()) {
219
+ if (o->type ()->isSubtypeOf (c10::TensorType::get ())) {
220
+ output_type_includes_tensor = true ;
221
+ }
222
+ }
223
+ TRTORCH_CHECK (!(contained_in_loop && output_type_includes_tensor), " TRTorch currently cannot compile conditionals within loops" );
224
+
225
+ auto condition = ctx->evaluated_value_map [n->input (0 )].toBool ();
226
+ LOG_DEBUG (ctx->logger , " (Conditional Evaluation) Evaluating block " << (int ) condition);
227
+ auto b = condition ? n->blocks ()[0 ] : n->blocks ()[1 ];
228
+
229
+ for (const auto bn : b->nodes ()) {
230
+ if (bn->kind () == torch::jit::prim::Loop) {
231
+ EvaluateLoopBlock (ctx, bn);
232
+ } else if (bn->kind () == torch::jit::prim::If) {
233
+ EvaluateConditionalBlock (ctx, bn, contained_in_loop);
234
+ } else if (evaluators::shouldEvalAtConversionTime (bn)) {
235
+ auto eval = EvaluateNode (ctx, bn);
236
+ if (!eval.value ().isTensor ()) {
237
+ LOG_DEBUG (ctx->logger , " (Conditional Evaluation) Found the value to be: " << eval.value ());
238
+ } else {
239
+ LOG_DEBUG (ctx->logger , " (Conditional Evaluation) Found the value to be a tensor (shape " << eval.value ().toTensor ().sizes () << ' )' );
240
+ }
241
+ ctx->AssociateValueAndIValue (bn->output (0 ), eval.value ());
242
+ } else if (converters::node_is_convertable (bn)) {
243
+ AddLayer (ctx, bn);
244
+ } else {
245
+ TRTORCH_THROW_ERROR (" TRTorch is unable to compile this conditional, a converter or evaluator is not available for node " << *bn);
246
+ }
247
+ }
248
+
249
+ MapIValues (ctx, b->outputs (), n->outputs (), 0 , 0 );
250
+ }
251
+
207
252
// TODO: With functionalization pass we may be able to make this into a regular evaluator later
208
253
void EvaluateLoopBlock (ConversionCtx* ctx, const torch::jit::Node* n) {
209
254
auto max_trip_count = ctx->evaluated_value_map [n->input (0 )];
@@ -213,16 +258,21 @@ void EvaluateLoopBlock(ConversionCtx* ctx, const torch::jit::Node* n) {
213
258
214
259
MapIValues (ctx, n->inputs (), n->outputs (), 2 , 0 );
215
260
216
- LOG_DEBUG (" (Loop Evaluation) Evaluating loop " << *n);
217
- LOG_DEBUG (" (Loop Evaluation) Max Trip Count: " << max_trip_count.toInt ());
218
- LOG_DEBUG (" (Loop Evaluation) Start Condition: " << start_cond.toBool ());
219
- LOG_DEBUG (" (Loop Evaluation) Current Trip Count: " << trip_count.toInt ());
261
+ LOG_DEBUG (ctx-> logger , " (Loop Evaluation) Evaluating loop " << *n);
262
+ LOG_DEBUG (ctx-> logger , " (Loop Evaluation) Max Trip Count: " << max_trip_count.toInt ());
263
+ LOG_DEBUG (ctx-> logger , " (Loop Evaluation) Start Condition: " << start_cond.toBool ());
264
+ LOG_DEBUG (ctx-> logger , " (Loop Evaluation) Current Trip Count: " << trip_count.toInt ());
220
265
221
266
while (start_cond.toBool () && trip_count.toInt () < max_trip_count.toInt ()) {
222
267
MapIValues (ctx, n->outputs (), n->blocks ()[0 ]->inputs (), 0 , 1 );
223
268
for (auto bn : n->blocks ()[0 ]->nodes ()) {
224
- auto eval = EvaluateNode (ctx, bn);
225
- if (eval) {
269
+ if (bn->kind () == torch::jit::prim::Loop) {
270
+ EvaluateLoopBlock (ctx, n);
271
+ } else if (bn->kind () == torch::jit::prim::If) {
272
+ EvaluateConditionalBlock (ctx, bn, true );
273
+ } else {
274
+ TRTORCH_CHECK (evaluators::shouldEvalAtConversionTime (bn), " TRTorch currently can only compile loops that are evaluatable at conversion time but node " << *bn << " cannot be evaluated." );
275
+ auto eval = EvaluateNode (ctx, bn);
226
276
if (!eval.value ().isTensor ()) {
227
277
LOG_DEBUG (ctx->logger , " (Loop Evaluation) Found the value to be: " << eval.value ());
228
278
} else {
@@ -236,8 +286,8 @@ void EvaluateLoopBlock(ConversionCtx* ctx, const torch::jit::Node* n) {
236
286
start_cond = ctx->evaluated_value_map [n->blocks ()[0 ]->outputs ()[0 ]];
237
287
auto new_trip_count = torch::jit::IValue (trip_count.toInt () + 1 );
238
288
trip_count.swap (new_trip_count);
239
- LOG_DEBUG (" (Loop Evaluation) Condition: " << start_cond.toBool ());
240
- LOG_DEBUG (" (Loop Evaluation) Current Trip Count: " << trip_count.toInt ());
289
+ LOG_DEBUG (ctx-> logger , " (Loop Evaluation) Condition: " << start_cond.toBool ());
290
+ LOG_DEBUG (ctx-> logger , " (Loop Evaluation) Current Trip Count: " << trip_count.toInt ());
241
291
}
242
292
}
243
293
@@ -255,6 +305,8 @@ void ConvertBlockToNetDef(ConversionCtx* ctx, const torch::jit::Block* b, Conver
255
305
bool blacklisted = isNodeConversionBlacklisted (n);
256
306
if (n->kind () == torch::jit::prim::Loop) {
257
307
EvaluateLoopBlock (ctx, n);
308
+ } else if (n->kind () == torch::jit::prim::If) {
309
+ EvaluateConditionalBlock (ctx, n);
258
310
} else if (to_eval) {
259
311
auto eval = EvaluateNode (ctx, n);
260
312
if (eval) {
@@ -303,10 +355,10 @@ std::string ConvertBlockToEngine(const torch::jit::Block* b, ConversionInfo buil
303
355
std::set<std::string> GetUnsupportedOpsInBlock (const torch::jit::Block* b ) {
304
356
std::set<std::string> unsupported_ops;
305
357
for (const auto n : b->nodes ()) {
306
- if (n->kind () != torch::jit::prim::Loop && !OpSupported (n)) {
358
+ if (n->kind () != torch::jit::prim::Loop && n-> kind () != torch::jit::prim::If && !OpSupported (n)) {
307
359
auto schema = n->maybeSchema ();
308
360
TRTORCH_CHECK (schema, " Unable to get schema for Node " << util::node_info (n) \
309
- << " (conversion.VerifyCoverterSupportForBlock" );
361
+ << " (conversion.VerifyCoverterSupportForBlock) " );
310
362
std::stringstream ss;
311
363
ss << *schema;
312
364
unsupported_ops.insert (ss.str ());
0 commit comments