11
11
import operator
12
12
import typing
13
13
from collections import defaultdict
14
- from dataclasses import dataclass
14
+ from dataclasses import dataclass , field
15
15
from typing import Any , Callable , Dict , Iterable , List , Optional , Set , Tuple , Union
16
16
17
17
import torch
@@ -117,6 +117,17 @@ def storage_overlap(cls, lhs_spec: TensorSpec, rhs_spec: TensorSpec) -> bool:
117
117
118
118
return has_overlap
119
119
120
+ @classmethod
121
+ def _debug_message_from_specs (
122
+ cls , lhs_spec : TensorSpec , rhs_spec : TensorSpec
123
+ ) -> str :
124
+ message = (
125
+ f"lhs life time: { lhs_spec .lifetime } , rhs lifetime: { rhs_spec .lifetime } "
126
+ )
127
+ message += f"lhs: mem_id { lhs_spec .mem_id } storage: { lhs_spec .mem_offset } , { lhs_spec .allocated_memory } "
128
+ message += f"rhs: mem_id { rhs_spec .mem_id } storage: { rhs_spec .mem_offset } , { rhs_spec .allocated_memory } "
129
+ return message
130
+
120
131
def verify_storage_reuse (
121
132
self , allow_lifetime_and_storage_overlap : bool = False
122
133
) -> int :
@@ -159,7 +170,7 @@ def verify_storage_reuse(
159
170
lhs_spec , rhs_spec
160
171
):
161
172
raise InternalError (
162
- f"Unexpected storage overlap: lhs { lhs_spec } , rhs { rhs_spec } "
173
+ f"Unexpected storage overlap: { Verifier . _debug_message_from_specs ( lhs_spec , rhs_spec ) } "
163
174
)
164
175
165
176
# Check that each mem_obj_id is consistent with whether the tensors have
@@ -454,6 +465,18 @@ def update_all_tensors_lifetime(
454
465
return specs
455
466
456
467
468
+ @dataclass
469
+ class AllocationSpec :
470
+ """
471
+ AllocationSpec is used to represent the allocation of a tensor.
472
+ """
473
+
474
+ # The offset of the tensor in the shared object/pool.
475
+ offset : int
476
+ # TensorSpec
477
+ spec : TensorSpec
478
+
479
+
457
480
@dataclass
458
481
class SharedObject :
459
482
r"""
@@ -470,8 +493,15 @@ class SharedObject:
470
493
offset : int
471
494
# size of this shared object in bytes
472
495
size : int
496
+ # When the object is first created
497
+ first_used_index : int
473
498
# the object will be available for index (last_used_index + 1)
474
499
last_used_index : int
500
+ # list of allocations belong to this shared object
501
+ allocations : List [AllocationSpec ] = field (default_factory = list )
502
+
503
+ def __repr__ (self ) -> str :
504
+ return f"SharedObject(idx={ self .idx } , offset={ self .offset } , size={ self .size } , lifetime=[{ self .first_used_index , self .last_used_index } ])"
475
505
476
506
477
507
def materialize_buffer (
@@ -489,35 +519,124 @@ def materialize_buffer(
489
519
return total_size
490
520
491
521
492
- def _size_abs_dif (sobj : SharedObject , spec : TensorSpec ) -> int :
522
+ def _does_not_overlap (sobj : SharedObject , spec : TensorSpec ) -> bool :
493
523
r"""
494
- Calculate the absolute different between the size of a shared object and
495
- a tensor.
524
+ Check if a shared object and a tensor do not overlap.
496
525
"""
497
- return abs (sobj .size - spec .allocated_memory )
526
+ for alloc in sobj .allocations :
527
+ if not (
528
+ spec .lifetime [1 ] < alloc .spec .lifetime [0 ]
529
+ or spec .lifetime [0 ] > alloc .spec .lifetime [1 ]
530
+ ):
531
+ return False
532
+ return True
533
+
534
+
535
+ def _find_max_overlapping_allocations_offset (
536
+ sobj : SharedObject , spec : TensorSpec
537
+ ) -> int :
538
+ max_offset = 0
539
+ for alloc in sobj .allocations :
540
+ if (
541
+ spec .lifetime [1 ] < alloc .spec .lifetime [0 ]
542
+ or spec .lifetime [0 ] > alloc .spec .lifetime [1 ]
543
+ ):
544
+ continue
545
+ max_offset = max (alloc .offset + alloc .spec .allocated_memory , max_offset )
546
+ return max_offset
498
547
499
548
500
549
def pick_shared_obj (
501
- shared_objects : List [SharedObject ], spec : TensorSpec
550
+ shared_objects : List [SharedObject ],
551
+ spec : TensorSpec ,
552
+ allow_overlapping_allocations : bool = True ,
502
553
) -> SharedObject :
503
554
r"""
504
- Pick the available shared object with closest size to the tensor.
505
- If there are no available shared object left, create a new one.
555
+ Pick the available shared object to which to assign this spec,
556
+ or create a new one
557
+ Algorithm details
558
+ Previous: Look at every spec in chronological order. Find if previously allocated object
559
+ allows it to fit in. If not, allocate a new object.
560
+ New:
561
+ - Sort all the specs by allocation size
562
+ - Process the specs in order
563
+ - If the spec's size in smaller than previously allocated buckets:
564
+ - Conditions under which previously allocated bucket can be used:
565
+ - Lifetime of the spec does not overlap with lifetime of the bucket.
566
+ - In this case allocate spec to that bucket and expand its lifetime.
567
+ - Spec is allocated at offset = 0 in this bucket.
568
+ - Add this spec to allocated object's list of specs.
569
+ - Lifetime of the spec overlaps with lifetime of the bucket,
570
+ partially or fully (e.g. spec's lifetime subset of bucket's lifetime)
571
+ - If none of the specs in the bucket overlaps with spec's lifetime.
572
+ - Allocate spec to the bucket at offset = 0.
573
+ - Add this spec to the bucket's list of specs.
574
+ - Expand bucket's lifetime accounting for added spec's lifetime.
575
+ - If one or more specs in the bucket overlaps with spec's lifetime.
576
+ - Collect offsets (at which the given overlapping spec is allocated in the bucket).
577
+ of all the overlapping specs, and find the max offset.
578
+ - Allocate spec to the bucket at offset = max_offset + max_offset_spec_size.
579
+ - Add this spec to the bucket's list of specs.
580
+ - Expand bucket's lifetime accounting for added spec's lifetime.
581
+ - If none of these conditions are met, allocate a new bucket.
582
+ - Add spec to this bucket.
583
+ - Update bucket's lifetime to that of the spec.
584
+ - If the spec's size is larger than previously allocated buckets, allocate a new bucket.
585
+ - Size and lifetime of this bucket is that of the spec
586
+
587
+ Proof of correctness:
588
+ - If allocating a new bucket, it is correct.
589
+ - If allocating spec to an existing bucket, whose lifetime does not overlap with any
590
+ of the previously allocated specs' lifetime, then the allocation is correct.
591
+ Proof of correctness by induction when adding spec to an existing bucket:
592
+ - If all previous allocations in the given bucket are correct:
593
+ - Then the new one being added must be correct because when the requested allocation
594
+ overlaps with one or more previous allocations, we find the largest offset among
595
+ all the overlapping allocations, and allocate the new spec at that offset. Hence,
596
+ the allocation at such an offset, will not overlap with any previous allocations.
597
+ Base case: A newly added allocation within a bucket with single allocation is correct:
598
+ because a) it must fit and b) its lifetime must not overlap with object's lifetime.
599
+ This holds true because of the following invariants:
600
+ - Once a bucket is created, it is never resized.
601
+ - All the allocations within a bucket follow this:
602
+ - Span, defined by allocation's offset + size, of two allocations can only overlap,
603
+ if their timelines do not overlap.
506
604
"""
507
- # TODO: do better than linear scan
508
605
picked = None
509
606
for sobj in shared_objects :
510
- if spec .lifetime [0 ] > sobj .last_used_index :
511
- if picked is None or _size_abs_dif (sobj , spec ) < _size_abs_dif (
512
- picked , spec
513
- ):
514
- picked = sobj
515
- sobj .last_used_index = spec .lifetime [1 ]
516
- sobj .size = max (sobj .size , spec .allocated_memory )
607
+ if _does_not_overlap (sobj , spec ):
608
+ assert sobj .size >= spec .allocated_memory , "Allocation specs are not sorted"
609
+ picked = sobj
610
+ sobj .first_used_index = min (sobj .first_used_index , spec .lifetime [0 ])
611
+ sobj .last_used_index = max (sobj .last_used_index , spec .lifetime [1 ])
612
+ allocation_spec = AllocationSpec (0 , spec )
613
+ picked .allocations .append (allocation_spec )
614
+ break
615
+
616
+ if picked is None and allow_overlapping_allocations :
617
+ for sobj in shared_objects :
618
+ max_offset = _find_max_overlapping_allocations_offset (sobj , spec )
619
+ if max_offset > 0 :
620
+ if max_offset + spec .allocated_memory <= sobj .size :
621
+ picked = sobj
622
+ sobj .first_used_index = min (sobj .first_used_index , spec .lifetime [0 ])
623
+ sobj .last_used_index = max (sobj .last_used_index , spec .lifetime [1 ])
624
+ allocation_spec = AllocationSpec (max_offset , spec )
625
+ picked .allocations .append (allocation_spec )
626
+ break
627
+
517
628
if picked is None :
518
629
picked = SharedObject (
519
- len (shared_objects ), - 1 , spec .allocated_memory , spec .lifetime [1 ]
630
+ len (shared_objects ),
631
+ - 1 ,
632
+ spec .allocated_memory ,
633
+ spec .lifetime [0 ],
634
+ spec .lifetime [1 ],
520
635
)
636
+ allocation_spec = AllocationSpec (0 , spec )
637
+ picked .allocations .append (allocation_spec )
638
+ picked .first_used_index = spec .lifetime [0 ]
639
+ picked .last_used_index = spec .lifetime [1 ]
521
640
shared_objects .append (picked )
522
641
523
642
return picked
@@ -556,7 +675,16 @@ def greedy(
556
675
graph_signature : Optional [ExportGraphSignature ] = None ,
557
676
alloc_graph_input : bool = True ,
558
677
alloc_graph_output : bool = True ,
678
+ allow_overlapping_allocations : bool = True ,
559
679
) -> List [int ]:
680
+ r"""Greedy algorithm to allocate memory for tensors in the graph.
681
+ alloc_graph_input: If set to true, the algorithm will allocate memory for graph input.
682
+ alloc_graph_output: If set to true, the algorithm will allocate memory for graph output.
683
+ allow_overlapping_allocations: If set to true, allows for allocations that overlap
684
+ in their lifetime but are at different offsets in the storage. By default true.
685
+ This flag is added to allow for Vulkan to use MemoryPlanningPass with overlapping
686
+ allocations disabled
687
+ """
560
688
spec2obj = {}
561
689
shared_objects = defaultdict (list )
562
690
# Don't do assertion in collect_specs_from_nodes if we have already encountered
@@ -565,24 +693,34 @@ def greedy(
565
693
# For each tensor, pick the available shared object with closest size to
566
694
# the tensor. If there are no available shared object left, create a new
567
695
# one.
696
+ import bisect
697
+
698
+ sorted_specs = []
568
699
for spec in collect_specs_from_nodes (
569
700
graph_module .graph .nodes ,
570
701
graph_signature ,
571
702
do_assertion = do_assertion ,
572
703
ignore_graph_input = not alloc_graph_input ,
573
704
ignore_graph_output = not alloc_graph_output ,
574
705
):
706
+ bisect .insort (sorted_specs , spec , key = lambda x : x .allocated_memory )
707
+ sorted_specs .reverse ()
708
+
709
+ for spec in sorted_specs :
575
710
if spec .mem_id is None :
576
711
spec .mem_id = 1
577
712
spec .realign (alignment )
578
- spec2obj [spec ] = pick_shared_obj (shared_objects [spec .mem_id ], spec )
713
+ spec2obj [spec ] = pick_shared_obj (
714
+ shared_objects [spec .mem_id ], spec , allow_overlapping_allocations
715
+ )
579
716
580
717
if len (shared_objects ) == 0 :
581
718
# Cannot find any tensor in the graph that needs to be allocated.
582
719
# Return [0, 0] to be consistent with default behavior of naive.
583
720
total_sizes = [0 , 0 ]
584
721
else :
585
722
total_sizes = [0 ] * (max (shared_objects .keys ()) + 1 )
723
+ num_specs_processed = 0
586
724
for mem_id in shared_objects :
587
725
input_total_size = 0
588
726
if bufsizes := getattr (graph_module , "input_mem_buffer_sizes" , None ):
@@ -594,13 +732,25 @@ def greedy(
594
732
total_sizes [mem_id ] = materialize_buffer (
595
733
shared_objects [mem_id ], input_total_size
596
734
)
597
-
598
- # Since we now know the number of shared objects we need and the size of
599
- # each shared object, we can assign offset in the memory buffer for each
600
- # shared object.
601
- for spec , sobj in spec2obj .items ():
602
- spec .mem_obj_id = sobj .idx
603
- spec .mem_offset = sobj .offset
735
+ # padding allocation with 64 bytes.
736
+ # this requirement really for XNNPACK backend which can access tensors
737
+ # for reading beyond the end of the tensor. This is done for performance
738
+ # optimizations in XNNPACK.
739
+ # While account for backend specific requirement is not the right choice
740
+ # in backend agnostic memory planning, we do it here for now.
741
+ total_sizes [mem_id ] += 64
742
+ # Since we now know the number of shared objects we need and the size of
743
+ # each shared object, we can assign offset in the memory buffer for each
744
+ # shared object.
745
+ for sobj in shared_objects [mem_id ]:
746
+ for alloc in sobj .allocations :
747
+ spec = alloc .spec
748
+ alloc .spec .mem_obj_id = sobj .idx
749
+ alloc .spec .mem_offset = sobj .offset + alloc .offset
750
+ num_specs_processed += 1
751
+ assert (
752
+ len (spec2obj ) == num_specs_processed
753
+ ), f"All specs should be processed but there were { len (spec2obj )} specs and processed { num_specs_processed } specs"
604
754
605
755
logging .debug (f"greedy algorithm returns bufsizes: { total_sizes } " )
606
756
return total_sizes
0 commit comments