@@ -244,7 +244,7 @@ def compile_module(
244
244
dryrun_tracker .total_ops_in_graph = total_ops
245
245
dryrun_tracker .supported_ops_in_graph = num_supported_ops
246
246
dryrun_tracker .graph_input_shapes = parse_complex_tensor_structs (
247
- sample_inputs , "shape" , tuple
247
+ sample_inputs , "shape" , lambda x : dict ( x ) if isinstance ( x , dict ) else tuple ( x )
248
248
)
249
249
dryrun_tracker .graph_input_dtypes = parse_complex_tensor_structs (
250
250
sample_inputs , "torch_dtype"
@@ -356,7 +356,9 @@ def compile_module(
356
356
)
357
357
358
358
subgraph_data .subgraph_input_shapes = parse_complex_tensor_structs (
359
- submodule_inputs , "shape" , tuple
359
+ submodule_inputs ,
360
+ "shape" ,
361
+ lambda x : dict (x ) if isinstance (x , dict ) else tuple (x ),
360
362
)
361
363
subgraph_data .subgraph_input_dtypes = parse_complex_tensor_structs (
362
364
submodule_inputs , "torch_dtype"
@@ -367,7 +369,9 @@ def compile_module(
367
369
)
368
370
369
371
subgraph_data .subgraph_output_shapes = parse_complex_tensor_structs (
370
- submodule_outputs , "shape" , tuple
372
+ submodule_outputs ,
373
+ "shape" ,
374
+ lambda x : dict (x ) if isinstance (x , dict ) else tuple (x ),
371
375
)
372
376
subgraph_data .subgraph_output_dtypes = parse_complex_tensor_structs (
373
377
submodule_outputs , "dtype"
@@ -395,7 +399,7 @@ def compile_module(
395
399
sample_outputs = [sample_outputs ]
396
400
397
401
dryrun_tracker .graph_output_shapes = parse_complex_tensor_structs (
398
- sample_outputs , "shape" , tuple
402
+ sample_outputs , "shape" , lambda x : dict ( x ) if isinstance ( x , dict ) else tuple ( x )
399
403
)
400
404
dryrun_tracker .graph_output_dtypes = parse_complex_tensor_structs (
401
405
sample_outputs , "dtype"
0 commit comments