11
11
import operator
12
12
import typing
13
13
from collections import defaultdict
14
- from dataclasses import dataclass
14
+ from dataclasses import dataclass , field
15
15
from typing import Any , Callable , Dict , Iterable , List , Optional , Set , Tuple , Union
16
16
17
17
import torch
@@ -454,6 +454,18 @@ def update_all_tensors_lifetime(
454
454
return specs
455
455
456
456
457
+ @dataclass
458
+ class AllocationSpec :
459
+ """
460
+ AllocationSpec is used to represent the allocation of a tensor.
461
+ """
462
+
463
+ # The offset of the tensor in the shared object/pool.
464
+ offset : int
465
+ # TensorSpec
466
+ spec : TensorSpec
467
+
468
+
457
469
@dataclass
458
470
class SharedObject :
459
471
r"""
@@ -470,8 +482,15 @@ class SharedObject:
470
482
offset : int
471
483
# size of this shared object in bytes
472
484
size : int
485
+ # When the object is first created
486
+ first_used_index : int
473
487
# the object will be available for index (last_used_index + 1)
474
488
last_used_index : int
489
+ # list of allocations belong to this shared object
490
+ allocations : List [AllocationSpec ] = field (default_factory = list )
491
+
492
+ def __repr__ (self ) -> str :
493
+ return f"SharedObject(idx={ self .idx } , offset={ self .offset } , size={ self .size } , lifetime=[{ self .first_used_index , self .last_used_index } ])"
475
494
476
495
477
496
def materialize_buffer (
@@ -489,35 +508,122 @@ def materialize_buffer(
489
508
return total_size
490
509
491
510
492
- def _size_abs_dif (sobj : SharedObject , spec : TensorSpec ) -> int :
511
+ def _does_not_overlap (sobj : SharedObject , spec : TensorSpec ) -> bool :
493
512
r"""
494
- Calculate the absolute different between the size of a shared object and
495
- a tensor.
513
+ Check if a shared object and a tensor do not overlap.
496
514
"""
497
- return abs (sobj .size - spec .allocated_memory )
515
+ for alloc in sobj .allocations :
516
+ if not (
517
+ spec .lifetime [1 ] < alloc .spec .lifetime [0 ]
518
+ or spec .lifetime [0 ] > alloc .spec .lifetime [1 ]
519
+ ):
520
+ return False
521
+ return True
522
+
523
+
524
+ def _find_max_overlapping_allocations_offset (
525
+ sobj : SharedObject , spec : TensorSpec
526
+ ) -> int :
527
+ max_offset = 0
528
+ for alloc in sobj .allocations :
529
+ if (
530
+ spec .lifetime [1 ] < alloc .spec .lifetime [0 ]
531
+ or spec .lifetime [0 ] > alloc .spec .lifetime [1 ]
532
+ ):
533
+ continue
534
+ max_offset = max (alloc .offset + alloc .spec .allocated_memory , max_offset )
535
+ return max_offset
498
536
499
537
500
538
def pick_shared_obj (
501
539
shared_objects : List [SharedObject ], spec : TensorSpec
502
540
) -> SharedObject :
503
541
r"""
504
- Pick the available shared object with closest size to the tensor.
505
- If there are no available shared object left, create a new one.
542
+ Pick the available shared object to which to assign this spec,
543
+ or create a new one
544
+ Algorithm details
545
+ Previous: Look at every spec in chronological order. Find if previously allocated object
546
+ allows it to fit in. If not, allocate a new object.
547
+ New:
548
+ - Sort all the specs by allocation size
549
+ - Process the specs in order
550
+ - If the spec's size in smaller than previously allocated buckets:
551
+ - Conditions under which previously allocated bucket can be used:
552
+ - Lifetime of the spec does not overlap with lifetime of the bucket.
553
+ - In this case allocate spec to that bucket and expand its lifetime.
554
+ - Spec is allocated at offset = 0 in this bucket.
555
+ - Add this spec to allocated object's list of specs.
556
+ - Lifetime of the spec overlaps with lifetime of the bucket,
557
+ partially or fully (e.g. spec's lifetime subset of bucket's lifetime)
558
+ - If none of the specs in the bucket overlaps with spec's lifetime.
559
+ - Allocate spec to the bucket at offset = 0.
560
+ - Add this spec to the bucket's list of specs.
561
+ - Expand bucket's lifetime accounting for added spec's lifetime.
562
+ - If one or more specs in the bucket overlaps with spec's lifetime.
563
+ - Collect offsets (at which the given overlapping spec is allocated in the bucket).
564
+ of all the overlapping specs, and find the max offset.
565
+ - Allocate spec to the bucket at offset = max_offset + max_offset_spec_size.
566
+ - Add this spec to the bucket's list of specs.
567
+ - Expand bucket's lifetime accounting for added spec's lifetime.
568
+ - If none of these conditions are met, allocate a new bucket.
569
+ - Add spec to this bucket.
570
+ - Update bucket's lifetime to that of the spec.
571
+ - If the spec's size is larger than previously allocated buckets, allocate a new bucket.
572
+ - Size and lifetime of this bucket is that of the spec
573
+
574
+ Proof of correctness:
575
+ - If allocating a new bucket, it is correct.
576
+ - If allocating spec to an existing bucket, whose lifetime does not overlap with any
577
+ of the previously allocated specs' lifetime, then the allocation is correct.
578
+ Proof of correctness by induction when adding spec to an existing bucket:
579
+ - If all previous allocations in the given bucket are correct:
580
+ - Then the new one being added must be correct because when the requested allocation
581
+ overlaps with one or more previous allocations, we find the largest offset among
582
+ all the overlapping allocations, and allocate the new spec at that offset. Hence,
583
+ the allocation at such an offset, will not overlap with any previous allocations.
584
+ Base case: A newly added allocation within a bucket with single allocation is correct:
585
+ because a) it must fit and b) its lifetime must not overlap with object's lifetime.
586
+ This holds true because of the following invariants:
587
+ - Once a bucket is created, it is never resized.
588
+ - All the allocations within a bucket follow this:
589
+ - Span, defined by allocation's offset + size, of two allocations can only overlap,
590
+ if their timelines do not overlap.
506
591
"""
507
- # TODO: do better than linear scan
508
592
picked = None
509
593
for sobj in shared_objects :
510
- if spec .lifetime [0 ] > sobj .last_used_index :
511
- if picked is None or _size_abs_dif (sobj , spec ) < _size_abs_dif (
512
- picked , spec
513
- ):
514
- picked = sobj
515
- sobj .last_used_index = spec .lifetime [1 ]
516
- sobj .size = max (sobj .size , spec .allocated_memory )
594
+ if _does_not_overlap (sobj , spec ):
595
+ assert sobj .size >= spec .allocated_memory , "Allocation specs are not sorted"
596
+ picked = sobj
597
+ sobj .first_used_index = min (sobj .first_used_index , spec .lifetime [0 ])
598
+ sobj .last_used_index = max (sobj .last_used_index , spec .lifetime [1 ])
599
+ allocation_spec = AllocationSpec (0 , spec )
600
+ picked .allocations .append (allocation_spec )
601
+ break
602
+
603
+ if picked is None :
604
+ for sobj in shared_objects :
605
+ max_offset = _find_max_overlapping_allocations_offset (sobj , spec )
606
+ if max_offset > 0 :
607
+ if max_offset + spec .allocated_memory <= sobj .size :
608
+ picked = sobj
609
+ sobj .first_used_index = min (sobj .first_used_index , spec .lifetime [0 ])
610
+ sobj .last_used_index = max (sobj .last_used_index , spec .lifetime [1 ])
611
+ allocation_spec = AllocationSpec (max_offset , spec )
612
+ picked .allocations .append (allocation_spec )
613
+ break
614
+
517
615
if picked is None :
518
616
picked = SharedObject (
519
- len (shared_objects ), - 1 , spec .allocated_memory , spec .lifetime [1 ]
617
+ len (shared_objects ),
618
+ - 1 ,
619
+ spec .allocated_memory ,
620
+ spec .lifetime [0 ],
621
+ spec .lifetime [1 ],
520
622
)
623
+ allocation_spec = AllocationSpec (0 , spec )
624
+ picked .allocations .append (allocation_spec )
625
+ picked .first_used_index = spec .lifetime [0 ]
626
+ picked .last_used_index = spec .lifetime [1 ]
521
627
shared_objects .append (picked )
522
628
523
629
return picked
@@ -565,13 +671,20 @@ def greedy(
565
671
# For each tensor, pick the available shared object with closest size to
566
672
# the tensor. If there are no available shared object left, create a new
567
673
# one.
674
+ import bisect
675
+
676
+ sorted_specs = []
568
677
for spec in collect_specs_from_nodes (
569
678
graph_module .graph .nodes ,
570
679
graph_signature ,
571
680
do_assertion = do_assertion ,
572
681
ignore_graph_input = not alloc_graph_input ,
573
682
ignore_graph_output = not alloc_graph_output ,
574
683
):
684
+ bisect .insort (sorted_specs , spec , key = lambda x : x .allocated_memory )
685
+ sorted_specs .reverse ()
686
+
687
+ for spec in sorted_specs :
575
688
if spec .mem_id is None :
576
689
spec .mem_id = 1
577
690
spec .realign (alignment )
@@ -583,6 +696,7 @@ def greedy(
583
696
total_sizes = [0 , 0 ]
584
697
else :
585
698
total_sizes = [0 ] * (max (shared_objects .keys ()) + 1 )
699
+ num_specs_processed = 0
586
700
for mem_id in shared_objects :
587
701
input_total_size = 0
588
702
if bufsizes := getattr (graph_module , "input_mem_buffer_sizes" , None ):
@@ -594,13 +708,18 @@ def greedy(
594
708
total_sizes [mem_id ] = materialize_buffer (
595
709
shared_objects [mem_id ], input_total_size
596
710
)
597
-
598
- # Since we now know the number of shared objects we need and the size of
599
- # each shared object, we can assign offset in the memory buffer for each
600
- # shared object.
601
- for spec , sobj in spec2obj .items ():
602
- spec .mem_obj_id = sobj .idx
603
- spec .mem_offset = sobj .offset
711
+ # Since we now know the number of shared objects we need and the size of
712
+ # each shared object, we can assign offset in the memory buffer for each
713
+ # shared object.
714
+ for sobj in shared_objects [mem_id ]:
715
+ for alloc in sobj .allocations :
716
+ spec = alloc .spec
717
+ alloc .spec .mem_obj_id = sobj .idx
718
+ alloc .spec .mem_offset = sobj .offset + alloc .offset
719
+ num_specs_processed += 1
720
+ assert (
721
+ len (spec2obj ) == num_specs_processed
722
+ ), f"All specs should be processed but there were { len (spec2obj )} specs and processed { num_specs_processed } specs"
604
723
605
724
logging .debug (f"greedy algorithm returns bufsizes: { total_sizes } " )
606
725
return total_sizes
0 commit comments