@@ -886,17 +886,31 @@ class RewriteScalarExtractOfTransferRead
886
886
SmallVector<Value> newIndices (xferOp.getIndices ().begin (),
887
887
xferOp.getIndices ().end ());
888
888
for (auto [i, pos] : llvm::enumerate (extractOp.getMixedPosition ())) {
889
- assert (isa<Attribute>(pos) && " Unexpected non-constant index" );
890
- int64_t offset = cast<IntegerAttr>(cast<Attribute>(pos)).getInt ();
891
889
int64_t idx = newIndices.size () - extractOp.getNumIndices () + i;
892
- OpFoldResult ofr = affine::makeComposedFoldedAffineApply (
893
- rewriter, extractOp.getLoc (),
894
- rewriter.getAffineSymbolExpr (0 ) + offset, {newIndices[idx]});
895
- if (auto value = dyn_cast<Value>(ofr)) {
890
+
891
+ // Compute affine expression `newIndices[idx] + pos` where `pos` can be
892
+ // either a constant or a value.
893
+ OpFoldResult composedIdx;
894
+ if (auto attr = dyn_cast<Attribute>(pos)) {
895
+ int64_t offset = cast<IntegerAttr>(attr).getInt ();
896
+ composedIdx = affine::makeComposedFoldedAffineApply (
897
+ rewriter, extractOp.getLoc (),
898
+ rewriter.getAffineSymbolExpr (0 ) + offset, {newIndices[idx]});
899
+ } else {
900
+ Value dynamicOffset = cast<Value>(pos);
901
+ AffineExpr sym0, sym1;
902
+ bindSymbols (rewriter.getContext (), sym0, sym1);
903
+ composedIdx = affine::makeComposedFoldedAffineApply (
904
+ rewriter, extractOp.getLoc (), sym0 + sym1,
905
+ {newIndices[idx], dynamicOffset});
906
+ }
907
+
908
+ // Update the corresponding index with the folded result.
909
+ if (auto value = dyn_cast<Value>(composedIdx)) {
896
910
newIndices[idx] = value;
897
911
} else {
898
912
newIndices[idx] = rewriter.create <arith::ConstantIndexOp>(
899
- extractOp.getLoc (), *getConstantIntValue (ofr ));
913
+ extractOp.getLoc (), *getConstantIntValue (composedIdx ));
900
914
}
901
915
}
902
916
if (isa<MemRefType>(xferOp.getBase ().getType ())) {
0 commit comments