Skip to content

Commit 91e4da6

Browse files
committed
Review fixups
1 parent b6015b0 commit 91e4da6

File tree

3 files changed

+12
-12
lines changed

3 files changed

+12
-12
lines changed

mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,13 +124,15 @@ class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
124124
if (rank == 1) {
125125
if (trueDimSize == 0 || trueDimSize == dstType.getDimSize(0)) {
126126
// Use constant splat for 'all set' or 'none set' dims.
127-
// This produces correct code for scalable dimensions.
127+
// This produces correct code for scalable dimensions (it will lower to
128+
// a constant splat).
128129
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
129130
op, DenseElementsAttr::get(dstType, trueDimSize != 0));
130131
} else {
131132
// Express constant 1-D case in explicit vector form:
132133
// [T,..,T,F,..,F].
133-
SmallVector<bool> values(dstType.getDimSize(0));
134+
// Note: The verifier would reject this case for scalable vectors.
135+
SmallVector<bool> values(dstType.getDimSize(0), false);
134136
for (int64_t d = 0; d < trueDimSize; d++)
135137
values[d] = true;
136138
rewriter.replaceOpWithNewOp<arith::ConstantOp>(

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1819,39 +1819,38 @@ func.func @genbool_1d() -> vector<8xi1> {
18191819

18201820
// -----
18211821

1822-
func.func @genbool_1d_scalable_pfalse() -> vector<[8]xi1> {
1822+
func.func @genbool_1d_scalable_all_false() -> vector<[8]xi1> {
18231823
%0 = vector.constant_mask [0] : vector<[8]xi1>
18241824
return %0 : vector<[8]xi1>
18251825
}
1826-
// CHECK-LABEL: func @genbool_1d_scalable_pfalse
1826+
// CHECK-LABEL: func @genbool_1d_scalable_all_false
18271827
// CHECK: %[[VAL_0:.*]] = arith.constant dense<false> : vector<[8]xi1>
18281828
// CHECK: return %[[VAL_0]] : vector<[8]xi1>
18291829

18301830
// -----
18311831

1832-
func.func @genbool_1d_scalable_ptrue() -> vector<[8]xi1> {
1832+
func.func @genbool_1d_scalable_all_true() -> vector<[8]xi1> {
18331833
%0 = vector.constant_mask [8] : vector<[8]xi1>
18341834
return %0 : vector<[8]xi1>
18351835
}
1836-
// CHECK-LABEL: func @genbool_1d_scalable_ptrue
1836+
// CHECK-LABEL: func @genbool_1d_scalable_all_true
18371837
// CHECK: %[[VAL_0:.*]] = arith.constant dense<true> : vector<[8]xi1>
18381838
// CHECK: return %[[VAL_0]] : vector<[8]xi1>
18391839

18401840
// -----
18411841

1842-
func.func @genbool_2d_scalable() -> vector<4x[4]xi1> {
1842+
func.func @genbool_2d_trailing_scalable() -> vector<4x[4]xi1> {
18431843
%0 = vector.constant_mask [2, 4] : vector<4x[4]xi1>
18441844
return %0 : vector<4x[4]xi1>
18451845
}
1846-
// CHECK-LABEL: func.func @genbool_2d_scalable() -> vector<4x[4]xi1> {
1846+
// CHECK-LABEL: func.func @genbool_2d_trailing_scalable
18471847
// CHECK: %[[VAL_0:.*]] = arith.constant dense<true> : vector<[4]xi1>
18481848
// CHECK: %[[VAL_1:.*]] = arith.constant dense<false> : vector<4x[4]xi1>
18491849
// CHECK: %[[VAL_2:.*]] = builtin.unrealized_conversion_cast %[[VAL_1]] : vector<4x[4]xi1> to !llvm.array<4 x vector<[4]xi1>>
18501850
// CHECK: %[[VAL_3:.*]] = llvm.insertvalue %[[VAL_0]], %[[VAL_2]][0] : !llvm.array<4 x vector<[4]xi1>>
18511851
// CHECK: %[[VAL_4:.*]] = llvm.insertvalue %[[VAL_0]], %[[VAL_3]][1] : !llvm.array<4 x vector<[4]xi1>>
18521852
// CHECK: %[[VAL_5:.*]] = builtin.unrealized_conversion_cast %[[VAL_4]] : !llvm.array<4 x vector<[4]xi1>> to vector<4x[4]xi1>
18531853
// CHECK: return %[[VAL_5]] : vector<4x[4]xi1>
1854-
// CHECK: }
18551854

18561855
// -----
18571856

@@ -1861,10 +1860,9 @@ func.func @cannot_genbool_2d_leading_scalable() -> vector<[4]x4xi1> {
18611860
%0 = vector.constant_mask [4, 2] : vector<[4]x4xi1>
18621861
return %0 : vector<[4]x4xi1>
18631862
}
1864-
// CHECK-LABEL: func.func @cannot_genbool_2d_leading_scalable() -> vector<[4]x4xi1> {
1863+
// CHECK-LABEL: func.func @cannot_genbool_2d_leading_scalable
18651864
// CHECK: %[[VAL_0:.*]] = vector.constant_mask [4, 2] : vector<[4]x4xi1>
18661865
// CHECK: return %[[VAL_0]] : vector<[4]x4xi1>
1867-
// CHECK: }
18681866

18691867
// -----
18701868

mlir/test/Dialect/Vector/ops.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1007,7 +1007,7 @@ func.func @contraction_masked_scalable(%A: vector<3x4xf32>,
10071007
%C: vector<3x[8]xf32>,
10081008
%M : vector<3x[8]x4xi1>) -> vector<3x[8]xf32> {
10091009
// CHECK: vector.mask %[[M]] { vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[A]], %[[B]], %[[C]] : vector<3x4xf32>, vector<4x[8]xf32> into vector<3x[8]xf32> } : vector<3x[8]x4xi1> -> vector<3x[8]xf32>
1010-
%0 = vector.mask %M { vector.contract #matmat_trait %A, %B, %C : vector<3x4xf32>, vector<4x[8]xf32> into vector<3x[8]xf32> }
1010+
%0 = vector.mask %M { vector.contract #matmat_trait %A, %B, %C : vector<3x4xf32>, vector<4x[8]xf32> into vector<3x[8]xf32> }
10111011
: vector<3x[8]x4xi1> -> vector<3x[8]xf32>
10121012
return %0 : vector<3x[8]xf32>
10131013
}

0 commit comments

Comments
 (0)