Skip to content

Commit ed7d5bd

Browse files
committed
[mlir][Vector] Support xfer_read(vector.extract)) folding with dynamic indices
This PR is part of the step to remove `vector.extractelement` and `vector.insertelement` ops. It adds support for folding `vector.transfer_read(vector.extract) -> memref.load` with dynamic indices, which is currently supported by `vector.extractelement`.
1 parent eed98e1 commit ed7d5bd

File tree

2 files changed

+49
-5
lines changed

2 files changed

+49
-5
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -886,12 +886,26 @@ class RewriteScalarExtractOfTransferRead
886886
SmallVector<Value> newIndices(xferOp.getIndices().begin(),
887887
xferOp.getIndices().end());
888888
for (auto [i, pos] : llvm::enumerate(extractOp.getMixedPosition())) {
889-
assert(isa<Attribute>(pos) && "Unexpected non-constant index");
890-
int64_t offset = cast<IntegerAttr>(cast<Attribute>(pos)).getInt();
891889
int64_t idx = newIndices.size() - extractOp.getNumIndices() + i;
892-
OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
893-
rewriter, extractOp.getLoc(),
894-
rewriter.getAffineSymbolExpr(0) + offset, {newIndices[idx]});
890+
891+
// Compute affine expression `newIndices[idx] + pos` where `pos` can be
892+
// either a constant or a value.
893+
OpFoldResult ofr;
894+
if (auto attr = dyn_cast<Attribute>(pos)) {
895+
int64_t offset = cast<IntegerAttr>(attr).getInt();
896+
ofr = affine::makeComposedFoldedAffineApply(
897+
rewriter, extractOp.getLoc(),
898+
rewriter.getAffineSymbolExpr(0) + offset, {newIndices[idx]});
899+
} else {
900+
Value dynamicOffset = cast<Value>(pos);
901+
AffineExpr sym0, sym1;
902+
bindSymbols(rewriter.getContext(), sym0, sym1);
903+
ofr = affine::makeComposedFoldedAffineApply(
904+
rewriter, extractOp.getLoc(), sym0 + sym1,
905+
{newIndices[idx], dynamicOffset});
906+
}
907+
908+
// Update the corresponding index with the folded result.
895909
if (auto value = dyn_cast<Value>(ofr)) {
896910
newIndices[idx] = value;
897911
} else {

mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,3 +148,33 @@ func.func @subvector_extract(%m: memref<?x?xf32>, %idx: index) -> vector<16xf32>
148148
return %1 : vector<16xf32>
149149
}
150150

151+
// -----
152+
153+
// CHECK: #[[$MAP:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
154+
// CHECK-LABEL: func @transfer_read_1d_extract_dynamic(
155+
// CHECK-SAME: %[[MEMREF:.*]]: memref<?xf32>, %[[M_IDX:.*]]: index, %[[E_IDX:.*]]: index
156+
// CHECK: %[[APPLY:.*]] = affine.apply #[[$MAP]]()[%[[M_IDX]], %[[E_IDX]]]
157+
// CHECK: %[[RES:.*]] = memref.load %[[MEMREF]][%[[APPLY]]]
158+
func.func @transfer_read_1d_extract_dynamic(%m: memref<?xf32>, %idx: index,
159+
%offset: index) -> f32 {
160+
%cst = arith.constant 0.0 : f32
161+
%vec = vector.transfer_read %m[%idx], %cst {in_bounds = [true]} : memref<?xf32>, vector<5xf32>
162+
%elem = vector.extract %vec[%offset] : f32 from vector<5xf32>
163+
return %elem : f32
164+
}
165+
166+
// -----
167+
168+
// CHECK: #[[$MAP:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
169+
// CHECK-LABEL: func @transfer_read_2d_extract_dynamic(
170+
// CHECK-SAME: %[[MEMREF:.*]]: memref<?x?xf32>, %[[M_IDX:.*]]: index, %[[ROW:.*]]: index, %[[COL:.*]]: index
171+
// CHECK: %[[ROW_APPLY:.*]] = affine.apply #[[$MAP]]()[%[[M_IDX]], %[[ROW]]]
172+
// CHECK: %[[COL_APPLY:.*]] = affine.apply #[[$MAP]]()[%[[M_IDX]], %[[COL]]]
173+
// CHECK: %[[RES:.*]] = memref.load %[[MEMREF]][%[[ROW_APPLY]], %[[COL_APPLY]]]
174+
func.func @transfer_read_2d_extract_dynamic(%m: memref<?x?xf32>, %idx: index,
175+
%row_offset: index, %col_offset: index) -> f32 {
176+
%cst = arith.constant 0.0 : f32
177+
%vec = vector.transfer_read %m[%idx, %idx], %cst {in_bounds = [true, true]} : memref<?x?xf32>, vector<10x5xf32>
178+
%elem = vector.extract %vec[%row_offset, %col_offset] : f32 from vector<10x5xf32>
179+
return %elem : f32
180+
}

0 commit comments

Comments
 (0)