Skip to content

Commit 9d78e81

Browse files
committed
[mlir][vector] Fix FlattenGather for scalable vectors
This pattern flattens vector.gather ops by unrolling the outermost dimension for rank > 2 vectors. There's two issues with this pattern for scalable vectors: 1. The unrolling doesn't take vscale into account. A constraint is added to disable this pattern for vectors with leading scalable dims. 2. The scalable dims are dropped when creating the new gather. Fixed by propagating the flags. Depends on #96049.
1 parent cc145f4 commit 9d78e81

File tree

5 files changed

+46
-6
lines changed

5 files changed

+46
-6
lines changed

mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,11 @@ Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source,
211211
/// static sizes in `shape`.
212212
LogicalResult isValidMaskedInputVector(ArrayRef<int64_t> shape,
213213
ArrayRef<int64_t> inputVectorSizes);
214+
215+
/// Returns true if the leading dim(s) of `type` are fixed and the trailing dim
216+
/// is scalable.
217+
bool isTrailingDimScalable(VectorType type);
218+
214219
} // namespace vector
215220

216221
/// Constructs a permutation map of invariant memref indices to vector

mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ namespace {
5555
/// ```
5656
///
5757
/// When applied exhaustively, this will produce a sequence of 1-d gather ops.
58+
///
59+
/// Supports vector types with trailing scalable dim.
5860
struct FlattenGather : OpRewritePattern<vector::GatherOp> {
5961
using OpRewritePattern::OpRewritePattern;
6062

@@ -64,6 +66,12 @@ struct FlattenGather : OpRewritePattern<vector::GatherOp> {
6466
if (resultTy.getRank() < 2)
6567
return rewriter.notifyMatchFailure(op, "already flat");
6668

69+
// Unrolling doesn't take vscale into account. Pattern is disabled for
70+
// vectors with leading scalable dim(s).
71+
if (resultTy.isScalable() && !isTrailingDimScalable(resultTy))
72+
return rewriter.notifyMatchFailure(
73+
op, "vector type must be fixed-width or scalable in trailing dim");
74+
6775
Location loc = op.getLoc();
6876
Value indexVec = op.getIndexVec();
6977
Value maskVec = op.getMask();
@@ -73,7 +81,8 @@ struct FlattenGather : OpRewritePattern<vector::GatherOp> {
7381
loc, resultTy, rewriter.getZeroAttr(resultTy));
7482

7583
Type subTy = VectorType::get(resultTy.getShape().drop_front(),
76-
resultTy.getElementType());
84+
resultTy.getElementType(),
85+
resultTy.getScalableDims().drop_front());
7786

7887
for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) {
7988
int64_t thisIdx[1] = {i};

mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -342,11 +342,6 @@ class ScalableShapeCastOpRewritePattern
342342
rewriter.replaceOp(op, result);
343343
return success();
344344
}
345-
346-
static bool isTrailingDimScalable(VectorType type) {
347-
return type.getRank() >= 1 && type.getScalableDims().back() &&
348-
!llvm::is_contained(type.getScalableDims().drop_back(), true);
349-
}
350345
};
351346

352347
} // namespace

mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,3 +396,8 @@ vector::isValidMaskedInputVector(ArrayRef<int64_t> shape,
396396
}
397397
return success();
398398
}
399+
400+
bool vector::isTrailingDimScalable(VectorType type) {
401+
return type.getRank() >= 1 && type.getScalableDims().back() &&
402+
!llvm::is_contained(type.getScalableDims().drop_back(), true);
403+
}

mlir/test/Dialect/Vector/vector-gather-lowering.mlir

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,32 @@ func.func @gather_memref_1d_i32_index(%base: memref<?xf32>, %v: vector<2xi32>, %
7474
return %0 : vector<2x3xf32>
7575
}
7676

77+
// CHECK-LABEL: @scalable_gather_memref_2d
78+
// CHECK-SAME: %[[BASE:.*]]: memref<?x?xf32>,
79+
// CHECK-SAME: %[[IDXVEC:.*]]: vector<2x[3]xindex>,
80+
// CHECK-SAME: %[[MASK:.*]]: vector<2x[3]xi1>,
81+
// CHECK-SAME: %[[PASS:.*]]: vector<2x[3]xf32>
82+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
83+
// CHECK: %[[C1:.*]] = arith.constant 1 : index
84+
// CHECK: %[[INIT:.*]] = arith.constant dense<0.000000e+00> : vector<2x[3]xf32>
85+
// CHECK: %[[IDXVEC0:.*]] = vector.extract %[[IDXVEC]][0] : vector<[3]xindex> from vector<2x[3]xindex>
86+
// CHECK: %[[MASK0:.*]] = vector.extract %[[MASK]][0] : vector<[3]xi1> from vector<2x[3]xi1>
87+
// CHECK: %[[PASS0:.*]] = vector.extract %[[PASS]][0] : vector<[3]xf32> from vector<2x[3]xf32>
88+
// CHECK: %[[GATHER0:.*]] = vector.gather %[[BASE]]{{\[}}%[[C0]], %[[C1]]] {{\[}}%[[IDXVEC0]]], %[[MASK0]], %[[PASS0]] : memref<?x?xf32>, vector<[3]xindex>, vector<[3]xi1>, vector<[3]xf32> into vector<[3]xf32>
89+
// CHECK: %[[INS0:.*]] = vector.insert %[[GATHER0]], %[[INIT]] [0] : vector<[3]xf32> into vector<2x[3]xf32>
90+
// CHECK: %[[IDXVEC1:.*]] = vector.extract %[[IDXVEC]][1] : vector<[3]xindex> from vector<2x[3]xindex>
91+
// CHECK: %[[MASK1:.*]] = vector.extract %[[MASK]][1] : vector<[3]xi1> from vector<2x[3]xi1>
92+
// CHECK: %[[PASS1:.*]] = vector.extract %[[PASS]][1] : vector<[3]xf32> from vector<2x[3]xf32>
93+
// CHECK: %[[GATHER1:.*]] = vector.gather %[[BASE]]{{\[}}%[[C0]], %[[C1]]] {{\[}}%[[IDXVEC1]]], %[[MASK1]], %[[PASS1]] : memref<?x?xf32>, vector<[3]xindex>, vector<[3]xi1>, vector<[3]xf32> into vector<[3]xf32>
94+
// CHECK: %[[INS1:.*]] = vector.insert %[[GATHER1]], %[[INS0]] [1] : vector<[3]xf32> into vector<2x[3]xf32>
95+
// CHECK-NEXT: return %[[INS1]] : vector<2x[3]xf32>
96+
func.func @scalable_gather_memref_2d(%base: memref<?x?xf32>, %v: vector<2x[3]xindex>, %mask: vector<2x[3]xi1>, %pass_thru: vector<2x[3]xf32>) -> vector<2x[3]xf32> {
97+
%c0 = arith.constant 0 : index
98+
%c1 = arith.constant 1 : index
99+
%0 = vector.gather %base[%c0, %c1][%v], %mask, %pass_thru : memref<?x?xf32>, vector<2x[3]xindex>, vector<2x[3]xi1>, vector<2x[3]xf32> into vector<2x[3]xf32>
100+
return %0 : vector<2x[3]xf32>
101+
}
102+
77103
// CHECK-LABEL: @gather_tensor_1d
78104
// CHECK-SAME: ([[BASE:%.+]]: tensor<?xf32>, [[IDXVEC:%.+]]: vector<2xindex>, [[MASK:%.+]]: vector<2xi1>, [[PASS:%.+]]: vector<2xf32>)
79105
// CHECK-DAG: [[M0:%.+]] = vector.extract [[MASK]][0] : i1 from vector<2xi1>

0 commit comments

Comments
 (0)