@@ -509,26 +509,25 @@ def tag_constant_data(edge_program: ExportedProgram) -> None:
509
509
underlying data will be owned by multiple delegates.
510
510
"""
511
511
for node in edge_program .graph .nodes :
512
- # go through const/param/buffer nodes
513
- is_attr = (
514
- node .op == "placeholder"
515
- and (
516
- is_param (edge_program , node )
517
- or is_buffer (edge_program , node )
518
- or is_lifted_tensor_constant (edge_program , node )
519
- )
520
- ) or (node .op == "get_attr" )
521
- # if all users of const/param/buffer nodes are partitioned then partition
522
- if is_attr :
512
+ # go through const/param/buffer nodes, if all users of const/param/buffer nodes are partitioned then partition
513
+ if node .op == "placeholder" and (
514
+ is_param (edge_program , node )
515
+ or is_buffer (edge_program , node )
516
+ or is_lifted_tensor_constant (edge_program , node )
517
+ ):
523
518
user_tags = set ()
524
519
for user in node .users :
525
- user_tags .add (user .meta .get ("delegation_tag" , None ))
526
- assert len (user_tags ) <= 1 , (
527
- "Const/Param/Buffer users have multiple tags because one constant data can't "
528
- "be owned by multiple backends. Consider duplicating the constant data so that "
529
- "each user is unique"
530
- )
531
- if len (user_tags ) == 1 :
520
+ user_tag = user .meta .get ("delegation_tag" , None )
521
+ if user_tag is not None :
522
+ user_tags .add (user_tag )
523
+ if len (user_tags ) > 1 :
524
+ logging .info (
525
+ f"The data node is used across multiple partitions, including { user_tags } . "
526
+ "If the data is too large and it's not preferred to copy, please tag the "
527
+ "constant node like node.['no_copy'] = True and they won't be copied."
528
+ )
529
+ # tag the data node with the same tag as the last user
530
+ if len (user_tags ) > 0 :
532
531
node .meta ["delegation_tag" ] = user_tags .pop ()
533
532
534
533
0 commit comments