Skip to content

[mlir][ArmSME] Support lowering masked vector.outerproduct ops to SME #69604

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 108 additions & 1 deletion mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,27 @@ def ArmSME_TileSliceLayoutAttr : EnumAttr<ArmSME_Dialect, TileSliceLayout,
let defaultValue = "TileSliceLayout::Horizontal";
}

def CombiningKind : I32EnumAttr<"CombiningKind", "Kind of combining function", [
I32EnumAttrCase<"Add", 0, "add">,
I32EnumAttrCase<"Sub", 1, "sub">,
]> {
let cppNamespace = "::mlir::arm_sme";
let genSpecializedAttr = 0;
}

/// An attribute that specifies how to combine a newly produced value with the
/// accumulator. This is similar to vector::CombiningKindAttr, but limited to
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wondering if it would make sense using the vector one and check if the kind is supported during verification?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did try this, but including VectorAttributes.td within this file results in the generator creating duplicate definitions of attributes from the vector dialect (which fails to build).

/// the functions that are valid for SME outer products. Add corresponds to a
/// MOPA and sub to a MOPS.
/// E.g. For f32:
/// FMOPA: https://developer.arm.com/documentation/ddi0602/2022-03/SME-Instructions/FMOPA--non-widening---Floating-point-outer-product-and-accumulate-
/// FMOPS: https://developer.arm.com/documentation/ddi0602/2022-03/SME-Instructions/FMOPS--non-widening---Floating-point-outer-product-and-subtract-
def ArmSME_CombiningKindAttr : EnumAttr<ArmSME_Dialect, CombiningKind,
"kind"> {
let assemblyFormat = "`<` $value `>`";
let defaultValue = "CombiningKind::Add";
}

//===----------------------------------------------------------------------===//
// ArmSME op definitions
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -209,7 +230,7 @@ def ZeroOp : ArmSME_Op<"zero", [Pure]> {
let results = (outs SMETile:$res);
let description = [{
Initialise ZA with 0. This operation is convenient wrapper for the SME
`zero` intrinsic and instruction.
`zero` intrinsic and instruction.

Example 1: Zero an 8-bit element ZA tile.

Expand Down Expand Up @@ -561,4 +582,90 @@ def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [Pure,
}];
}

class HasMatchingMaskTypeConstraint<string operand> :
OptionalTypesMatchWith<
"shape of `" # operand # "Mask` matches `" # operand # "`",
operand, operand # "Mask",
"::llvm::cast<mlir::VectorType>($_self).cloneWith({}, IntegerType::get($_ctxt, 1))">;

class OuterProductResultTileTypeConstraint<string operand> :
OptionalTypesMatchWith<operand # "type is derived from `lhs` and `rhs`",
"lhs", operand,
"[&]{"
" auto vectorType = ::llvm::cast<mlir::VectorType>($_self);"
" int64_t size = vectorType.getDimSize(0);"
" return VectorType::get("
" { size, size }, vectorType.getElementType(), { true, true });"
"}()">;

def OuterProductOp :
ArmSME_Op<"outerproduct", [Pure,
AttrSizedOperandSegments,
AllTypesMatch<["lhs", "rhs"]>,
HasMatchingMaskTypeConstraint<"lhs">,
HasMatchingMaskTypeConstraint<"rhs">,
PredOpTrait<
"both `lhsMask` and `rhsMask` should be provided or neither",
CPred<"bool(getLhsMask()) == bool(getRhsMask())">>,
OuterProductResultTileTypeConstraint<"result">,
OuterProductResultTileTypeConstraint<"acc">
]>
{
let summary = "Outer product with optional fused add/sub";

let description = [{
This operation represents an outer product that fits within an SME tile.
All operands must be SVE vectors and the result a SME tile. Unlike
`vector.outerproduct` masking is on the operands (rather than the result),
which mirrors the SME instructions.

Example 1: Unmasked outerproduct (without accumulator)
```mlir
// Not specifying an accumulator implicitly zeros the destination tile.
%result = arm_sme.outerproduct $lhs, $rhs : vector<[4]xf32>, vector<[4]xf32>
```

Example 2: Unmasked outerproduct (with accumulator)
```mlir
%result = arm_sme.outerproduct $lhs, $rhs acc($accumulator)
: vector<[4]xf32>, vector<[4]xf32>
```

Example 3: Masked outerproduct
```mlir
%result = arm_sme.outerproduct $lhs, $rhs masks($lhsMask, $rhsMask)
: vector<[4]xf32>, vector<[4]xf32>
```

Example 4: Masked outerproduct (with accumulator)
```mlir
%result = arm_sme.outerproduct $lhs, $rhs acc($accumulator) masks($lhsMask, $rhsMask)
: vector<[4]xf32>, vector<[4]xf32>
```
}];

let arguments = (ins
SVEVector:$lhs, SVEVector:$rhs,
Optional<SVEPredicate>:$lhsMask,
Optional<SVEPredicate>:$rhsMask,
Optional<SMETile>: $acc,
ArmSME_CombiningKindAttr:$kind);
let results = (outs SMETile:$result);

let assemblyFormat = [{
$lhs `,` $rhs
oilist(
`kind` `` $kind
| `acc` `` `(` $acc `)`
| `masks` `` `(` $lhsMask `,` $rhsMask `)`
) attr-dict `:` type($lhs) `,` type($rhs)
}];

let extraClassDeclaration = [{
VectorType getLhsType() { return llvm::cast<VectorType>(getLhs().getType()); }
VectorType getRhsType() { return llvm::cast<VectorType>(getRhs().getType()); }
VectorType getResultType() { return llvm::cast<VectorType>(getResult().getType()); }
}];
}

#endif // ARMSME_OPS
102 changes: 101 additions & 1 deletion mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -427,12 +427,112 @@ struct TransposeOpToArmSMELowering
}
};

/// Conversion pattern for vector.outerproduct.
///
/// If the vector.outerproduct is masked (and the mask is from a
/// vector.create_mask), then the mask is decomposed into two 1-D masks for the
/// operands.
///
/// Example:
///
/// %mask = vector.create_mask %dimA, %dimB : vector<[4]x[4]xi1>
/// %result = vector.mask %mask {
/// vector.outerproduct %vecA, %vecB
/// : vector<[4]xf32>, vector<[4]xf32>
/// } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
///
/// is converted to:
///
/// %maskA = vector.create_mask %dimA : vector<[4]xi1>
/// %maskB = vector.create_mask %dimB : vector<[4]xi1>
/// %result = arm_sme.outerproduct %vecA, %vecB masks(%maskA, %maskB)
/// : vector<[4]xf32>, vector<[4]xf32>
///
/// Unmasked outerproducts can be directly replaced with the arm_sme op.
///
/// Example:
///
/// %result = vector.outerproduct %vecA, %vecB
/// : vector<[4]xf32>, vector<[4]xf32>
///
/// is converted to:
///
/// %result = arm_sme.outerproduct %vecA, %vecB
/// : vector<[4]xf32>, vector<[4]xf32>
///
struct VectorOuterProductToArmSMELowering
: public OpRewritePattern<vector::OuterProductOp> {

using OpRewritePattern<vector::OuterProductOp>::OpRewritePattern;

LogicalResult matchAndRewrite(vector::OuterProductOp outerProductOp,
PatternRewriter &rewriter) const override {

// We don't yet support lowering AXPY operations to SME. These could be
// lowered by masking out all but the first element of the LHS.
if (!isa<VectorType>(outerProductOp.getOperandTypeRHS()))
return outerProductOp.emitError("AXPY operations not supported");

if (!arm_sme::isValidSMETileVectorType(
outerProductOp.getResultVectorType()))
return outerProductOp.emitError(
"outer product does not fit into SME tile");

auto kind = outerProductOp.getKind();
if (kind != vector::CombiningKind::ADD)
return outerProductOp.emitError(
"unsupported kind (lowering to SME only supports ADD at the moment)");

Value lhsMask = {};
Value rhsMask = {};
Operation *rootOp = outerProductOp;
auto loc = outerProductOp.getLoc();
if (outerProductOp.isMasked()) {
auto maskOp = outerProductOp.getMaskingOp();
rewriter.setInsertionPoint(maskOp);
rootOp = maskOp;
auto operandMasks = decomposeResultMask(loc, maskOp.getMask(), rewriter);
if (failed(operandMasks))
return failure();
std::tie(lhsMask, rhsMask) = *operandMasks;
}

rewriter.replaceOpWithNewOp<arm_sme::OuterProductOp>(
rootOp, outerProductOp.getResultVectorType(), outerProductOp.getLhs(),
outerProductOp.getRhs(), lhsMask, rhsMask, outerProductOp.getAcc());

return success();
}

static FailureOr<std::pair<Value, Value>>
decomposeResultMask(Location loc, Value mask, PatternRewriter &rewriter) {
// Attempt to extract masks from vector.create_mask.
// TODO: Add support for other mask sources.
auto createMaskOp = mask.getDefiningOp<vector::CreateMaskOp>();
if (!createMaskOp)
return failure();

auto maskType = createMaskOp.getVectorType();
Value lhsMaskDim = createMaskOp.getOperand(0);
Value rhsMaskDim = createMaskOp.getOperand(1);

VectorType operandMaskType = VectorType::Builder(maskType).dropDim(0);
Value lhsMask =
rewriter.create<vector::CreateMaskOp>(loc, operandMaskType, lhsMaskDim);
Value rhsMask =
rewriter.create<vector::CreateMaskOp>(loc, operandMaskType, rhsMaskDim);

return std::make_pair(lhsMask, rhsMask);
}
};

} // namespace

void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
MLIRContext &ctx) {
patterns.add<BroadcastOpToArmSMELowering, ConstantOpToArmSMELowering,
SplatOpToArmSMELowering, TransferReadPermutationToArmSMELowering,
TransferWriteToArmSMELowering, TransposeOpToArmSMELowering,
VectorLoadToArmSMELowering, VectorStoreToArmSMELowering>(&ctx);
VectorLoadToArmSMELowering, VectorStoreToArmSMELowering,
VectorOuterProductToArmSMELowering>(&ctx);
}
66 changes: 29 additions & 37 deletions mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -460,11 +460,11 @@ struct MoveTileSliceToVectorArmSMELowering
}
};

/// Lower `vector.outerproduct` to SME MOPA intrinsics.
/// Lower `arm_sme.outerproduct` to SME MOPA intrinsics.
///
/// Example:
///
/// %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>}
/// %0 = arm_sme.outerproduct %lhs, %rhs acc(%acc)
/// : vector<[4]xf32>, vector<[4]xf32>
///
/// is converted to:
Expand All @@ -474,13 +474,13 @@ struct MoveTileSliceToVectorArmSMELowering
/// vector<[4]xf32>) -> ()
///
/// Currently only supports FMOPA and BFMOPA (non-widening).
struct VectorOuterProductToArmSMELowering
: public ConvertOpToLLVMPattern<vector::OuterProductOp> {
using ConvertOpToLLVMPattern<vector::OuterProductOp>::ConvertOpToLLVMPattern;
struct OuterProductOpConversion
: public ConvertOpToLLVMPattern<arm_sme::OuterProductOp> {
using ConvertOpToLLVMPattern<arm_sme::OuterProductOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(vector::OuterProductOp outerProductOp,
vector::OuterProductOp::Adaptor adaptor,
matchAndRewrite(arm_sme::OuterProductOp outerProductOp,
arm_sme::OuterProductOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto isSupportedType = [](VectorType vectorType) {
// TODO: the FP outer product instruction variants are predicated on
Expand Down Expand Up @@ -512,24 +512,13 @@ struct VectorOuterProductToArmSMELowering
return true;
};

auto resultVectorType = outerProductOp.getResultVectorType();
if (!isSupportedType(resultVectorType))
return outerProductOp.emitError("unsupported type");

vector::CombiningKind kind = outerProductOp.getKind();
if (kind != vector::CombiningKind::ADD)
// TODO: support subtract.
// TODO: Support CombiningKind::Sub for outer products.
if (outerProductOp.getKind() != CombiningKind::Add)
return outerProductOp.emitError("unsupported kind");

auto maskableOp =
cast<vector::MaskableOpInterface>(outerProductOp.getOperation());
if (maskableOp.isMasked())
// TODO: support masking.
return outerProductOp.emitError("masking is currently unsupported");

if (!isa<VectorType>(outerProductOp.getOperandTypeRHS()))
// AXPY operation not suited for SME.
return failure();
auto resultVectorType = outerProductOp.getResultType();
if (!isSupportedType(resultVectorType))
return outerProductOp.emitError("unsupported type");

auto loc = outerProductOp.getLoc();

Expand All @@ -542,21 +531,24 @@ struct VectorOuterProductToArmSMELowering
auto tileId = rewriter.create<arm_sme::CastVectorToTile>(
loc, rewriter.getIntegerType(elementWidth), acc);

// Create all active predicate mask.
auto one = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI1Type(),
rewriter.getIntegerAttr(rewriter.getI1Type(), 1));
auto predTy =
VectorType::get(resultVectorType.getShape()[0], rewriter.getI1Type(),
/*scalableDims=*/{true});
auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one);

auto tileI32 = castTileIDToI32(tileId, loc, rewriter);

Value lhsMask = outerProductOp.getLhsMask();
Value rhsMask = outerProductOp.getRhsMask();

if (!lhsMask || !rhsMask) {
auto predTy =
outerProductOp.getLhsType().cloneWith({}, rewriter.getI1Type());
Value allActiveMask = rewriter.create<arith::ConstantOp>(
loc, DenseElementsAttr::get(predTy, true));
lhsMask = allActiveMask;
rhsMask = allActiveMask;
}

// Create 'arm_sme.intr.mopa' outer product intrinsic.
rewriter.create<arm_sme::aarch64_sme_mopa>(
loc, tileI32, allActiveMask, allActiveMask, outerProductOp.getLhs(),
outerProductOp.getRhs());
rewriter.create<arm_sme::aarch64_sme_mopa>(loc, tileI32, lhsMask, rhsMask,
outerProductOp.getLhs(),
outerProductOp.getRhs());

// Create `CastTileToVectorOp` to use as the output.
rewriter.replaceOpWithNewOp<arm_sme::CastTileToVector>(
Expand Down Expand Up @@ -733,6 +725,6 @@ void mlir::populateArmSMELegalizeForLLVMExportPatterns(
patterns.add<
LoadTileSliceToArmSMELowering, MoveTileSliceToVectorArmSMELowering,
MoveVectorToTileSliceToArmSMELowering, StoreTileSliceToArmSMELowering,
VectorOuterProductToArmSMELowering, ZeroOpConversion,
VectorExtractToArmSMELowering, VectorInsertToArmSMELowering>(converter);
OuterProductOpConversion, ZeroOpConversion, VectorExtractToArmSMELowering,
VectorInsertToArmSMELowering>(converter);
}
22 changes: 22 additions & 0 deletions mlir/test/Dialect/ArmSME/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -150,3 +150,25 @@ func.func @arm_sme_tile_load__pad_but_no_mask(%src : memref<?x?xf64>, %pad : f64
%tile = arm_sme.tile_load %src[%c0, %c0], %pad, : memref<?x?xf64>, vector<[2]x[2]xf64>
return
}

//===----------------------------------------------------------------------===//
// arm_sme.outerproduct
//===----------------------------------------------------------------------===//

// -----

func.func @arm_sme_outerproduct__bad_result_type(%vecA: vector<[2]xi16>, %vecB: vector<[2]xi16>) -> vector<[2]x[2]xi16>
{
// expected-error@+1 {{op result #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<[2]x[2]xi16>'}}
%0 = arm_sme.outerproduct %vecA, %vecB : vector<[2]xi16>, vector<[2]xi16>
return %0 : vector<[2]x[2]xi16>
}

// -----

func.func @arm_sme_outerproduct__bad_vector_type(%vecA: vector<[4]xf32>, %vecB: vector<[8]xf32>) -> vector<[4]x[4]xf32>
{
// expected-error@+1 {{op failed to verify that all of {lhs, rhs} have same type}}
%0 = arm_sme.outerproduct %vecA, %vecB : vector<[4]xf32>, vector<[8]xf32>
return %0 : vector<[4]x[4]xf32>
}
Loading