@@ -232,7 +232,9 @@ def partition(
232
232
233
233
# Check Owning Program still owns all constant data
234
234
owning_program = delegated .exported_program ()
235
- self .assertEqual (len (owning_program .state_dict ), 3 )
235
+ self .assertEqual (
236
+ len (owning_program .state_dict ) + len (owning_program .constants ), 3
237
+ )
236
238
self .assertEqual (len (owning_program .graph_signature .buffers ), 2 )
237
239
self .assertEqual (len (owning_program .graph_signature .parameters ), 1 )
238
240
@@ -321,7 +323,7 @@ def partition(
321
323
delegated .exported_program ().graph_module , lowered_module_node .name
322
324
)
323
325
delegated_ep = lower_module .original_module
324
- self .assertEqual (len (delegated_ep .state_dict ), 3 )
326
+ self .assertEqual (len (delegated_ep .state_dict ) + len ( delegated_ep . constants ) , 3 )
325
327
self .assertEqual (len (delegated_ep .graph_signature .buffers ), 2 )
326
328
self .assertEqual (len (delegated_ep .graph_signature .parameters ), 1 )
327
329
@@ -375,7 +377,9 @@ def partition(
375
377
376
378
# Check Owning Program still owns only buffers
377
379
owning_program = delegated .exported_program ()
378
- self .assertEqual (len (owning_program .state_dict ), 2 )
380
+ self .assertEqual (
381
+ len (owning_program .state_dict ) + len (owning_program .constants ), 2
382
+ )
379
383
self .assertEqual (len (owning_program .graph_signature .buffers ), 2 )
380
384
self .assertEqual (len (owning_program .graph_signature .parameters ), 0 )
381
385
0 commit comments