Skip to content

Commit 662400e

Browse files
c-rhodesAlexisPerry
authored andcommitted
[mlir][vector] Fix FlattenGather for scalable vectors (llvm#96074)
This pattern flattens vector.gather ops by unrolling the outermost dimension for rank > 2 vectors. There's two issues with this pattern for scalable vectors: 1. The unrolling doesn't take vscale into account. A constraint is added to disable this pattern for vectors with leading scalable dims. 2. The scalable dims are dropped when creating the new gather. Fixed by propagating the flags. Depends on llvm#96049.
1 parent e6f6724 commit 662400e

File tree

2 files changed

+45
-2
lines changed

2 files changed

+45
-2
lines changed

mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ namespace {
5555
/// ```
5656
///
5757
/// When applied exhaustively, this will produce a sequence of 1-d gather ops.
58+
///
59+
/// Supports vector types with a fixed leading dimension.
5860
struct FlattenGather : OpRewritePattern<vector::GatherOp> {
5961
using OpRewritePattern::OpRewritePattern;
6062

@@ -64,6 +66,11 @@ struct FlattenGather : OpRewritePattern<vector::GatherOp> {
6466
if (resultTy.getRank() < 2)
6567
return rewriter.notifyMatchFailure(op, "already flat");
6668

69+
// Unrolling doesn't take vscale into account. Pattern is disabled for
70+
// vectors with leading scalable dim(s).
71+
if (resultTy.getScalableDims().front())
72+
return rewriter.notifyMatchFailure(op, "cannot unroll scalable dim");
73+
6774
Location loc = op.getLoc();
6875
Value indexVec = op.getIndexVec();
6976
Value maskVec = op.getMask();
@@ -72,8 +79,7 @@ struct FlattenGather : OpRewritePattern<vector::GatherOp> {
7279
Value result = rewriter.create<arith::ConstantOp>(
7380
loc, resultTy, rewriter.getZeroAttr(resultTy));
7481

75-
Type subTy = VectorType::get(resultTy.getShape().drop_front(),
76-
resultTy.getElementType());
82+
VectorType subTy = VectorType::Builder(resultTy).dropDim(0);
7783

7884
for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) {
7985
int64_t thisIdx[1] = {i};

mlir/test/Dialect/Vector/vector-gather-lowering.mlir

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,43 @@ func.func @gather_memref_1d_i32_index(%base: memref<?xf32>, %v: vector<2xi32>, %
7474
return %0 : vector<2x3xf32>
7575
}
7676

77+
// CHECK-LABEL: @scalable_gather_memref_2d
78+
// CHECK-SAME: %[[BASE:.*]]: memref<?x?xf32>,
79+
// CHECK-SAME: %[[IDXVEC:.*]]: vector<2x[3]xindex>,
80+
// CHECK-SAME: %[[MASK:.*]]: vector<2x[3]xi1>,
81+
// CHECK-SAME: %[[PASS:.*]]: vector<2x[3]xf32>
82+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
83+
// CHECK: %[[C1:.*]] = arith.constant 1 : index
84+
// CHECK: %[[INIT:.*]] = arith.constant dense<0.000000e+00> : vector<2x[3]xf32>
85+
// CHECK: %[[IDXVEC0:.*]] = vector.extract %[[IDXVEC]][0] : vector<[3]xindex> from vector<2x[3]xindex>
86+
// CHECK: %[[MASK0:.*]] = vector.extract %[[MASK]][0] : vector<[3]xi1> from vector<2x[3]xi1>
87+
// CHECK: %[[PASS0:.*]] = vector.extract %[[PASS]][0] : vector<[3]xf32> from vector<2x[3]xf32>
88+
// CHECK: %[[GATHER0:.*]] = vector.gather %[[BASE]]{{\[}}%[[C0]], %[[C1]]] {{\[}}%[[IDXVEC0]]], %[[MASK0]], %[[PASS0]] : memref<?x?xf32>, vector<[3]xindex>, vector<[3]xi1>, vector<[3]xf32> into vector<[3]xf32>
89+
// CHECK: %[[INS0:.*]] = vector.insert %[[GATHER0]], %[[INIT]] [0] : vector<[3]xf32> into vector<2x[3]xf32>
90+
// CHECK: %[[IDXVEC1:.*]] = vector.extract %[[IDXVEC]][1] : vector<[3]xindex> from vector<2x[3]xindex>
91+
// CHECK: %[[MASK1:.*]] = vector.extract %[[MASK]][1] : vector<[3]xi1> from vector<2x[3]xi1>
92+
// CHECK: %[[PASS1:.*]] = vector.extract %[[PASS]][1] : vector<[3]xf32> from vector<2x[3]xf32>
93+
// CHECK: %[[GATHER1:.*]] = vector.gather %[[BASE]]{{\[}}%[[C0]], %[[C1]]] {{\[}}%[[IDXVEC1]]], %[[MASK1]], %[[PASS1]] : memref<?x?xf32>, vector<[3]xindex>, vector<[3]xi1>, vector<[3]xf32> into vector<[3]xf32>
94+
// CHECK: %[[INS1:.*]] = vector.insert %[[GATHER1]], %[[INS0]] [1] : vector<[3]xf32> into vector<2x[3]xf32>
95+
// CHECK-NEXT: return %[[INS1]] : vector<2x[3]xf32>
96+
func.func @scalable_gather_memref_2d(%base: memref<?x?xf32>, %v: vector<2x[3]xindex>, %mask: vector<2x[3]xi1>, %pass_thru: vector<2x[3]xf32>) -> vector<2x[3]xf32> {
97+
%c0 = arith.constant 0 : index
98+
%c1 = arith.constant 1 : index
99+
%0 = vector.gather %base[%c0, %c1][%v], %mask, %pass_thru : memref<?x?xf32>, vector<2x[3]xindex>, vector<2x[3]xi1>, vector<2x[3]xf32> into vector<2x[3]xf32>
100+
return %0 : vector<2x[3]xf32>
101+
}
102+
103+
// CHECK-LABEL: @scalable_gather_cant_unroll
104+
// CHECK-NOT: extract
105+
// CHECK: vector.gather
106+
// CHECK-NOT: extract
107+
func.func @scalable_gather_cant_unroll(%base: memref<?x?xf32>, %v: vector<[4]x8xindex>, %mask: vector<[4]x8xi1>, %pass_thru: vector<[4]x8xf32>) -> vector<[4]x8xf32> {
108+
%c0 = arith.constant 0 : index
109+
%c1 = arith.constant 1 : index
110+
%0 = vector.gather %base[%c0, %c1][%v], %mask, %pass_thru : memref<?x?xf32>, vector<[4]x8xindex>, vector<[4]x8xi1>, vector<[4]x8xf32> into vector<[4]x8xf32>
111+
return %0 : vector<[4]x8xf32>
112+
}
113+
77114
// CHECK-LABEL: @gather_tensor_1d
78115
// CHECK-SAME: ([[BASE:%.+]]: tensor<?xf32>, [[IDXVEC:%.+]]: vector<2xindex>, [[MASK:%.+]]: vector<2xi1>, [[PASS:%.+]]: vector<2xf32>)
79116
// CHECK-DAG: [[M0:%.+]] = vector.extract [[MASK]][0] : i1 from vector<2xi1>

0 commit comments

Comments
 (0)