Skip to content

Commit e2296d8

Browse files
authored
[mlir][ArmSME] Lower extract from 2D scalable create_mask to psel (#96066)
Example: ```mlir %mask = vector.create_mask %a, %b : vector<[4]x[8]xi1> %slice = vector.extract %mask[%index] : vector<[8]xi1> from vector<[4]x[8]xi1> ``` Becomes: ```mlir %mask_rows = vector.create_mask %a : vector<[4]xi1> %mask_cols = vector.create_mask %b : vector<[8]xi1> %slice = arm_sve.psel %mask_cols, %mask_rows[%index] : vector<[8]xi1>, vector<[4]xi1> ``` Note: While psel is under ArmSVE it requires SME (or SVE 2.1), so this is currently the most logical place for this lowering.
1 parent 94fdfc1 commit e2296d8

File tree

6 files changed

+168
-10
lines changed

6 files changed

+168
-10
lines changed

mlir/include/mlir/Conversion/Passes.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1276,7 +1276,7 @@ def ConvertVectorToArmSME : Pass<"convert-vector-to-arm-sme"> {
12761276
Pass that converts vector dialect operations into equivalent ArmSME dialect
12771277
operations.
12781278
}];
1279-
let dependentDialects = ["arm_sme::ArmSMEDialect"];
1279+
let dependentDialects = ["arm_sme::ArmSMEDialect", "arm_sve::ArmSVEDialect"];
12801280
}
12811281

12821282
//===----------------------------------------------------------------------===//

mlir/lib/Conversion/VectorToArmSME/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,6 @@ add_mlir_conversion_library(MLIRVectorToArmSME
1010

1111
LINK_LIBS PUBLIC
1212
MLIRArmSMEDialect
13+
MLIRArmSVEDialect
1314
MLIRLLVMCommonConversion
1415
)

mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp

Lines changed: 79 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
1212
#include "mlir/Dialect/ArmSME/Utils/Utils.h"
13+
#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
1314
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1415
#include "mlir/IR/BuiltinTypes.h"
1516
#include "llvm/Support/Casting.h"
@@ -719,16 +720,86 @@ struct FoldTransferWriteOfExtractTileSlice
719720
}
720721
};
721722

723+
/// Lower a `vector.extract` from a 2-D scalable `vector.create_mask` to
724+
/// `arm_sve.psel`. Note: While psel is under ArmSVE it requires SME (or
725+
/// SVE 2.1), so this is currently the most logical place for this lowering.
726+
///
727+
/// Example:
728+
/// ```mlir
729+
/// %mask = vector.create_mask %a, %b : vector<[4]x[8]xi1>
730+
/// %slice = vector.extract %mask[%index]
731+
/// : vector<[8]xi1> from vector<[4]x[8]xi1>
732+
/// ```
733+
/// Becomes:
734+
/// ```
735+
/// %mask_rows = vector.create_mask %a : vector<[4]xi1>
736+
/// %mask_cols = vector.create_mask %b : vector<[8]xi1>
737+
/// %slice = arm_sve.psel %mask_cols, %mask_rows[%index]
738+
/// : vector<[8]xi1>, vector<[4]xi1>
739+
/// ```
740+
struct ExtractFromCreateMaskToPselLowering
741+
: public OpRewritePattern<vector::ExtractOp> {
742+
using OpRewritePattern<vector::ExtractOp>::OpRewritePattern;
743+
744+
LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
745+
PatternRewriter &rewriter) const override {
746+
if (extractOp.getNumIndices() != 1)
747+
return rewriter.notifyMatchFailure(extractOp, "not single extract index");
748+
749+
auto resultType = extractOp.getResult().getType();
750+
auto resultVectorType = dyn_cast<VectorType>(resultType);
751+
if (!resultVectorType)
752+
return rewriter.notifyMatchFailure(extractOp, "result not VectorType");
753+
754+
auto createMaskOp =
755+
extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
756+
if (!createMaskOp)
757+
return rewriter.notifyMatchFailure(extractOp, "source not CreateMaskOp");
758+
759+
auto maskType = createMaskOp.getVectorType();
760+
if (maskType.getRank() != 2 || !maskType.allDimsScalable())
761+
return rewriter.notifyMatchFailure(createMaskOp, "not 2-D scalable mask");
762+
763+
auto isSVEPredicateSize = [](int64_t size) {
764+
return size > 0 && size <= 16 && llvm::isPowerOf2_32(uint32_t(size));
765+
};
766+
767+
auto rowsBaseSize = maskType.getDimSize(0);
768+
auto colsBaseSize = maskType.getDimSize(1);
769+
if (!isSVEPredicateSize(rowsBaseSize) || !isSVEPredicateSize(colsBaseSize))
770+
return rewriter.notifyMatchFailure(
771+
createMaskOp, "mask dimensions not SVE predicate-sized");
772+
773+
auto loc = extractOp.getLoc();
774+
VectorType rowMaskType = VectorType::Builder(maskType).dropDim(1);
775+
VectorType colMaskType = VectorType::Builder(maskType).dropDim(0);
776+
777+
// Create the two 1-D masks at the location of the 2-D create_mask (which is
778+
// usually outside a loop). This prevents the need for later hoisting.
779+
rewriter.setInsertionPoint(createMaskOp);
780+
auto rowMask = rewriter.create<vector::CreateMaskOp>(
781+
loc, rowMaskType, createMaskOp.getOperand(0));
782+
auto colMask = rewriter.create<vector::CreateMaskOp>(
783+
loc, colMaskType, createMaskOp.getOperand(1));
784+
785+
rewriter.setInsertionPoint(extractOp);
786+
auto position =
787+
vector::getAsValues(rewriter, loc, extractOp.getMixedPosition());
788+
rewriter.replaceOpWithNewOp<arm_sve::PselOp>(extractOp, colMask, rowMask,
789+
position[0]);
790+
return success();
791+
}
792+
};
793+
722794
} // namespace
723795

724796
void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
725797
MLIRContext &ctx) {
726-
patterns
727-
.add<BroadcastOpToArmSMELowering, SplatOpToArmSMELowering,
728-
TransferReadToArmSMELowering, TransferWriteToArmSMELowering,
729-
TransposeOpToArmSMELowering, VectorLoadToArmSMELowering,
730-
VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering,
731-
VectorExtractToArmSMELowering, VectorInsertToArmSMELowering,
732-
VectorPrintToArmSMELowering, FoldTransferWriteOfExtractTileSlice>(
733-
&ctx);
798+
patterns.add<BroadcastOpToArmSMELowering, SplatOpToArmSMELowering,
799+
TransferReadToArmSMELowering, TransferWriteToArmSMELowering,
800+
TransposeOpToArmSMELowering, VectorLoadToArmSMELowering,
801+
VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering,
802+
VectorExtractToArmSMELowering, VectorInsertToArmSMELowering,
803+
VectorPrintToArmSMELowering, FoldTransferWriteOfExtractTileSlice,
804+
ExtractFromCreateMaskToPselLowering>(&ctx);
734805
}

mlir/lib/Conversion/VectorToArmSME/VectorToArmSMEPass.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "mlir/Conversion/VectorToArmSME/VectorToArmSME.h"
1010

1111
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
12+
#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
1213
#include "mlir/Pass/Pass.h"
1314
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1415

mlir/test/Conversion/VectorToArmSME/unsupported.mlir

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,3 +192,54 @@ func.func @vector_outerproduct_unknown_mask(%lhs : vector<[4]xf32>, %rhs : vecto
192192
%0 = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
193193
"prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
194194
}
195+
196+
// -----
197+
198+
/// Not SVE predicate-sized.
199+
200+
// CHECK-LABEL: @negative_vector_extract_to_psel_0
201+
func.func @negative_vector_extract_to_psel_0(%a: index, %b: index, %index: index) -> vector<[32]xi1>
202+
{
203+
// CHECK-NOT: arm_sve.psel
204+
%mask = vector.create_mask %a, %b : vector<[4]x[32]xi1>
205+
%slice = vector.extract %mask[%index] : vector<[32]xi1> from vector<[4]x[32]xi1>
206+
return %slice : vector<[32]xi1>
207+
}
208+
209+
// -----
210+
211+
/// Source not 2-D scalable mask.
212+
213+
// CHECK-LABEL: @negative_vector_extract_to_psel_1
214+
func.func @negative_vector_extract_to_psel_1(%a: index, %b: index, %index: index) -> vector<[8]xi1>
215+
{
216+
// CHECK-NOT: arm_sve.psel
217+
%mask = vector.create_mask %a, %b : vector<4x[8]xi1>
218+
%slice = vector.extract %mask[%index] : vector<[8]xi1> from vector<4x[8]xi1>
219+
return %slice : vector<[8]xi1>
220+
}
221+
222+
// -----
223+
224+
/// Source not vector.create_mask.
225+
226+
// CHECK-LABEL: @negative_vector_extract_to_psel_2
227+
func.func @negative_vector_extract_to_psel_2(%mask: vector<[4]x[8]xi1>, %index: index) -> vector<[8]xi1>
228+
{
229+
// CHECK-NOT: arm_sve.psel
230+
%slice = vector.extract %mask[%index] : vector<[8]xi1> from vector<[4]x[8]xi1>
231+
return %slice : vector<[8]xi1>
232+
}
233+
234+
// -----
235+
236+
/// Not psel-like extract.
237+
238+
// CHECK-LABEL: @negative_vector_extract_to_psel_3
239+
func.func @negative_vector_extract_to_psel_3(%a: index, %b: index, %index: index) -> i1
240+
{
241+
// CHECK-NOT: arm_sve.psel
242+
%mask = vector.create_mask %a, %b : vector<[4]x[8]xi1>
243+
%el = vector.extract %mask[2, %index] : i1 from vector<[4]x[8]xi1>
244+
return %el : i1
245+
}

mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1124,7 +1124,7 @@ func.func @vector_insert_element_f64(%el: f64, %row: index, %col: index) -> vect
11241124
}
11251125

11261126
//===----------------------------------------------------------------------===//
1127-
// vector.extract
1127+
// vector.extract --> arm_sme.move_tile_slice_to_vector
11281128
//===----------------------------------------------------------------------===//
11291129

11301130
// -----
@@ -1320,3 +1320,37 @@ func.func @vector_extract_element_f64(%row: index, %col: index) -> f64 {
13201320
%el = vector.extract %tile[%row, %col] : f64 from vector<[2]x[2]xf64>
13211321
return %el : f64
13221322
}
1323+
1324+
//===----------------------------------------------------------------------===//
1325+
// vector.extract --> arm_sve.psel
1326+
//===----------------------------------------------------------------------===//
1327+
1328+
// -----
1329+
1330+
// CHECK-LABEL: @dynamic_vector_extract_mask_to_psel(
1331+
// CHECK-SAME: %[[A:.*]]: index, %[[B:.*]]: index, %[[INDEX:.*]]: index)
1332+
func.func @dynamic_vector_extract_mask_to_psel(%a: index, %b: index, %index: index) -> vector<[8]xi1>
1333+
{
1334+
// CHECK: %[[MASK_ROWS:.*]] = vector.create_mask %[[A]] : vector<[4]xi1>
1335+
// CHECK: %[[MASK_COLS:.*]] = vector.create_mask %[[B]] : vector<[8]xi1>
1336+
// CHECK: arm_sve.psel %[[MASK_COLS]], %[[MASK_ROWS]][%[[INDEX]]] : vector<[8]xi1>, vector<[4]xi1>
1337+
%mask = vector.create_mask %a, %b : vector<[4]x[8]xi1>
1338+
%slice = vector.extract %mask[%index] : vector<[8]xi1> from vector<[4]x[8]xi1>
1339+
return %slice : vector<[8]xi1>
1340+
}
1341+
1342+
// -----
1343+
1344+
// CHECK-LABEL: @vector_extract_mask_to_psel(
1345+
// CHECK-SAME: %[[A:.*]]: index,
1346+
// CHECK-SAME: %[[B:.*]]: index)
1347+
func.func @vector_extract_mask_to_psel(%a: index, %b: index) -> vector<[2]xi1>
1348+
{
1349+
// CHECK: %[[C1:.*]] = arith.constant 1 : index
1350+
// CHECK: %[[MASK_ROWS:.*]] = vector.create_mask %[[A]] : vector<[16]xi1>
1351+
// CHECK: %[[MASK_COLS:.*]] = vector.create_mask %[[B]] : vector<[2]xi1>
1352+
// CHECK: arm_sve.psel %[[MASK_COLS]], %[[MASK_ROWS]][%[[C1]]] : vector<[2]xi1>, vector<[16]xi1>
1353+
%mask = vector.create_mask %a, %b : vector<[16]x[2]xi1>
1354+
%slice = vector.extract %mask[1] : vector<[2]xi1> from vector<[16]x[2]xi1>
1355+
return %slice : vector<[2]xi1>
1356+
}

0 commit comments

Comments
 (0)