@@ -256,6 +256,7 @@ GraphAndMapping ConstructFallbackGraph(
256
256
// update the input ranges for each segments
257
257
convert_cfg.inputs = ir::associate_specs_with_inputs (seg_block.g (), inputs, static_params);
258
258
259
+ // TODO mapping Inputs Ivalue to flatten one here
259
260
auto engine = conversion::ConvertBlockToEngine (seg_block.block (), convert_cfg, static_params);
260
261
auto temp_g = std::make_shared<torch::jit::Graph>();
261
262
auto device_spec = convert_cfg.engine_settings .device ;
@@ -306,57 +307,80 @@ void MapInputsAndDetermineDTypes(
306
307
CompileSpec& cfg,
307
308
std::shared_ptr<torch::jit::Graph>& g,
308
309
ir::StaticParams& static_params,
309
- ir::TypeMap& first_use_type_map) {
310
- // Associate input specs with inputs
311
- cfg.convert_info .inputs = std::move (ir::associate_specs_with_inputs (g, cfg.inputs , static_params));
312
-
313
- for (auto & in : g->inputs ()) {
314
- if (static_params.find (in) == static_params.end ()) {
315
- ir::Input& spec = cfg.convert_info .inputs .find (in)->second ;
316
- auto est_type_opt = first_use_type_map.find (in)->second ;
317
- if (est_type_opt && !spec.dtype_is_user_defined ) {
310
+ ir::CollectionTypeMap& first_use_type_map) {
311
+ cfg.convert_info .collection_input_spec_map =
312
+ std::move (ir::associate_specs_with_collection_inputs (g, cfg.graph_inputs , static_params));
313
+
314
+ auto collection_inputs = ir::get_collection_inputs (g, static_params);
315
+ LOG_DEBUG (
316
+ " In MapInputsAndDetermineDTypes, the g->inputs() size is "
317
+ << g->inputs ().size () << " , CollectionInputSpecMap size is" << collection_inputs.size ());
318
+
319
+ for (auto in : collection_inputs) {
320
+ std::vector<ir::Input>& spec = cfg.convert_info .collection_input_spec_map .find (in)->second ;
321
+ std::vector<c10::optional<at::ScalarType>> est_type_opt;
322
+
323
+ auto est_it = first_use_type_map.find (in);
324
+ if (est_it != first_use_type_map.end ()) {
325
+ est_type_opt = first_use_type_map.find (in)->second ;
326
+ }
327
+ // traverse elements in est_type_out and spec
328
+ for (size_t i = 0 ; i < est_type_opt.size (); i++) {
329
+ if (est_type_opt[i] && !spec[i].dtype_is_user_defined ) {
318
330
// If we can calculate the type from the graph and the type was not defined by the user then use the calculated
319
331
// type
320
332
LOG_INFO (
321
- " Since input type is not explicitly defined, infering using first tensor calculation\n Found input "
322
- << in->debugName () << " has type " << est_type_opt.value ()
323
- << " . If this is incorrect explicitly set dtype for input and file a bug" );
324
- spec.dtype = util::ScalarTypeToTRTDataType (est_type_opt.value ());
325
- } else if (!est_type_opt && !spec.dtype_is_user_defined ) {
333
+ " Since input type is not explicitly defined, infering using first tensor calculation\n Inferred input "
334
+ << in->debugName () << " has type " << est_type_opt[i].value ());
335
+ spec[i].dtype = util::ScalarTypeToTRTDataType (est_type_opt[i].value ());
336
+ } else if (!est_type_opt[i] && !spec[i].dtype_is_user_defined ) {
326
337
// If we cannot calculate the type and the user did not define the type, then default to FP32
327
338
LOG_WARNING (
328
339
" Cannot infer input type from calcuations in graph for input "
329
340
<< in->debugName () << " . Assuming it is Float32. If not, specify input type explicity" );
330
- spec.dtype = nvinfer1::DataType::kFLOAT ;
331
- } else if (spec.dtype_is_user_defined && cfg.partition_info .enabled ) {
332
- if (!est_type_opt) {
333
- LOG_INFO (" Cannot infer input tensor dtype in graph. Using user provided input dtype settings" );
334
- first_use_type_map[in] = {util::TRTDataTypeToScalarType (cfg.convert_info .inputs .find (in)->second .dtype )};
341
+ spec[i].dtype = nvinfer1::DataType::kFLOAT ;
342
+ } else if (spec[i].dtype_is_user_defined && cfg.partition_info .enabled ) {
343
+ if (!est_type_opt[i]) {
344
+ LOG_INFO (" Cannot infer input tensor dtype in graph, compiler is going to use the user setting" );
345
+ std::stringstream ss;
346
+ ss << " For input " << in->debugName () << " , found user specified input dtype as " ;
347
+ ss << cfg.convert_info .collection_input_spec_map .find (in)->second [i].dtype ;
348
+ ss << " . The compiler is going to use the user setting "
349
+ << cfg.convert_info .collection_input_spec_map .find (in)->second [i].dtype ;
350
+ auto warn_str = ss.str ();
351
+ LOG_WARNING (warn_str);
352
+ // Overwrite type map with user settings
353
+ first_use_type_map[in][i] = {
354
+ util::TRTDataTypeToScalarType (cfg.convert_info .collection_input_spec_map .find (in)->second [i].dtype )};
355
+
335
356
} else {
336
- if (util::TRTDataTypeToScalarType (cfg.convert_info .inputs .find (in)->second .dtype ) != est_type_opt.value ()) {
357
+ if (util::TRTDataTypeToScalarType (cfg.convert_info .collection_input_spec_map .find (in)->second [i].dtype ) !=
358
+ est_type_opt[i].value ()) {
337
359
std::stringstream ss;
338
360
ss << " For input " << in->debugName () << " , found user specified input dtype as " ;
339
- ss << cfg.convert_info .inputs .find (in)->second .dtype ;
361
+ ss << cfg.convert_info .collection_input_spec_map .find (in)->second [i] .dtype ;
340
362
ss << " , however when inspecting the graph, the input type expected was inferred to be " ;
341
- ss << est_type_opt.value () << std::endl;
342
- ss << " The compiler is going to use the user setting " << cfg.convert_info .inputs .find (in)->second .dtype ;
363
+ ss << est_type_opt[i].value () << std::endl;
364
+ ss << " The compiler is going to use the user setting "
365
+ << cfg.convert_info .collection_input_spec_map .find (in)->second [i].dtype ;
343
366
ss << " \n This conflict may cause an error at runtime due to partial compilation being enabled and therefore\n " ;
344
367
ss << " compatibility with PyTorch's data type convention is required.\n " ;
345
368
ss << " If you do indeed see errors at runtime either:\n " ;
346
369
ss << " - Remove the dtype spec for " << in->debugName () << std::endl;
347
370
ss << " - Disable partial compilation by setting require_full_compilation to True" ;
348
371
auto warn_str = ss.str ();
349
372
LOG_WARNING (warn_str);
373
+ // Overwrite type map with user settings
374
+ first_use_type_map[in][i] = {
375
+ util::TRTDataTypeToScalarType (cfg.convert_info .collection_input_spec_map .find (in)->second [i].dtype )};
350
376
}
351
- // Overwrite type map with user settings
352
- // We use this map for partitiioning since we need c10::ScalarTypes not nvinfer::DataTypes
353
- first_use_type_map[in] = {util::TRTDataTypeToScalarType (cfg.convert_info .inputs .find (in)->second .dtype )};
354
377
}
355
378
} else {
356
379
// The user defined the type so no changes are necessary
357
380
}
358
381
}
359
382
}
383
+ // }
360
384
}
361
385
362
386
std::string ConvertGraphToTRTEngine (const torch::jit::script::Module& mod, std::string method_name, CompileSpec cfg) {
@@ -370,7 +394,7 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::
370
394
auto params = graph_and_parameters.second ;
371
395
auto static_params = ir::get_static_params (g->inputs (), params);
372
396
// Infer the type of an input from the weights of the calculation
373
- auto first_use_types = ir::get_block_first_calc_dtypes_opt (g->block ());
397
+ auto first_use_types = ir::get_block_first_calc_dtypes_opt_collection (g->block ());
374
398
375
399
MapInputsAndDetermineDTypes (cfg, g, static_params, first_use_types);
376
400
@@ -395,23 +419,26 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
395
419
auto params = graph_and_parameters.second ;
396
420
auto static_params = ir::get_static_params (g->inputs (), params);
397
421
// Infer the type of an input from the weights of the calculation
398
- auto first_use_types = ir::get_block_first_calc_dtypes_opt (g->block ());
422
+ auto first_use_types = ir::get_block_first_calc_dtypes_opt_collection (g->block ());
399
423
400
424
MapInputsAndDetermineDTypes (cfg, g, static_params, first_use_types);
401
425
auto isBlockConvertible = conversion::VerifyConverterSupportForBlock (g->block (), true );
426
+ auto outputIsCollection = conversion::OutputIsCollection (g->block ());
402
427
if (cfg.partition_info .enabled &&
403
428
(cfg.lower_info .forced_fallback_modules .size () == 0 &&
404
429
cfg.partition_info .forced_fallback_operators .size () == 0 && isBlockConvertible)) {
405
430
LOG_INFO (" Skipping partitioning since model is fully supported" );
406
431
}
407
432
408
433
if (cfg.partition_info .enabled &&
409
- !(cfg.lower_info .forced_fallback_modules .size () == 0 &&
410
- cfg.partition_info .forced_fallback_operators .size () == 0 && isBlockConvertible)) {
411
- auto input_ivalues_map = partitioning::generateRandomInputs (cfg. convert_info . inputs , first_use_types);
434
+ ( !(cfg.lower_info .forced_fallback_modules .size () == 0 &&
435
+ cfg.partition_info .forced_fallback_operators .size () == 0 && isBlockConvertible) ||
436
+ outputIsCollection)) {
412
437
std::unordered_map<torch::jit::Node*, int > fallback_nodes;
413
- auto graph_and_mapping =
414
- ConstructFallbackGraph (new_mod, g->block (), input_ivalues_map, cfg, static_params, fallback_nodes);
438
+ auto collection_input_ivalues_map =
439
+ partitioning::generateRandomInputs (cfg.convert_info .collection_input_spec_map , first_use_types);
440
+ auto graph_and_mapping = ConstructFallbackGraph (
441
+ new_mod, g->block (), collection_input_ivalues_map, cfg, static_params, fallback_nodes);
415
442
new_g = graph_and_mapping.first ;
416
443
// renaming the input name of graph after fallback to ensure pytorch deserialize it correctly
417
444
for (size_t i = 0 ; i < new_g->inputs ().size (); ++i) {
@@ -429,6 +456,7 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
429
456
TORCHTRT_CHECK (
430
457
conversion::VerifyConverterSupportForBlock (g->block ()),
431
458
" Not all operations in graph are supported by the compiler" );
459
+ // TODO find the right
432
460
auto engine = conversion::ConvertBlockToEngine (g->block (), cfg.convert_info , static_params);
433
461
AddEngineToGraph (new_mod, new_g, engine, cuda_device);
434
462
}
0 commit comments