Skip to content

Commit 2edc49b

Browse files
c-rhodesAlexisPerry
authored andcommitted
[mlir][vector] Disable Gather1DToConditionalLoads for scalable vectors (llvm#96049)
Pattern scalarizes vector.gather operations and is incorrect for scalable vectors.
1 parent 0753b56 commit 2edc49b

File tree

2 files changed

+13
-0
lines changed

2 files changed

+13
-0
lines changed

mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,9 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
189189
if (resultTy.getRank() != 1)
190190
return rewriter.notifyMatchFailure(op, "unsupported rank");
191191

192+
if (resultTy.isScalable())
193+
return rewriter.notifyMatchFailure(op, "not a fixed-width vector");
194+
192195
Location loc = op.getLoc();
193196
Type elemTy = resultTy.getElementType();
194197
// Vector type with a single element. Used to generate `vector.loads`.

mlir/test/Dialect/Vector/vector-gather-lowering.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,3 +206,13 @@ func.func @strided_gather(%base : memref<100x3xf32>,
206206
// CHECK: scf.if %[[MASK_3]] -> (vector<4xf32>)
207207
// CHECK: %[[M_3:.*]] = vector.load %[[COLLAPSED]][%[[IDX_3]]] : memref<300xf32>, vector<1xf32>
208208
// CHECK: %[[V_3:.*]] = vector.extract %[[M_3]][0] : f32 from vector<1xf32>
209+
210+
// CHECK-LABEL: @scalable_gather_1d
211+
// CHECK-NOT: extract
212+
// CHECK: vector.gather
213+
// CHECK-NOT: extract
214+
func.func @scalable_gather_1d(%base: tensor<?xf32>, %v: vector<[2]xindex>, %mask: vector<[2]xi1>, %pass_thru: vector<[2]xf32>) -> vector<[2]xf32> {
215+
%c0 = arith.constant 0 : index
216+
%0 = vector.gather %base[%c0][%v], %mask, %pass_thru : tensor<?xf32>, vector<[2]xindex>, vector<[2]xi1>, vector<[2]xf32> into vector<[2]xf32>
217+
return %0 : vector<[2]xf32>
218+
}

0 commit comments

Comments
 (0)