Skip to content

Commit fff86c6

Browse files
authored
[mlir][ArmSME] Support 4-way widening outer products (#79288)
This patch introduces support for 4-way widening outer products. This enables the fusion of 4 'arm_sme.outerproduct' operations that are chained via the accumulator into single widened operations. Changes: - Adds the following operations: - smopa_4way, smops_4way - umopa_4way, umops_4way - sumopa_4way, sumops_4way - sumopa_4way, sumops_4way - Implements conversions for the above ops to intrinsics in ArmSMEToLLVM. - Extends 'arm-sme-outer-product' pass. For a detailed description of these operations see the 'arm_sme.smopa_4way' description.
1 parent 7d508eb commit fff86c6

File tree

8 files changed

+1897
-46
lines changed

8 files changed

+1897
-46
lines changed

mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td

Lines changed: 325 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1103,6 +1103,331 @@ def UMops2WayOp
11031103
}];
11041104
}
11051105

1106+
class OuterProduct4Way<string mnemonic,
1107+
list<Type> allowedInputVectorTypes,
1108+
list<Type> allowedResultVectorTypes>
1109+
: OuterProductWideningBase<mnemonic, allowedInputVectorTypes,
1110+
allowedResultVectorTypes, /*numOuterProducts=*/4>;
1111+
1112+
def SMopa4WayOp
1113+
: OuterProduct4Way<"smopa_4way",
1114+
[ScalableVectorOfRankAndLengthAndType<[1], [16], [I8]>,
1115+
ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
1116+
[nxnxv4i32, nxnxv2i64]> {
1117+
let summary = "Signed integer sum of 4 outer products and accumulate";
1118+
let description = [{
1119+
This operation represents a sum of 4 widened outer products. It takes 2 1-D
1120+
scalable vectors as input and a 2-D scalable vector (ZA tile) as output.
1121+
1122+
For example (i8 to i32):
1123+
1124+
```mlir
1125+
%result = arm_sme.smopa_4way $lhs, $rhs :
1126+
vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
1127+
```
1128+
1129+
The `lhs` encodes a matrix of shape SVLSx4 and the `rhs` a matrix of
1130+
4xSVLS, where SVLS (spec [1], section B2.1) is the number of 32-bit
1131+
elements in a vector of SVL bits. To illustrate, below is a breakdown of
1132+
this operation for i8 to i32, SVL=128 (i.e., vscale=1):
1133+
1134+
```
1135+
LHS
1136+
[A0 A1 A2 A3 A4 A5 A6 A7 A8 A9 A10 A11 A12 A15 A14 A15]
1137+
1138+
RHS
1139+
[B0 B1 B2 B3 B4 B5 B6 B7 B8 B9 B10 B11 B12 B13 B14 B15]
1140+
1141+
----------------------------------------------------------------------------
1142+
1143+
implicit layout
1144+
1145+
[A0 A1 A2 A3] | [B0 B4 B8 B12]
1146+
[A4 A5 A6 A7] | [B1 B5 B9 B13]
1147+
[A8 A9 A10 A11] | [B2 B6 B10 B14]
1148+
[A12 A13 A14 A15] | [B3 B7 B11 B15]
1149+
1150+
----------------------------------------------------------------------------
1151+
1152+
4 outer products
1153+
1154+
Acol0 ⊗ Brow0 | Acol1 ⊗ Brow1
1155+
------------- | -------------
1156+
|
1157+
[B0 B4 B8 B12] | [B1 B5 B9 B13]
1158+
|
1159+
[A0 [ A0B0 A0B4 A0B8 A0B12] | [A1 [ A1B1 A1B5 A1B9 A1B13]
1160+
A4 [ A4B0 A4B4 A4B8 A4B12] | A5 [ A5B1 A5B5 A5B9 A5B13]
1161+
A8 [ A8B0 A8B4 A8B8 A8B12] | A9 [ A9B1 A9B5 A9B9 A9B13]
1162+
A12] [A12B0 A12B4 A12B8 A12B12] | A13] [A13B1 A13B5 A13B9 A13B13]
1163+
|
1164+
Acol2 ⊗ Brow2 | Acol3 ⊗ Brow3
1165+
------------- | -------------
1166+
|
1167+
[B2, B6, B10, B14] | [B3 B7 B11 B15]
1168+
|
1169+
[A2 [ A2B2 A2B6 A2B10 A2B14] | [A3 [ A3B3 A3B7 A3B11 A3B15]
1170+
A6 [ A6B2 A6B6 A6B10 A6B14] | A7 [ A7B3 A7B7 A7B11 A7B15]
1171+
A10 [A10B2 A10B6 A10B10 A10B14] | A11 [A11B3 A11B7 A11B11 A11B15]
1172+
A14] [A14B2 A14B6 A14B10 A14B14] | A15] [A15B3 A15B7 A15B11 A15B15]
1173+
|
1174+
1175+
----------------------------------------------------------------------------
1176+
1177+
sum of 4 outer products
1178+
1179+
Acol0 ⊗ Brow0 + Acol1 ⊗ Brow1 + Acol2 ⊗ Brow2 + Acol3 ⊗ Brow3
1180+
1181+
[ A0B0 + A1B1 + A2B2 + A3B3 ... ... A0B12 + A1B13 + A2B14 + A3B15]
1182+
[ A4B0 + A5B1 + A6B2 + A7B3 ... ... A4B12 + A5B13 + A6B14 + A7B15]
1183+
[ A8B0 + A9B1 + A10B2 + A11B3 ... ... A8B12 + A9B13 + A10B14 + A11B15]
1184+
[A12B0 + A13B1 + A14B2 + A15B3 ... ... A12B12 + A13B13 + A14B14 + A15B15]
1185+
1186+
----------------------------------------------------------------------------
1187+
```
1188+
1189+
This operation enables the folding of 4 outer products chained via the
1190+
accumulator into a single outer product.
1191+
1192+
For example:
1193+
1194+
```mlir
1195+
%a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32>
1196+
%b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32>
1197+
1198+
%a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32>
1199+
%b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32>
1200+
1201+
%a2_ext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32>
1202+
%b2_ext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32>
1203+
1204+
%a3_ext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32>
1205+
%b3_ext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32>
1206+
1207+
%0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xi32>, vector<[4]xi32>
1208+
%1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xi32>, vector<[4]xi32>
1209+
%2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) : vector<[4]xi32>, vector<[4]xi32>
1210+
%3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) : vector<[4]xi32>, vector<[4]xi32>
1211+
```
1212+
1213+
The 4 outer products in the example above can be fused into a single outer
1214+
product as follows:
1215+
1216+
```mlir
1217+
%lhs0 = "llvm.intr.experimental.vector.interleave2"(%a0, %a2) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8>
1218+
%lhs1 = "llvm.intr.experimental.vector.interleave2"(%a1, %a3) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8>
1219+
%lhs = "llvm.intr.experimental.vector.interleave2"(%lhs0, %lhs1) : (vector<[8]xi8>, vector<[8]xi8>) -> vector<[16]xi8>
1220+
1221+
%rhs0 = "llvm.intr.experimental.vector.interleave2"(%b0, %b2) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8>
1222+
%rhs1 = "llvm.intr.experimental.vector.interleave2"(%b1, %b3) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8>
1223+
%rhs = "llvm.intr.experimental.vector.interleave2"(%rhs0, %rhs1) : (vector<[8]xi8>, vector<[8]xi8>) -> vector<[16]xi8>
1224+
1225+
%0 = arm_sme.smopa_4way %lhs, %rhs : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
1226+
```
1227+
1228+
This is implemented in the `-arm-sme-outer-product-fusion` pass.
1229+
1230+
Example: I8 to I32
1231+
```mlir
1232+
%result = arm_sme.smopa_4way $lhs, $rhs : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
1233+
```
1234+
1235+
Example: I16 to I64
1236+
```mlir
1237+
%result = arm_sme.smopa_4way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
1238+
```
1239+
1240+
| Spec | Features |
1241+
| ---- | -------- |
1242+
| [SMOPA (4-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/SMOPA--4-way---Signed-integer-sum-of-outer-products-and-accumulate-) | +sme (32-bit), +sme-i16i64 (64-bit)|
1243+
}];
1244+
}
1245+
1246+
def SMops4WayOp
1247+
: OuterProduct4Way<"smops_4way",
1248+
[ScalableVectorOfRankAndLengthAndType<[1], [16], [I8]>,
1249+
ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
1250+
[nxnxv4i32, nxnxv2i64]> {
1251+
let summary = "Signed integer sum of 4 outer products and subtract";
1252+
let description = [{
1253+
Equivalent to `smopa_4way` but outer products are subtracted from
1254+
destination `result`.
1255+
1256+
Example: I8 to I32
1257+
```mlir
1258+
%result = arm_sme.smops_4way $lhs, $rhs : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
1259+
```
1260+
1261+
Example: I16 to I64
1262+
```mlir
1263+
%result = arm_sme.smops_4way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
1264+
```
1265+
1266+
Refer to [smopa_4way](#arm_smesmopa_4way-arm_smesmopa4wayop) for a
1267+
detailed description of 4-way outer products.
1268+
1269+
| Spec | Features |
1270+
| ---- | -------- |
1271+
| [SMOPS (4-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/SMOPS--4-way---Signed-integer-sum-of-outer-products-and-subtract-) | +sme (32-bit), +sme-i16i64 (64-bit)|
1272+
}];
1273+
}
1274+
1275+
def UMopa4WayOp
1276+
: OuterProduct4Way<"umopa_4way",
1277+
[ScalableVectorOfRankAndLengthAndType<[1], [16], [I8]>,
1278+
ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
1279+
[nxnxv4i32, nxnxv2i64]> {
1280+
let summary = "Unsigned integer sum of 4 outer products and accumulate";
1281+
let description = [{
1282+
Example: I8 to I32
1283+
```mlir
1284+
%result = arm_sme.umopa_4way $lhs, $rhs : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
1285+
```
1286+
1287+
Example: I16 to I64
1288+
```mlir
1289+
%result = arm_sme.umopa_4way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
1290+
```
1291+
1292+
Refer to [smopa_4way](#arm_smesmopa_4way-arm_smesmopa4wayop) for a
1293+
detailed description of 4-way outer products.
1294+
1295+
| Spec | Features |
1296+
| ---- | -------- |
1297+
| [UMOPA (4-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/UMOPA--4-way---Unsigned-integer-sum-of-outer-products-and-accumulate-) | +sme (32-bit), +sme-i16i64 (64-bit)|
1298+
}];
1299+
}
1300+
1301+
def UMops4WayOp
1302+
: OuterProduct4Way<"umops_4way",
1303+
[ScalableVectorOfRankAndLengthAndType<[1], [16], [I8]>,
1304+
ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
1305+
[nxnxv4i32, nxnxv2i64]> {
1306+
let summary = "Unsigned integer sum of 4 outer products and subtract";
1307+
let description = [{
1308+
Example: I8 to I32
1309+
```mlir
1310+
%result = arm_sme.umops_4way $lhs, $rhs : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
1311+
```
1312+
1313+
Example: I16 to I64
1314+
```mlir
1315+
%result = arm_sme.umops_4way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
1316+
```
1317+
1318+
Refer to [smopa_4way](#arm_smesmopa_4way-arm_smesmopa4wayop) for a
1319+
detailed description of 4-way outer products.
1320+
1321+
| Spec | Features |
1322+
| ---- | -------- |
1323+
| [UMOPS (4-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/UMOPS--4-way---Unsigned-integer-sum-of-outer-products-and-subtract-) | +sme (32-bit), +sme-i16i64 (64-bit)|
1324+
}];
1325+
}
1326+
1327+
def SuMopa4WayOp
1328+
: OuterProduct4Way<"sumopa_4way",
1329+
[ScalableVectorOfRankAndLengthAndType<[1], [16], [I8]>,
1330+
ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
1331+
[nxnxv4i32, nxnxv2i64]> {
1332+
let summary = "Signed by unsigned integer sum of 4 outer products and accumulate";
1333+
let description = [{
1334+
Example: I8 to I32
1335+
```mlir
1336+
%result = arm_sme.sumopa_4way $lhs, $rhs : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
1337+
```
1338+
1339+
Example: I16 to I64
1340+
```mlir
1341+
%result = arm_sme.sumopa_4way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
1342+
```
1343+
1344+
Refer to [smopa_4way](#arm_smesmopa_4way-arm_smesmopa4wayop) for a
1345+
detailed description of 4-way outer products.
1346+
1347+
| Spec | Features |
1348+
| ---- | -------- |
1349+
| [SUMOPA (4-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/SUMOPA--Signed-by-unsigned-integer-sum-of-outer-products-and-accumulate-) | +sme (32-bit), +sme-i16i64 (64-bit)|
1350+
}];
1351+
}
1352+
1353+
def SuMops4WayOp
1354+
: OuterProduct4Way<"sumops_4way",
1355+
[ScalableVectorOfRankAndLengthAndType<[1], [16], [I8]>,
1356+
ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
1357+
[nxnxv4i32, nxnxv2i64]> {
1358+
let summary = "Signed by unsigned integer sum of 4 outer products and subtract";
1359+
let description = [{
1360+
Example: I8 to I32
1361+
```mlir
1362+
%result = arm_sme.sumops_4way $lhs, $rhs : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
1363+
```
1364+
1365+
Example: I16 to I64
1366+
```mlir
1367+
%result = arm_sme.sumops_4way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
1368+
```
1369+
1370+
Refer to [smopa_4way](#arm_smesmopa_4way-arm_smesmopa4wayop) for a
1371+
detailed description of 4-way outer products.
1372+
1373+
| Spec | Features |
1374+
| ---- | -------- |
1375+
| [SUMOPS (4-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/SUMOPS--Signed-by-unsigned-integer-sum-of-outer-products-and-subtract-) | +sme (32-bit), +sme-i16i64 (64-bit)|
1376+
}];
1377+
}
1378+
1379+
def UsMopa4WayOp
1380+
: OuterProduct4Way<"usmopa_4way",
1381+
[ScalableVectorOfRankAndLengthAndType<[1], [16], [I8]>,
1382+
ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
1383+
[nxnxv4i32, nxnxv2i64]> {
1384+
let summary = "Unsigned by signed integer sum of 4 outer products and accumulate";
1385+
let description = [{
1386+
Example: I8 to I32
1387+
```mlir
1388+
%result = arm_sme.usmopa_4way $lhs, $rhs : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
1389+
```
1390+
1391+
Example: I16 to I64
1392+
```mlir
1393+
%result = arm_sme.usmopa_4way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
1394+
```
1395+
1396+
Refer to [smopa_4way](#arm_smesmopa_4way-arm_smesmopa4wayop) for a
1397+
detailed description of 4-way outer products.
1398+
1399+
| Spec | Features |
1400+
| ---- | -------- |
1401+
| [USMOPA (4-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/USMOPA--Unsigned-by-signed-integer-sum-of-outer-products-and-accumulate-) | +sme (32-bit), +sme-i16i64 (64-bit)|
1402+
}];
1403+
}
1404+
1405+
def UsMops4WayOp
1406+
: OuterProduct4Way<"usmops_4way",
1407+
[ScalableVectorOfRankAndLengthAndType<[1], [16], [I8]>,
1408+
ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
1409+
[nxnxv4i32, nxnxv2i64]> {
1410+
let summary = "Unsigned by signed integer sum of 4 outer products and subtract";
1411+
let description = [{
1412+
Example: I8 to I32
1413+
```mlir
1414+
%result = arm_sme.usmops_4way $lhs, $rhs : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
1415+
```
1416+
1417+
Example: I16 to I64
1418+
```mlir
1419+
%result = arm_sme.usmops_4way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
1420+
```
1421+
1422+
Refer to [smopa_4way](#arm_smesmopa_4way-arm_smesmopa4wayop) for a
1423+
detailed description of 4-way outer products.
1424+
1425+
| Spec | Features |
1426+
| ---- | -------- |
1427+
| [USMOPS (4-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/USMOPS--Unsigned-by-signed-integer-sum-of-outer-products-and-subtract-) | +sme (32-bit), +sme-i16i64 (64-bit)|
1428+
}];
1429+
}
1430+
11061431
def StreamingVLOp : ArmSME_Op<"streaming_vl", [Pure]>
11071432
{
11081433
let summary = "Query the streaming vector length";

mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -939,6 +939,22 @@ void mlir::populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter,
939939
arm_sme::aarch64_sme_umopa_za32>,
940940
OuterProductWideningOpConversion<arm_sme::UMops2WayOp,
941941
arm_sme::aarch64_sme_umops_za32>,
942+
OuterProductWideningOpConversion<arm_sme::SMopa4WayOp,
943+
arm_sme::aarch64_sme_smopa_wide>,
944+
OuterProductWideningOpConversion<arm_sme::SMops4WayOp,
945+
arm_sme::aarch64_sme_smops_wide>,
946+
OuterProductWideningOpConversion<arm_sme::UMopa4WayOp,
947+
arm_sme::aarch64_sme_umopa_wide>,
948+
OuterProductWideningOpConversion<arm_sme::UMops4WayOp,
949+
arm_sme::aarch64_sme_umops_wide>,
950+
OuterProductWideningOpConversion<arm_sme::SuMopa4WayOp,
951+
arm_sme::aarch64_sme_sumopa_wide>,
952+
OuterProductWideningOpConversion<arm_sme::SuMops4WayOp,
953+
arm_sme::aarch64_sme_sumops_wide>,
954+
OuterProductWideningOpConversion<arm_sme::UsMopa4WayOp,
955+
arm_sme::aarch64_sme_usmopa_wide>,
956+
OuterProductWideningOpConversion<arm_sme::UsMops4WayOp,
957+
arm_sme::aarch64_sme_usmops_wide>,
942958
ZeroOpConversion, GetTileConversion>(patterns, converter);
943959
}
944960

0 commit comments

Comments
 (0)