Skip to content

Commit ccef726

Browse files
committed
[mlir][VectorOps] Don't drop scalable dims when lowering transfer_reads/writes (in VectorToLLVM)
This is a follow-on to D158753, and allows the lowering of a transfer read/write of n-D vectors with a single trailing scalable dimension to primitive vector ops. The final conversion to LLVM depends on D158517 and D158752, without these patches type conversion will fail (or an assert is hit in the LLVM backend) if the final IR contains an array of scalable vectors. This patch adds `transform.apply_patterns.vector.lower_create_mask` which allows the lowering of vector.create_mask/constant_mask to be tested independently of --convert-vector-to-llvm. Reviewed By: c-rhodes, awarzynski, dcaballe Differential Revision: https://reviews.llvm.org/D159482
1 parent 6bf923d commit ccef726

File tree

6 files changed

+84
-4
lines changed

6 files changed

+84
-4
lines changed

mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,17 @@ def ApplyLowerContractionPatternsOp : Op<Transform_Dialect,
122122
}];
123123
}
124124

125+
def ApplyLowerCreateMaskPatternsOp : Op<Transform_Dialect,
126+
"apply_patterns.vector.lower_create_mask",
127+
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
128+
let description = [{
129+
Indicates that vector create_mask-like operations should be lowered to
130+
finer-grained vector primitives.
131+
}];
132+
133+
let assemblyFormat = "attr-dict";
134+
}
135+
125136
def ApplyLowerMasksPatternsOp : Op<Transform_Dialect,
126137
"apply_patterns.vector.lower_masks",
127138
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {

mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,11 @@ void transform::ApplyVectorReductionToContractPatternsOp::populatePatterns(
6464
vector::populateVectorReductionToContractPatterns(patterns);
6565
}
6666

67+
void transform::ApplyLowerCreateMaskPatternsOp::populatePatterns(
68+
RewritePatternSet &patterns) {
69+
vector::populateVectorMaskOpLoweringPatterns(patterns);
70+
}
71+
6772
void transform::ApplyRankReducingSubviewPatternsOp::populatePatterns(
6873
RewritePatternSet &patterns) {
6974
vector::populateVectorTransferDropUnitDimsPatterns(patterns);

mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,13 +58,15 @@ class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> {
5858
return rewriter.notifyMatchFailure(
5959
op, "0-D and 1-D vectors are handled separately");
6060

61+
if (dstType.getScalableDims().front())
62+
return rewriter.notifyMatchFailure(
63+
op, "Cannot unroll leading scalable dim in dstType");
64+
6165
auto loc = op.getLoc();
62-
auto eltType = dstType.getElementType();
6366
int64_t dim = dstType.getDimSize(0);
6467
Value idx = op.getOperand(0);
6568

66-
VectorType lowType =
67-
VectorType::get(dstType.getShape().drop_front(), eltType);
69+
VectorType lowType = VectorType::Builder(dstType).dropDim(0);
6870
Value trueVal = rewriter.create<vector::CreateMaskOp>(
6971
loc, lowType, op.getOperands().drop_front());
7072
Value falseVal = rewriter.create<arith::ConstantOp>(

mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -434,7 +434,7 @@ struct TransferReadToVectorLoadLowering
434434
vectorShape.end());
435435
for (unsigned i : broadcastedDims)
436436
unbroadcastedVectorShape[i] = 1;
437-
VectorType unbroadcastedVectorType = VectorType::get(
437+
VectorType unbroadcastedVectorType = read.getVectorType().cloneWith(
438438
unbroadcastedVectorShape, read.getVectorType().getElementType());
439439

440440
// `vector.load` supports vector types as memref's elements only when the

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1743,6 +1743,28 @@ func.func @transfer_read_1d_mask(%A : memref<?xf32>, %base : index) -> vector<5x
17431743

17441744
// -----
17451745

1746+
// CHECK-LABEL: func @transfer_read_1d_scalable_mask
1747+
// CHECK: %[[passtru:.*]] = arith.constant dense<0.000000e+00> : vector<[4]xf32>
1748+
// CHECK: %[[r:.*]] = llvm.intr.masked.load %{{.*}}, %{{.*}}, %[[passtru]] {alignment = 4 : i32} : (!llvm.ptr, vector<[4]xi1>, vector<[4]xf32>) -> vector<[4]xf32>
1749+
// CHECK: return %[[r]] : vector<[4]xf32>
1750+
func.func @transfer_read_1d_scalable_mask(%arg0: memref<1x?xf32>, %mask: vector<[4]xi1>) -> vector<[4]xf32> {
1751+
%c0 = arith.constant 0 : index
1752+
%pad = arith.constant 0.0 : f32
1753+
%vec = vector.transfer_read %arg0[%c0, %c0], %pad, %mask {in_bounds = [true]} : memref<1x?xf32>, vector<[4]xf32>
1754+
return %vec : vector<[4]xf32>
1755+
}
1756+
1757+
// -----
1758+
// CHECK-LABEL: func @transfer_write_1d_scalable_mask
1759+
// CHECK: llvm.intr.masked.store %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : vector<[4]xf32>, vector<[4]xi1> into !llvm.ptr
1760+
func.func @transfer_write_1d_scalable_mask(%arg0: memref<1x?xf32>, %vec: vector<[4]xf32>, %mask: vector<[4]xi1>) {
1761+
%c0 = arith.constant 0 : index
1762+
vector.transfer_write %vec, %arg0[%c0, %c0], %mask {in_bounds = [true]} : vector<[4]xf32>, memref<1x?xf32>
1763+
return
1764+
}
1765+
1766+
// -----
1767+
17461768
func.func @genbool_0d_f() -> vector<i1> {
17471769
%0 = vector.constant_mask [0] : vector<i1>
17481770
return %0 : vector<i1>
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// RUN: mlir-opt %s --test-transform-dialect-interpreter | FileCheck %s
2+
3+
// CHECK-LABEL: func.func @create_mask_2d_trailing_scalable(
4+
// CHECK-SAME: %[[arg:.*]]: index) -> vector<3x[4]xi1> {
5+
// CHECK-NEXT: %[[zero_mask_1d:.*]] = arith.constant dense<false> : vector<[4]xi1>
6+
// CHECK-NEXT: %[[zero_mask_2d:.*]] = arith.constant dense<false> : vector<3x[4]xi1>
7+
// CHECK-NEXT: %[[create_mask_1d:.*]] = vector.create_mask %[[arg]] : vector<[4]xi1>
8+
// CHECK-NEXT: %[[res_0:.*]] = vector.insert %[[create_mask_1d]], %[[zero_mask_2d]] [0] : vector<[4]xi1> into vector<3x[4]xi1>
9+
// CHECK-NEXT: %[[res_1:.*]] = vector.insert %[[create_mask_1d]], %[[res_0]] [1] : vector<[4]xi1> into vector<3x[4]xi1>
10+
// CHECK-NEXT: %[[res_2:.*]] = vector.insert %[[zero_mask_1d]], %[[res_1]] [2] : vector<[4]xi1> into vector<3x[4]xi1>
11+
// CHECK-NEXT: return %[[res_2]] : vector<3x[4]xi1>
12+
func.func @create_mask_2d_trailing_scalable(%a: index) -> vector<3x[4]xi1> {
13+
%c2 = arith.constant 2 : index
14+
%mask = vector.create_mask %c2, %a : vector<3x[4]xi1>
15+
return %mask : vector<3x[4]xi1>
16+
}
17+
18+
// -----
19+
20+
/// The following cannot be lowered as the current lowering requires unrolling
21+
/// the leading dim.
22+
23+
// CHECK-LABEL: func.func @cannot_create_mask_2d_leading_scalable(
24+
// CHECK-SAME: %[[arg:.*]]: index) -> vector<[4]x4xi1> {
25+
// CHECK: %{{.*}} = vector.create_mask %[[arg]], %{{.*}} : vector<[4]x4xi1>
26+
func.func @cannot_create_mask_2d_leading_scalable(%a: index) -> vector<[4]x4xi1> {
27+
%c1 = arith.constant 1 : index
28+
%mask = vector.create_mask %a, %c1 : vector<[4]x4xi1>
29+
return %mask : vector<[4]x4xi1>
30+
}
31+
32+
transform.sequence failures(suppress) {
33+
^bb1(%module_op: !transform.any_op):
34+
%f = transform.structured.match ops{["func.func"]} in %module_op
35+
: (!transform.any_op) -> !transform.any_op
36+
37+
transform.apply_patterns to %f {
38+
transform.apply_patterns.vector.lower_create_mask
39+
} : !transform.any_op
40+
}

0 commit comments

Comments
 (0)