Skip to content

Commit e666295

Browse files
authored
[mlir][ArmSME] Support lowering masked vector.outerproduct ops to SME (llvm#69604)
This patch adds support for lowering masked outer products to SME. This is done in two stages. First, vector.outerproducts (both masked and non-masked) are rewritten to arm_sme.outerproducts. The arm_sme.outerproduct op is close to vector.outerproduct, but supports masking on the operands rather than the result. It also limits the cases it handles to things that could be (directly) lowered to SME. This currently requires that the source of the mask is a vector.create_mask op. E.g.: ```mlir %mask = vector.create_mask %dimA, %dimB : vector<[4]x[4]xi1> %result = vector.mask %mask { vector.outerproduct %vecA, %vecB : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32> ``` Is rewritten to: ``` %maskA = vector.create_mask %dimA : vector<[4]xi1> %maskB = vector.create_mask %dimB : vector<[4]xi1> %result = arm_sme.outerproduct %vecA, %vecB masks(%maskA, %maskB) : vector<[4]xf32>, vector<[4]xf32> ``` (The same rewrite works for non-masked vector.outerproducts too) The arm_sme.outerproduct can then be directly lowered to SME intrinsics.
1 parent bbd61d8 commit e666295

File tree

9 files changed

+654
-64
lines changed

9 files changed

+654
-64
lines changed

mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td

Lines changed: 108 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,27 @@ def ArmSME_TileSliceLayoutAttr : EnumAttr<ArmSME_Dialect, TileSliceLayout,
7979
let defaultValue = "TileSliceLayout::Horizontal";
8080
}
8181

82+
def CombiningKind : I32EnumAttr<"CombiningKind", "Kind of combining function", [
83+
I32EnumAttrCase<"Add", 0, "add">,
84+
I32EnumAttrCase<"Sub", 1, "sub">,
85+
]> {
86+
let cppNamespace = "::mlir::arm_sme";
87+
let genSpecializedAttr = 0;
88+
}
89+
90+
/// An attribute that specifies how to combine a newly produced value with the
91+
/// accumulator. This is similar to vector::CombiningKindAttr, but limited to
92+
/// the functions that are valid for SME outer products. Add corresponds to a
93+
/// MOPA and sub to a MOPS.
94+
/// E.g. For f32:
95+
/// FMOPA: https://developer.arm.com/documentation/ddi0602/2022-03/SME-Instructions/FMOPA--non-widening---Floating-point-outer-product-and-accumulate-
96+
/// FMOPS: https://developer.arm.com/documentation/ddi0602/2022-03/SME-Instructions/FMOPS--non-widening---Floating-point-outer-product-and-subtract-
97+
def ArmSME_CombiningKindAttr : EnumAttr<ArmSME_Dialect, CombiningKind,
98+
"kind"> {
99+
let assemblyFormat = "`<` $value `>`";
100+
let defaultValue = "CombiningKind::Add";
101+
}
102+
82103
//===----------------------------------------------------------------------===//
83104
// ArmSME op definitions
84105
//===----------------------------------------------------------------------===//
@@ -209,7 +230,7 @@ def ZeroOp : ArmSME_Op<"zero", [Pure]> {
209230
let results = (outs SMETile:$res);
210231
let description = [{
211232
Initialise ZA with 0. This operation is convenient wrapper for the SME
212-
`zero` intrinsic and instruction.
233+
`zero` intrinsic and instruction.
213234

214235
Example 1: Zero an 8-bit element ZA tile.
215236

@@ -561,4 +582,90 @@ def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [Pure,
561582
}];
562583
}
563584

585+
class HasMatchingMaskTypeConstraint<string operand> :
586+
OptionalTypesMatchWith<
587+
"shape of `" # operand # "Mask` matches `" # operand # "`",
588+
operand, operand # "Mask",
589+
"::llvm::cast<mlir::VectorType>($_self).cloneWith({}, IntegerType::get($_ctxt, 1))">;
590+
591+
class OuterProductResultTileTypeConstraint<string operand> :
592+
OptionalTypesMatchWith<operand # "type is derived from `lhs` and `rhs`",
593+
"lhs", operand,
594+
"[&]{"
595+
" auto vectorType = ::llvm::cast<mlir::VectorType>($_self);"
596+
" int64_t size = vectorType.getDimSize(0);"
597+
" return VectorType::get("
598+
" { size, size }, vectorType.getElementType(), { true, true });"
599+
"}()">;
600+
601+
def OuterProductOp :
602+
ArmSME_Op<"outerproduct", [Pure,
603+
AttrSizedOperandSegments,
604+
AllTypesMatch<["lhs", "rhs"]>,
605+
HasMatchingMaskTypeConstraint<"lhs">,
606+
HasMatchingMaskTypeConstraint<"rhs">,
607+
PredOpTrait<
608+
"both `lhsMask` and `rhsMask` should be provided or neither",
609+
CPred<"bool(getLhsMask()) == bool(getRhsMask())">>,
610+
OuterProductResultTileTypeConstraint<"result">,
611+
OuterProductResultTileTypeConstraint<"acc">
612+
]>
613+
{
614+
let summary = "Outer product with optional fused add/sub";
615+
616+
let description = [{
617+
This operation represents an outer product that fits within an SME tile.
618+
All operands must be SVE vectors and the result a SME tile. Unlike
619+
`vector.outerproduct` masking is on the operands (rather than the result),
620+
which mirrors the SME instructions.
621+
622+
Example 1: Unmasked outerproduct (without accumulator)
623+
```mlir
624+
// Not specifying an accumulator implicitly zeros the destination tile.
625+
%result = arm_sme.outerproduct $lhs, $rhs : vector<[4]xf32>, vector<[4]xf32>
626+
```
627+
628+
Example 2: Unmasked outerproduct (with accumulator)
629+
```mlir
630+
%result = arm_sme.outerproduct $lhs, $rhs acc($accumulator)
631+
: vector<[4]xf32>, vector<[4]xf32>
632+
```
633+
634+
Example 3: Masked outerproduct
635+
```mlir
636+
%result = arm_sme.outerproduct $lhs, $rhs masks($lhsMask, $rhsMask)
637+
: vector<[4]xf32>, vector<[4]xf32>
638+
```
639+
640+
Example 4: Masked outerproduct (with accumulator)
641+
```mlir
642+
%result = arm_sme.outerproduct $lhs, $rhs acc($accumulator) masks($lhsMask, $rhsMask)
643+
: vector<[4]xf32>, vector<[4]xf32>
644+
```
645+
}];
646+
647+
let arguments = (ins
648+
SVEVector:$lhs, SVEVector:$rhs,
649+
Optional<SVEPredicate>:$lhsMask,
650+
Optional<SVEPredicate>:$rhsMask,
651+
Optional<SMETile>: $acc,
652+
ArmSME_CombiningKindAttr:$kind);
653+
let results = (outs SMETile:$result);
654+
655+
let assemblyFormat = [{
656+
$lhs `,` $rhs
657+
oilist(
658+
`kind` `` $kind
659+
| `acc` `` `(` $acc `)`
660+
| `masks` `` `(` $lhsMask `,` $rhsMask `)`
661+
) attr-dict `:` type($lhs) `,` type($rhs)
662+
}];
663+
664+
let extraClassDeclaration = [{
665+
VectorType getLhsType() { return llvm::cast<VectorType>(getLhs().getType()); }
666+
VectorType getRhsType() { return llvm::cast<VectorType>(getRhs().getType()); }
667+
VectorType getResultType() { return llvm::cast<VectorType>(getResult().getType()); }
668+
}];
669+
}
670+
564671
#endif // ARMSME_OPS

mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp

Lines changed: 101 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -427,12 +427,112 @@ struct TransposeOpToArmSMELowering
427427
}
428428
};
429429

430+
/// Conversion pattern for vector.outerproduct.
431+
///
432+
/// If the vector.outerproduct is masked (and the mask is from a
433+
/// vector.create_mask), then the mask is decomposed into two 1-D masks for the
434+
/// operands.
435+
///
436+
/// Example:
437+
///
438+
/// %mask = vector.create_mask %dimA, %dimB : vector<[4]x[4]xi1>
439+
/// %result = vector.mask %mask {
440+
/// vector.outerproduct %vecA, %vecB
441+
/// : vector<[4]xf32>, vector<[4]xf32>
442+
/// } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
443+
///
444+
/// is converted to:
445+
///
446+
/// %maskA = vector.create_mask %dimA : vector<[4]xi1>
447+
/// %maskB = vector.create_mask %dimB : vector<[4]xi1>
448+
/// %result = arm_sme.outerproduct %vecA, %vecB masks(%maskA, %maskB)
449+
/// : vector<[4]xf32>, vector<[4]xf32>
450+
///
451+
/// Unmasked outerproducts can be directly replaced with the arm_sme op.
452+
///
453+
/// Example:
454+
///
455+
/// %result = vector.outerproduct %vecA, %vecB
456+
/// : vector<[4]xf32>, vector<[4]xf32>
457+
///
458+
/// is converted to:
459+
///
460+
/// %result = arm_sme.outerproduct %vecA, %vecB
461+
/// : vector<[4]xf32>, vector<[4]xf32>
462+
///
463+
struct VectorOuterProductToArmSMELowering
464+
: public OpRewritePattern<vector::OuterProductOp> {
465+
466+
using OpRewritePattern<vector::OuterProductOp>::OpRewritePattern;
467+
468+
LogicalResult matchAndRewrite(vector::OuterProductOp outerProductOp,
469+
PatternRewriter &rewriter) const override {
470+
471+
// We don't yet support lowering AXPY operations to SME. These could be
472+
// lowered by masking out all but the first element of the LHS.
473+
if (!isa<VectorType>(outerProductOp.getOperandTypeRHS()))
474+
return outerProductOp.emitError("AXPY operations not supported");
475+
476+
if (!arm_sme::isValidSMETileVectorType(
477+
outerProductOp.getResultVectorType()))
478+
return outerProductOp.emitError(
479+
"outer product does not fit into SME tile");
480+
481+
auto kind = outerProductOp.getKind();
482+
if (kind != vector::CombiningKind::ADD)
483+
return outerProductOp.emitError(
484+
"unsupported kind (lowering to SME only supports ADD at the moment)");
485+
486+
Value lhsMask = {};
487+
Value rhsMask = {};
488+
Operation *rootOp = outerProductOp;
489+
auto loc = outerProductOp.getLoc();
490+
if (outerProductOp.isMasked()) {
491+
auto maskOp = outerProductOp.getMaskingOp();
492+
rewriter.setInsertionPoint(maskOp);
493+
rootOp = maskOp;
494+
auto operandMasks = decomposeResultMask(loc, maskOp.getMask(), rewriter);
495+
if (failed(operandMasks))
496+
return failure();
497+
std::tie(lhsMask, rhsMask) = *operandMasks;
498+
}
499+
500+
rewriter.replaceOpWithNewOp<arm_sme::OuterProductOp>(
501+
rootOp, outerProductOp.getResultVectorType(), outerProductOp.getLhs(),
502+
outerProductOp.getRhs(), lhsMask, rhsMask, outerProductOp.getAcc());
503+
504+
return success();
505+
}
506+
507+
static FailureOr<std::pair<Value, Value>>
508+
decomposeResultMask(Location loc, Value mask, PatternRewriter &rewriter) {
509+
// Attempt to extract masks from vector.create_mask.
510+
// TODO: Add support for other mask sources.
511+
auto createMaskOp = mask.getDefiningOp<vector::CreateMaskOp>();
512+
if (!createMaskOp)
513+
return failure();
514+
515+
auto maskType = createMaskOp.getVectorType();
516+
Value lhsMaskDim = createMaskOp.getOperand(0);
517+
Value rhsMaskDim = createMaskOp.getOperand(1);
518+
519+
VectorType operandMaskType = VectorType::Builder(maskType).dropDim(0);
520+
Value lhsMask =
521+
rewriter.create<vector::CreateMaskOp>(loc, operandMaskType, lhsMaskDim);
522+
Value rhsMask =
523+
rewriter.create<vector::CreateMaskOp>(loc, operandMaskType, rhsMaskDim);
524+
525+
return std::make_pair(lhsMask, rhsMask);
526+
}
527+
};
528+
430529
} // namespace
431530

432531
void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
433532
MLIRContext &ctx) {
434533
patterns.add<BroadcastOpToArmSMELowering, ConstantOpToArmSMELowering,
435534
SplatOpToArmSMELowering, TransferReadPermutationToArmSMELowering,
436535
TransferWriteToArmSMELowering, TransposeOpToArmSMELowering,
437-
VectorLoadToArmSMELowering, VectorStoreToArmSMELowering>(&ctx);
536+
VectorLoadToArmSMELowering, VectorStoreToArmSMELowering,
537+
VectorOuterProductToArmSMELowering>(&ctx);
438538
}

mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp

Lines changed: 29 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -460,11 +460,11 @@ struct MoveTileSliceToVectorArmSMELowering
460460
}
461461
};
462462

463-
/// Lower `vector.outerproduct` to SME MOPA intrinsics.
463+
/// Lower `arm_sme.outerproduct` to SME MOPA intrinsics.
464464
///
465465
/// Example:
466466
///
467-
/// %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>}
467+
/// %0 = arm_sme.outerproduct %lhs, %rhs acc(%acc)
468468
/// : vector<[4]xf32>, vector<[4]xf32>
469469
///
470470
/// is converted to:
@@ -474,13 +474,13 @@ struct MoveTileSliceToVectorArmSMELowering
474474
/// vector<[4]xf32>) -> ()
475475
///
476476
/// Currently only supports FMOPA and BFMOPA (non-widening).
477-
struct VectorOuterProductToArmSMELowering
478-
: public ConvertOpToLLVMPattern<vector::OuterProductOp> {
479-
using ConvertOpToLLVMPattern<vector::OuterProductOp>::ConvertOpToLLVMPattern;
477+
struct OuterProductOpConversion
478+
: public ConvertOpToLLVMPattern<arm_sme::OuterProductOp> {
479+
using ConvertOpToLLVMPattern<arm_sme::OuterProductOp>::ConvertOpToLLVMPattern;
480480

481481
LogicalResult
482-
matchAndRewrite(vector::OuterProductOp outerProductOp,
483-
vector::OuterProductOp::Adaptor adaptor,
482+
matchAndRewrite(arm_sme::OuterProductOp outerProductOp,
483+
arm_sme::OuterProductOp::Adaptor adaptor,
484484
ConversionPatternRewriter &rewriter) const override {
485485
auto isSupportedType = [](VectorType vectorType) {
486486
// TODO: the FP outer product instruction variants are predicated on
@@ -512,24 +512,13 @@ struct VectorOuterProductToArmSMELowering
512512
return true;
513513
};
514514

515-
auto resultVectorType = outerProductOp.getResultVectorType();
516-
if (!isSupportedType(resultVectorType))
517-
return outerProductOp.emitError("unsupported type");
518-
519-
vector::CombiningKind kind = outerProductOp.getKind();
520-
if (kind != vector::CombiningKind::ADD)
521-
// TODO: support subtract.
515+
// TODO: Support CombiningKind::Sub for outer products.
516+
if (outerProductOp.getKind() != CombiningKind::Add)
522517
return outerProductOp.emitError("unsupported kind");
523518

524-
auto maskableOp =
525-
cast<vector::MaskableOpInterface>(outerProductOp.getOperation());
526-
if (maskableOp.isMasked())
527-
// TODO: support masking.
528-
return outerProductOp.emitError("masking is currently unsupported");
529-
530-
if (!isa<VectorType>(outerProductOp.getOperandTypeRHS()))
531-
// AXPY operation not suited for SME.
532-
return failure();
519+
auto resultVectorType = outerProductOp.getResultType();
520+
if (!isSupportedType(resultVectorType))
521+
return outerProductOp.emitError("unsupported type");
533522

534523
auto loc = outerProductOp.getLoc();
535524

@@ -542,21 +531,24 @@ struct VectorOuterProductToArmSMELowering
542531
auto tileId = rewriter.create<arm_sme::CastVectorToTile>(
543532
loc, rewriter.getIntegerType(elementWidth), acc);
544533

545-
// Create all active predicate mask.
546-
auto one = rewriter.create<arith::ConstantOp>(
547-
loc, rewriter.getI1Type(),
548-
rewriter.getIntegerAttr(rewriter.getI1Type(), 1));
549-
auto predTy =
550-
VectorType::get(resultVectorType.getShape()[0], rewriter.getI1Type(),
551-
/*scalableDims=*/{true});
552-
auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one);
553-
554534
auto tileI32 = castTileIDToI32(tileId, loc, rewriter);
555535

536+
Value lhsMask = outerProductOp.getLhsMask();
537+
Value rhsMask = outerProductOp.getRhsMask();
538+
539+
if (!lhsMask || !rhsMask) {
540+
auto predTy =
541+
outerProductOp.getLhsType().cloneWith({}, rewriter.getI1Type());
542+
Value allActiveMask = rewriter.create<arith::ConstantOp>(
543+
loc, DenseElementsAttr::get(predTy, true));
544+
lhsMask = allActiveMask;
545+
rhsMask = allActiveMask;
546+
}
547+
556548
// Create 'arm_sme.intr.mopa' outer product intrinsic.
557-
rewriter.create<arm_sme::aarch64_sme_mopa>(
558-
loc, tileI32, allActiveMask, allActiveMask, outerProductOp.getLhs(),
559-
outerProductOp.getRhs());
549+
rewriter.create<arm_sme::aarch64_sme_mopa>(loc, tileI32, lhsMask, rhsMask,
550+
outerProductOp.getLhs(),
551+
outerProductOp.getRhs());
560552

561553
// Create `CastTileToVectorOp` to use as the output.
562554
rewriter.replaceOpWithNewOp<arm_sme::CastTileToVector>(
@@ -733,6 +725,6 @@ void mlir::populateArmSMELegalizeForLLVMExportPatterns(
733725
patterns.add<
734726
LoadTileSliceToArmSMELowering, MoveTileSliceToVectorArmSMELowering,
735727
MoveVectorToTileSliceToArmSMELowering, StoreTileSliceToArmSMELowering,
736-
VectorOuterProductToArmSMELowering, ZeroOpConversion,
737-
VectorExtractToArmSMELowering, VectorInsertToArmSMELowering>(converter);
728+
OuterProductOpConversion, ZeroOpConversion, VectorExtractToArmSMELowering,
729+
VectorInsertToArmSMELowering>(converter);
738730
}

mlir/test/Dialect/ArmSME/invalid.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,3 +150,25 @@ func.func @arm_sme_tile_load__pad_but_no_mask(%src : memref<?x?xf64>, %pad : f64
150150
%tile = arm_sme.tile_load %src[%c0, %c0], %pad, : memref<?x?xf64>, vector<[2]x[2]xf64>
151151
return
152152
}
153+
154+
//===----------------------------------------------------------------------===//
155+
// arm_sme.outerproduct
156+
//===----------------------------------------------------------------------===//
157+
158+
// -----
159+
160+
func.func @arm_sme_outerproduct__bad_result_type(%vecA: vector<[2]xi16>, %vecB: vector<[2]xi16>) -> vector<[2]x[2]xi16>
161+
{
162+
// expected-error@+1 {{op result #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<[2]x[2]xi16>'}}
163+
%0 = arm_sme.outerproduct %vecA, %vecB : vector<[2]xi16>, vector<[2]xi16>
164+
return %0 : vector<[2]x[2]xi16>
165+
}
166+
167+
// -----
168+
169+
func.func @arm_sme_outerproduct__bad_vector_type(%vecA: vector<[4]xf32>, %vecB: vector<[8]xf32>) -> vector<[4]x[4]xf32>
170+
{
171+
// expected-error@+1 {{op failed to verify that all of {lhs, rhs} have same type}}
172+
%0 = arm_sme.outerproduct %vecA, %vecB : vector<[4]xf32>, vector<[8]xf32>
173+
return %0 : vector<[4]x[4]xf32>
174+
}

0 commit comments

Comments
 (0)