Skip to content

Commit 589c0d2

Browse files
committed
Address comments. Changes:
- hasOneUse - Negative tests for pass - move types and isWidening checks to isSupported method - negative test for unsupported type - move masking check - op1.erase() -> rewriter.eraseOp(op1); - braces round dangling-else - TypeSwitch - rename isWidenable and add comments for clarity - add TODO/REDEFINE for QEMU bug to make it clearer - s/widening/fusion/g - drop wide from op names
1 parent 26f705d commit 589c0d2

File tree

13 files changed

+620
-405
lines changed

13 files changed

+620
-405
lines changed

mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td

Lines changed: 44 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -814,10 +814,10 @@ let arguments = (ins
814814
}];
815815
}
816816

817-
class OuterProductWideBase<string mnemonic,
818-
list<Type> allowedInputVectorTypes,
819-
list<Type> allowedResultVectorTypes,
820-
int numOuterProducts> :
817+
class OuterProductWideningBase<string mnemonic,
818+
list<Type> allowedInputVectorTypes,
819+
list<Type> allowedResultVectorTypes,
820+
int numOuterProducts> :
821821
ArmSME_Op<mnemonic, [
822822
ArmSMETileOpInterface,
823823
AttrSizedOperandSegments,
@@ -869,14 +869,14 @@ class OuterProductWideBase<string mnemonic,
869869
}];
870870
}
871871

872-
class OuterProductWide2Way<string mnemonic,
873-
list<Type> allowedInputVectorTypes,
874-
list<Type> allowedResultVectorTypes>
875-
: OuterProductWideBase<mnemonic, allowedInputVectorTypes,
876-
allowedResultVectorTypes, /*numOuterProducts=*/2>;
872+
class OuterProduct2Way<string mnemonic,
873+
list<Type> allowedInputVectorTypes,
874+
list<Type> allowedResultVectorTypes>
875+
: OuterProductWideningBase<mnemonic, allowedInputVectorTypes,
876+
allowedResultVectorTypes, /*numOuterProducts=*/2>;
877877

878-
def FMopaWide2WayOp
879-
: OuterProductWide2Way<"fmopa_wide_2way",
878+
def FMopa2WayOp
879+
: OuterProduct2Way<"fmopa_2way",
880880
[ScalableVectorOfRankAndLengthAndType<[1], [8], [F16, BF16]>],
881881
[nxnxv4f32]> {
882882
let summary = "Floating-point sum of 2 outer products and accumulate";
@@ -888,14 +888,14 @@ def FMopaWide2WayOp
888888
For example (fp16 to fp32):
889889

890890
```mlir
891-
%result = arm_sme.fmopa_wide_2way %lhs, %rhs :
891+
%result = arm_sme.fmopa_2way %lhs, %rhs :
892892
vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
893893
```
894894

895895
The `lhs` encodes a matrix of shape SVLSx2 and the `rhs` a matrix of
896896
2xSVLS, where SVLS (spec [1], section B2.1) is the number of 32-bit
897897
elements in a vector of SVL bits. To illustrate, below is a breakdown of
898-
this operation for SVL=128 (i.e., vscale=1):
898+
this operation for fp16 to fp32, SVL=128 (i.e., vscale=1):
899899

900900
```
901901
LHS RHS
@@ -960,19 +960,19 @@ def FMopaWide2WayOp
960960
```mlir
961961
%a_packed = "llvm.intr.experimental.vector.interleave2"(%a0, %a1) : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
962962
%b_packed = "llvm.intr.experimental.vector.interleave2"(%b0, %b1) : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
963-
%0 = arm_sme.fmopa_wide_2way %a_packed, %b_packed : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
963+
%0 = arm_sme.fmopa_2way %a_packed, %b_packed : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
964964
```
965965

966-
This is implemented in the `-arm-sme-outer-product-widening` pass.
966+
This is implemented in the `-arm-sme-outer-product-fusion` pass.
967967

968968
Example: FP16 to FP32
969969
```mlir
970-
%result = arm_sme.fmopa_wide_2way $lhs, $rhs : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
970+
%result = arm_sme.fmopa_2way $lhs, $rhs : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
971971
```
972972

973973
Example: BF16 to FP32
974974
```mlir
975-
%result = arm_sme.fmopa_wide_2way $lhs, $rhs : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32>
975+
%result = arm_sme.fmopa_2way $lhs, $rhs : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32>
976976
```
977977

978978
| Spec | Features |
@@ -989,27 +989,27 @@ def FMopaWide2WayOp
989989
// - FMOPA 4-way FP16 to FP32
990990
// once intrinsic support lands in the backend.
991991

992-
def FMopsWide2WayOp
993-
: OuterProductWide2Way<"fmops_wide_2way",
992+
def FMops2WayOp
993+
: OuterProduct2Way<"fmops_2way",
994994
[ScalableVectorOfRankAndLengthAndType<[1], [8], [F16, BF16]>],
995995
[nxnxv4f32]> {
996996
let summary = "Floating-point sum of 2 outer products and subtract";
997997
let description = [{
998-
Equivalent to `fmopa_wide_2way` but outer products are subtracted from
998+
Equivalent to `fmopa_2way` but outer products are subtracted from
999999
destination `result`.
10001000

10011001
Example: FP16 to FP32
10021002
```mlir
1003-
%result = arm_sme.fmops_wide_2way $lhs, $rhs : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
1003+
%result = arm_sme.fmops_2way $lhs, $rhs : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
10041004
```
10051005

10061006
Example: BF16 to FP32
10071007
```mlir
1008-
%result = arm_sme.fmops_wide_2way $lhs, $rhs : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32>
1008+
%result = arm_sme.fmops_2way $lhs, $rhs : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32>
10091009

10101010
Refer to
1011-
[fmopa_wide_2way](#arm_smefmopa_wide_2way-arm_smefmopa_wide_2wayop) for a
1012-
detailed description of 2-way outer products.
1011+
[fmopa_2way](#arm_smefmopa_2way-arm_smefmopa_2wayop) for a detailed
1012+
description of 2-way outer products.
10131013

10141014
| Spec | Features |
10151015
| ---- | -------- |
@@ -1019,19 +1019,19 @@ def FMopsWide2WayOp
10191019
}];
10201020
}
10211021

1022-
def SMopaWide2WayOp
1023-
: OuterProductWide2Way<"smopa_wide_2way",
1022+
def SMopa2WayOp
1023+
: OuterProduct2Way<"smopa_2way",
10241024
[ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
10251025
[nxnxv4i32]> {
10261026
let summary = "Signed integer sum of 2 outer products and accumulate";
10271027
let description = [{
10281028
Example:
10291029
```mlir
1030-
%result = arm_sme.smopa_wide_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
1030+
%result = arm_sme.smopa_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
10311031

10321032
Refer to
1033-
[fmopa_wide_2way](#arm_smefmopa_wide_2way-arm_smefmopa_wide_2wayop) for a
1034-
detailed description of 2-way outer products.
1033+
[fmopa_2way](#arm_smefmopa_2way-arm_smefmopa_2wayop) for a detailed
1034+
description of 2-way outer products.
10351035

10361036
| Spec | Features |
10371037
| ---- | -------- |
@@ -1040,19 +1040,19 @@ def SMopaWide2WayOp
10401040
}];
10411041
}
10421042

1043-
def SMopsWide2WayOp
1044-
: OuterProductWide2Way<"smops_wide_2way",
1043+
def SMops2WayOp
1044+
: OuterProduct2Way<"smops_2way",
10451045
[ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
10461046
[nxnxv4i32]> {
10471047
let summary = "Signed integer sum of 2 outer products and subtract";
10481048
let description = [{
10491049
Example:
10501050
```mlir
1051-
%result = arm_sme.smops_wide_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
1051+
%result = arm_sme.smops_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
10521052

10531053
Refer to
1054-
[fmopa_wide_2way](#arm_smefmopa_wide_2way-arm_smefmopa_wide_2wayop) for a
1055-
detailed description of 2-way outer products.
1054+
[fmopa_2way](#arm_smefmopa_2way-arm_smefmopa_2wayop) for a detailed
1055+
description of 2-way outer products.
10561056

10571057
| Spec | Features |
10581058
| ---- | -------- |
@@ -1061,19 +1061,19 @@ def SMopsWide2WayOp
10611061
}];
10621062
}
10631063

1064-
def UMopaWide2WayOp
1065-
: OuterProductWide2Way<"umopa_wide_2way",
1064+
def UMopa2WayOp
1065+
: OuterProduct2Way<"umopa_2way",
10661066
[ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
10671067
[nxnxv4i32]> {
10681068
let summary = "Unsiged integer sum of 2 outer products and accumulate";
10691069
let description = [{
10701070
Example:
10711071
```mlir
1072-
%result = arm_sme.umopa_wide_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
1072+
%result = arm_sme.umopa_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
10731073

10741074
Refer to
1075-
[fmopa_wide_2way](#arm_smefmopa_wide_2way-arm_smefmopa_wide_2wayop) for a
1076-
detailed description of 2-way outer products.
1075+
[fmopa_2way](#arm_smefmopa_2way-arm_smefmopa_2wayop) for a detailed
1076+
description of 2-way outer products.
10771077

10781078
| Spec | Features |
10791079
| ---- | -------- |
@@ -1082,19 +1082,19 @@ def UMopaWide2WayOp
10821082
}];
10831083
}
10841084

1085-
def UMopsWide2WayOp
1086-
: OuterProductWide2Way<"umops_wide_2way",
1085+
def UMops2WayOp
1086+
: OuterProduct2Way<"umops_2way",
10871087
[ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
10881088
[nxnxv4i32]> {
10891089
let summary = "Unsiged integer sum of 2 outer products and subtract";
10901090
let description = [{
10911091
Example:
10921092
```mlir
1093-
%result = arm_sme.umops_wide_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
1093+
%result = arm_sme.umops_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
10941094

10951095
Refer to
1096-
[fmopa_wide_2way](#arm_smefmopa_wide_2way-arm_smefmopa_wide_2wayop) for a
1097-
detailed description of 2-way outer products.
1096+
[fmopa_2way](#arm_smefmopa_2way-arm_smefmopa_2wayop) for a detailed
1097+
description of 2-way outer products.
10981098

10991099
| Spec | Features |
11001100
| ---- | -------- |

mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,9 @@ std::unique_ptr<Pass> createEnableArmStreamingPass(
3232
/// Pass that allocates tile IDs to ArmSME operations.
3333
std::unique_ptr<Pass> createTileAllocationPass();
3434

35-
/// Pass that folds 'arm_sme.outerproduct' ops into widening variants.
36-
std::unique_ptr<Pass> createOuterProductWideningPass();
35+
/// Pass that fuses 'arm_sme.outerproduct' ops into 2-way or 4-way widening
36+
/// variants.
37+
std::unique_ptr<Pass> createOuterProductFusionPass();
3738

3839
//===----------------------------------------------------------------------===//
3940
// Registration

mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -122,11 +122,11 @@ def TileAllocation
122122
let dependentDialects = ["func::FuncDialect"];
123123
}
124124

125-
def OuterProductWidening
126-
: Pass<"arm-sme-outer-product-widening", "mlir::func::FuncOp"> {
127-
let summary = "Fold 'arm_sme.outerproduct' operations into widening variants";
125+
def OuterProductFusion
126+
: Pass<"arm-sme-outer-product-fusion", "mlir::func::FuncOp"> {
127+
let summary = "Fuse 'arm_sme.outerproduct' operations into 2-way or 4-way widening variants";
128128
let description = [{
129-
This pass folds 'arm_sme.outerproduct' operations that are chained via the
129+
This pass fuses 'arm_sme.outerproduct' operations that are chained via the
130130
accumulator into 2-way or 4-way ArmSME outer product operations.
131131

132132
For example:
@@ -145,14 +145,14 @@ def OuterProductWidening
145145
```mlir
146146
%a_packed = "llvm.intr.experimental.vector.interleave2"(%a0, %a1) : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
147147
%b_packed = "llvm.intr.experimental.vector.interleave2"(%b0, %b1) : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
148-
%0 = arm_sme.fmopa_wide_2way %a_packed, %b_packed : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
148+
%0 = arm_sme.fmopa_2way %a_packed, %b_packed : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
149149
```
150150

151-
For further information on the widening ops see:
152-
https://mlir.llvm.org/docs/Dialects/ArmSME/#arm_smefmopa_wide_2way-arm_smefmopa_wide_2wayop
153-
https://mlir.llvm.org/docs/Dialects/ArmSME/#arm_smesmopa_wide_4way-arm_smesmopa_wide_4wayop
151+
For further information on the 2-way or 4-way widening ops see:
152+
https://mlir.llvm.org/docs/Dialects/ArmSME/#arm_smefmopa_2way-arm_smefmopa_2wayop
153+
https://mlir.llvm.org/docs/Dialects/ArmSME/#arm_smesmopa_4way-arm_smesmopa_4wayop
154154
}];
155-
let constructor = "mlir::arm_sme::createOuterProductWideningPass()";
155+
let constructor = "mlir::arm_sme::createOuterProductFusionPass()";
156156
let dependentDialects = ["func::FuncDialect", "arm_sme::ArmSMEDialect", "arm_sve::ArmSVEDialect", "LLVM::LLVMDialect"];
157157
}
158158

mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class LLVMTypeConverter;
1616
class RewritePatternSet;
1717

1818
namespace arm_sme {
19-
void populateOuterProductWideningPatterns(RewritePatternSet &patterns);
19+
void populateOuterProductFusionPatterns(RewritePatternSet &patterns);
2020
} // namespace arm_sme
2121

2222
} // namespace mlir

mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -776,16 +776,16 @@ struct OuterProductOpConversion
776776
}
777777
};
778778

779-
/// Lower 2-way and 4-way outer products to intrinsics.
780-
template <class OuterProductWideOp, class OuterProductWideIntrOp>
781-
struct OuterProductWideOpConversion
782-
: public ConvertArmSMEOpToLLVMPattern<OuterProductWideOp> {
779+
/// Lower 2-way and 4-way widening outer products to intrinsics.
780+
template <class OuterProductWideningOp, class OuterProductWideningIntrOp>
781+
struct OuterProductWideningOpConversion
782+
: public ConvertArmSMEOpToLLVMPattern<OuterProductWideningOp> {
783783
using ConvertArmSMEOpToLLVMPattern<
784-
OuterProductWideOp>::ConvertArmSMEOpToLLVMPattern;
784+
OuterProductWideningOp>::ConvertArmSMEOpToLLVMPattern;
785785

786786
LogicalResult
787-
matchAndRewrite(OuterProductWideOp op,
788-
typename OuterProductWideOp::Adaptor adaptor,
787+
matchAndRewrite(OuterProductWideningOp op,
788+
typename OuterProductWideningOp::Adaptor adaptor,
789789
ConversionPatternRewriter &rewriter) const override {
790790
auto tileId = getTileIdOrError(op);
791791
if (!tileId)
@@ -807,9 +807,9 @@ struct OuterProductWideOpConversion
807807
rhsMask = allActiveMask;
808808
}
809809

810-
rewriter.create<OuterProductWideIntrOp>(op.getLoc(), tileId, lhsMask,
811-
rhsMask, adaptor.getLhs(),
812-
adaptor.getRhs());
810+
rewriter.create<OuterProductWideningIntrOp>(op.getLoc(), tileId, lhsMask,
811+
rhsMask, adaptor.getLhs(),
812+
adaptor.getRhs());
813813

814814
// The outerproduct intrinsics have no result, replace
815815
// 'arm_sme.outerproduct' with the input tile to preserve dataflow.
@@ -927,18 +927,18 @@ void mlir::populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter,
927927
LoadTileSliceConversion, MoveTileSliceToVectorConversion,
928928
MoveVectorToTileSliceConversion, StoreTileSliceConversion,
929929
StreamingVLOpConversion, OuterProductOpConversion,
930-
OuterProductWideOpConversion<arm_sme::FMopaWide2WayOp,
931-
arm_sme::aarch64_sme_mopa_wide>,
932-
OuterProductWideOpConversion<arm_sme::FMopsWide2WayOp,
933-
arm_sme::aarch64_sme_mops_wide>,
934-
OuterProductWideOpConversion<arm_sme::SMopaWide2WayOp,
935-
arm_sme::aarch64_sme_smopa_za32>,
936-
OuterProductWideOpConversion<arm_sme::SMopsWide2WayOp,
937-
arm_sme::aarch64_sme_smops_za32>,
938-
OuterProductWideOpConversion<arm_sme::UMopaWide2WayOp,
939-
arm_sme::aarch64_sme_umopa_za32>,
940-
OuterProductWideOpConversion<arm_sme::UMopsWide2WayOp,
941-
arm_sme::aarch64_sme_umops_za32>,
930+
OuterProductWideningOpConversion<arm_sme::FMopa2WayOp,
931+
arm_sme::aarch64_sme_mopa_wide>,
932+
OuterProductWideningOpConversion<arm_sme::FMops2WayOp,
933+
arm_sme::aarch64_sme_mops_wide>,
934+
OuterProductWideningOpConversion<arm_sme::SMopa2WayOp,
935+
arm_sme::aarch64_sme_smopa_za32>,
936+
OuterProductWideningOpConversion<arm_sme::SMops2WayOp,
937+
arm_sme::aarch64_sme_smops_za32>,
938+
OuterProductWideningOpConversion<arm_sme::UMopa2WayOp,
939+
arm_sme::aarch64_sme_umopa_za32>,
940+
OuterProductWideningOpConversion<arm_sme::UMops2WayOp,
941+
arm_sme::aarch64_sme_umops_za32>,
942942
ZeroOpConversion, GetTileConversion>(patterns, converter);
943943
}
944944

mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
add_mlir_dialect_library(MLIRArmSMETransforms
22
EnableArmStreaming.cpp
3-
OuterProductWidening.cpp
3+
OuterProductFusion.cpp
44
TileAllocation.cpp
55

66
ADDITIONAL_HEADER_DIRS

0 commit comments

Comments
 (0)