Skip to content

[mlir][ArmSME] Lower extract from 2D scalable create_mask to psel #96066

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mlir/include/mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -1276,7 +1276,7 @@ def ConvertVectorToArmSME : Pass<"convert-vector-to-arm-sme"> {
Pass that converts vector dialect operations into equivalent ArmSME dialect
operations.
}];
let dependentDialects = ["arm_sme::ArmSMEDialect"];
let dependentDialects = ["arm_sme::ArmSMEDialect", "arm_sve::ArmSVEDialect"];
}

//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Conversion/VectorToArmSME/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@ add_mlir_conversion_library(MLIRVectorToArmSME

LINK_LIBS PUBLIC
MLIRArmSMEDialect
MLIRArmSVEDialect
MLIRLLVMCommonConversion
)
87 changes: 79 additions & 8 deletions mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
#include "mlir/Dialect/ArmSME/Utils/Utils.h"
#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/Support/Casting.h"
Expand Down Expand Up @@ -719,16 +720,86 @@ struct FoldTransferWriteOfExtractTileSlice
}
};

/// Lower a `vector.extract` from a 2-D scalable `vector.create_mask` to
/// `arm_sve.psel`. Note: While psel is under ArmSVE it requires SME (or
/// SVE 2.1), so this is currently the most logical place for this lowering.
///
/// Example:
/// ```mlir
/// %mask = vector.create_mask %a, %b : vector<[4]x[8]xi1>
/// %slice = vector.extract %mask[%index]
/// : vector<[8]xi1> from vector<[4]x[8]xi1>
/// ```
/// Becomes:
/// ```
/// %mask_rows = vector.create_mask %a : vector<[4]xi1>
/// %mask_cols = vector.create_mask %b : vector<[8]xi1>
/// %slice = arm_sve.psel %mask_cols, %mask_rows[%index]
/// : vector<[8]xi1>, vector<[4]xi1>
/// ```
struct ExtractFromCreateMaskToPselLowering
: public OpRewritePattern<vector::ExtractOp> {
using OpRewritePattern<vector::ExtractOp>::OpRewritePattern;

LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
PatternRewriter &rewriter) const override {
if (extractOp.getNumIndices() != 1)
return rewriter.notifyMatchFailure(extractOp, "not single extract index");

auto resultType = extractOp.getResult().getType();
auto resultVectorType = dyn_cast<VectorType>(resultType);
if (!resultVectorType)
return rewriter.notifyMatchFailure(extractOp, "result not VectorType");

auto createMaskOp =
extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
if (!createMaskOp)
return rewriter.notifyMatchFailure(extractOp, "source not CreateMaskOp");

auto maskType = createMaskOp.getVectorType();
if (maskType.getRank() != 2 || !maskType.allDimsScalable())
return rewriter.notifyMatchFailure(createMaskOp, "not 2-D scalable mask");

auto isSVEPredicateSize = [](int64_t size) {
return size > 0 && size <= 16 && llvm::isPowerOf2_32(uint32_t(size));
};

auto rowsBaseSize = maskType.getDimSize(0);
auto colsBaseSize = maskType.getDimSize(1);
if (!isSVEPredicateSize(rowsBaseSize) || !isSVEPredicateSize(colsBaseSize))
return rewriter.notifyMatchFailure(
createMaskOp, "mask dimensions not SVE predicate-sized");

auto loc = extractOp.getLoc();
VectorType rowMaskType = VectorType::Builder(maskType).dropDim(1);
VectorType colMaskType = VectorType::Builder(maskType).dropDim(0);

// Create the two 1-D masks at the location of the 2-D create_mask (which is
// usually outside a loop). This prevents the need for later hoisting.
rewriter.setInsertionPoint(createMaskOp);
auto rowMask = rewriter.create<vector::CreateMaskOp>(
loc, rowMaskType, createMaskOp.getOperand(0));
auto colMask = rewriter.create<vector::CreateMaskOp>(
loc, colMaskType, createMaskOp.getOperand(1));

rewriter.setInsertionPoint(extractOp);
auto position =
vector::getAsValues(rewriter, loc, extractOp.getMixedPosition());
rewriter.replaceOpWithNewOp<arm_sve::PselOp>(extractOp, colMask, rowMask,
position[0]);
return success();
}
};

} // namespace

void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
MLIRContext &ctx) {
patterns
.add<BroadcastOpToArmSMELowering, SplatOpToArmSMELowering,
TransferReadToArmSMELowering, TransferWriteToArmSMELowering,
TransposeOpToArmSMELowering, VectorLoadToArmSMELowering,
VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering,
VectorExtractToArmSMELowering, VectorInsertToArmSMELowering,
VectorPrintToArmSMELowering, FoldTransferWriteOfExtractTileSlice>(
&ctx);
patterns.add<BroadcastOpToArmSMELowering, SplatOpToArmSMELowering,
TransferReadToArmSMELowering, TransferWriteToArmSMELowering,
TransposeOpToArmSMELowering, VectorLoadToArmSMELowering,
VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering,
VectorExtractToArmSMELowering, VectorInsertToArmSMELowering,
VectorPrintToArmSMELowering, FoldTransferWriteOfExtractTileSlice,
ExtractFromCreateMaskToPselLowering>(&ctx);
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "mlir/Conversion/VectorToArmSME/VectorToArmSME.h"

#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

Expand Down
51 changes: 51 additions & 0 deletions mlir/test/Conversion/VectorToArmSME/unsupported.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -192,3 +192,54 @@ func.func @vector_outerproduct_unknown_mask(%lhs : vector<[4]xf32>, %rhs : vecto
%0 = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
"prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
}

// -----

/// Not SVE predicate-sized.

// CHECK-LABEL: @negative_vector_extract_to_psel_0
func.func @negative_vector_extract_to_psel_0(%a: index, %b: index, %index: index) -> vector<[32]xi1>
{
// CHECK-NOT: arm_sve.psel
%mask = vector.create_mask %a, %b : vector<[4]x[32]xi1>
%slice = vector.extract %mask[%index] : vector<[32]xi1> from vector<[4]x[32]xi1>
return %slice : vector<[32]xi1>
}

// -----

/// Source not 2-D scalable mask.

// CHECK-LABEL: @negative_vector_extract_to_psel_1
func.func @negative_vector_extract_to_psel_1(%a: index, %b: index, %index: index) -> vector<[8]xi1>
{
// CHECK-NOT: arm_sve.psel
%mask = vector.create_mask %a, %b : vector<4x[8]xi1>
%slice = vector.extract %mask[%index] : vector<[8]xi1> from vector<4x[8]xi1>
return %slice : vector<[8]xi1>
}

// -----

/// Source not vector.create_mask.

// CHECK-LABEL: @negative_vector_extract_to_psel_2
func.func @negative_vector_extract_to_psel_2(%mask: vector<[4]x[8]xi1>, %index: index) -> vector<[8]xi1>
{
// CHECK-NOT: arm_sve.psel
%slice = vector.extract %mask[%index] : vector<[8]xi1> from vector<[4]x[8]xi1>
return %slice : vector<[8]xi1>
}

// -----

/// Not psel-like extract.

// CHECK-LABEL: @negative_vector_extract_to_psel_3
func.func @negative_vector_extract_to_psel_3(%a: index, %b: index, %index: index) -> i1
{
// CHECK-NOT: arm_sve.psel
%mask = vector.create_mask %a, %b : vector<[4]x[8]xi1>
%el = vector.extract %mask[2, %index] : i1 from vector<[4]x[8]xi1>
return %el : i1
}
36 changes: 35 additions & 1 deletion mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1124,7 +1124,7 @@ func.func @vector_insert_element_f64(%el: f64, %row: index, %col: index) -> vect
}

//===----------------------------------------------------------------------===//
// vector.extract
// vector.extract --> arm_sme.move_tile_slice_to_vector
//===----------------------------------------------------------------------===//

// -----
Expand Down Expand Up @@ -1320,3 +1320,37 @@ func.func @vector_extract_element_f64(%row: index, %col: index) -> f64 {
%el = vector.extract %tile[%row, %col] : f64 from vector<[2]x[2]xf64>
return %el : f64
}

//===----------------------------------------------------------------------===//
// vector.extract --> arm_sve.psel
//===----------------------------------------------------------------------===//

// -----

// CHECK-LABEL: @dynamic_vector_extract_mask_to_psel(
// CHECK-SAME: %[[A:.*]]: index, %[[B:.*]]: index, %[[INDEX:.*]]: index)
func.func @dynamic_vector_extract_mask_to_psel(%a: index, %b: index, %index: index) -> vector<[8]xi1>
{
// CHECK: %[[MASK_ROWS:.*]] = vector.create_mask %[[A]] : vector<[4]xi1>
// CHECK: %[[MASK_COLS:.*]] = vector.create_mask %[[B]] : vector<[8]xi1>
// CHECK: arm_sve.psel %[[MASK_COLS]], %[[MASK_ROWS]][%[[INDEX]]] : vector<[8]xi1>, vector<[4]xi1>
%mask = vector.create_mask %a, %b : vector<[4]x[8]xi1>
%slice = vector.extract %mask[%index] : vector<[8]xi1> from vector<[4]x[8]xi1>
return %slice : vector<[8]xi1>
}

// -----

// CHECK-LABEL: @vector_extract_mask_to_psel(
// CHECK-SAME: %[[A:.*]]: index,
// CHECK-SAME: %[[B:.*]]: index)
func.func @vector_extract_mask_to_psel(%a: index, %b: index) -> vector<[2]xi1>
{
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[MASK_ROWS:.*]] = vector.create_mask %[[A]] : vector<[16]xi1>
// CHECK: %[[MASK_COLS:.*]] = vector.create_mask %[[B]] : vector<[2]xi1>
// CHECK: arm_sve.psel %[[MASK_COLS]], %[[MASK_ROWS]][%[[C1]]] : vector<[2]xi1>, vector<[16]xi1>
%mask = vector.create_mask %a, %b : vector<[16]x[2]xi1>
%slice = vector.extract %mask[1] : vector<[2]xi1> from vector<[16]x[2]xi1>
return %slice : vector<[2]xi1>
}
Loading