@@ -260,7 +260,7 @@ def compile_module(
260
260
dryrun_tracker .total_ops_in_graph = total_ops
261
261
dryrun_tracker .supported_ops_in_graph = num_supported_ops
262
262
dryrun_tracker .graph_input_shapes = parse_complex_tensor_structs (
263
- sample_inputs , "shape" , tuple
263
+ sample_inputs , "shape" , lambda x : dict ( x ) if isinstance ( x , dict ) else tuple ( x )
264
264
)
265
265
dryrun_tracker .graph_input_dtypes = parse_complex_tensor_structs (
266
266
sample_inputs , "torch_dtype"
@@ -372,7 +372,9 @@ def compile_module(
372
372
)
373
373
374
374
subgraph_data .subgraph_input_shapes = parse_complex_tensor_structs (
375
- submodule_inputs , "shape" , tuple
375
+ submodule_inputs ,
376
+ "shape" ,
377
+ lambda x : dict (x ) if isinstance (x , dict ) else tuple (x ),
376
378
)
377
379
subgraph_data .subgraph_input_dtypes = parse_complex_tensor_structs (
378
380
submodule_inputs , "torch_dtype"
@@ -383,7 +385,9 @@ def compile_module(
383
385
)
384
386
385
387
subgraph_data .subgraph_output_shapes = parse_complex_tensor_structs (
386
- submodule_outputs , "shape" , tuple
388
+ submodule_outputs ,
389
+ "shape" ,
390
+ lambda x : dict (x ) if isinstance (x , dict ) else tuple (x ),
387
391
)
388
392
subgraph_data .subgraph_output_dtypes = parse_complex_tensor_structs (
389
393
submodule_outputs , "dtype"
@@ -411,7 +415,7 @@ def compile_module(
411
415
sample_outputs = [sample_outputs ]
412
416
413
417
dryrun_tracker .graph_output_shapes = parse_complex_tensor_structs (
414
- sample_outputs , "shape" , tuple
418
+ sample_outputs , "shape" , lambda x : dict ( x ) if isinstance ( x , dict ) else tuple ( x )
415
419
)
416
420
dryrun_tracker .graph_output_dtypes = parse_complex_tensor_structs (
417
421
sample_outputs , "dtype"
0 commit comments