@@ -480,7 +480,15 @@ class LowerMatrixIntrinsics {
480
480
// / the result value of the instruction, with the only exceptions being store
481
481
// / instructions and the matrix_column_major_store intrinsics. For those, the
482
482
// / shape information indicates that those instructions should be lowered
483
- // / using shape information as well.
483
+ // / using shape information as well. Note that extra care is needed when
484
+ // / erasing or RAUW'ing a value that is present in ShapeMap. If the
485
+ // / replacement is also a matrix operation, use
486
+ // / updateShapeAndReplaceAllUsesWith to make sure the replacement is added to
487
+ // / ShapeMap. We don't use ValueMap, as there are also cases where we do not
488
+ // / want to add shape information for a replacement instruction. When directly
489
+ // / erasing a value with an entry in ShapeMap, use
490
+ // / eraseFromParentAndRemoveFromShapeMap to make sure ShapeMap is also updated
491
+ // / accordingly.
484
492
DenseMap<Value *, ShapeInfo> ShapeMap;
485
493
486
494
// / List of instructions to remove. While lowering, we are not replacing all
@@ -743,6 +751,8 @@ class LowerMatrixIntrinsics {
743
751
return Operation (T0, Shape0.t (), T1, Shape1.t ());
744
752
}
745
753
754
+ // / Erase \p Inst from both ShapeMap (if an entry exists) and erase \p Inst
755
+ // / itself.
746
756
void eraseFromParentAndRemoveFromShapeMap (Instruction *Inst) {
747
757
auto Iter = ShapeMap.find (Inst);
748
758
if (Iter != ShapeMap.end ())
@@ -763,6 +773,8 @@ class LowerMatrixIntrinsics {
763
773
eraseFromParentAndRemoveFromShapeMap (Inst);
764
774
}
765
775
776
+ // / Add a new entry to ShapeMap for \p New with \p Old's shape info, erase the
777
+ // / entry for \p Old and replace all uses of \p Old with \p New.
766
778
void updateShapeAndReplaceAllUsesWith (Instruction &Old, Value *New) {
767
779
// We need to remove Old from the ShapeMap otherwise RAUW will replace it
768
780
// with New. We should only add New it it supportsShapeInfo so we insert
0 commit comments