Skip to content

[mlir][ArmSME] Support decomposing constant splats into ArmSME tiles #88762

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 16, 2024

Conversation

MacDue
Copy link
Member

@MacDue MacDue commented Apr 15, 2024

This adds a simple rewrite/legalization to decompose constant splats larger than a single ArmSME tile into multiple SME virtual tile sized splats. E.g. a constant splat to vector<[8]x[8]xi32> would decompose into four vector<[4]x[4]xi32> splats.

This adds a simple rewrite/legalization to decompose constant splats
larger than a single ArmSME tile into multiple SME virtual tile sized
splats. E.g. a constant splat to `vector<[8]x[8]xi32>` would decompose
into four `vector<[4]x[4]xi32>` splats.
@llvmbot
Copy link
Member

llvmbot commented Apr 15, 2024

@llvm/pr-subscribers-mlir-sme

@llvm/pr-subscribers-mlir

Author: Benjamin Maxwell (MacDue)

Changes

This adds a simple rewrite/legalization to decompose constant splats larger than a single ArmSME tile into multiple SME virtual tile sized splats. E.g. a constant splat to vector&lt;[8]x[8]xi32&gt; would decompose into four vector&lt;[4]x[4]xi32&gt; splats.


Full diff: https://github.com/llvm/llvm-project/pull/88762.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp (+31-1)
  • (modified) mlir/test/Dialect/ArmSME/vector-legalization.mlir (+11)
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 31500c62c0d600..b595c6dd8a6848 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -165,6 +165,35 @@ int getNumberOfSMETilesForVectorType(VectorType type) {
   return (vectorRows * vectorCols) / (minNumElts * minNumElts);
 }
 
+/// Legalize `arith.constant dense<value>` splat operations to fit within SME
+/// tiles by decomposing them into tile-sized operations.
+struct LegalizeArithConstantOpsByDecomposition
+    : public OneToNOpConversionPattern<arith::ConstantOp> {
+  using OneToNOpConversionPattern::OneToNOpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(arith::ConstantOp constantOp, OpAdaptor adaptor,
+                  OneToNPatternRewriter &rewriter) const override {
+    auto vectorType = dyn_cast<VectorType>(constantOp.getType());
+    auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr());
+    if (!vectorType || !denseAttr || !denseAttr.isSplat())
+      return failure();
+
+    if (!isMultipleOfSMETileVectorType(vectorType))
+      return rewriter.notifyMatchFailure(constantOp,
+                                         kMatchFailureNotSMETileTypeMultiple);
+
+    auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
+    auto tileCount = getNumberOfSMETilesForVectorType(vectorType);
+    auto tileSplat = rewriter.create<arith::ConstantOp>(
+        constantOp.getLoc(), denseAttr.resizeSplat(smeTileType));
+    rewriter.replaceOp(constantOp, SmallVector<Value>(tileCount, tileSplat),
+                       adaptor.getResultMapping());
+
+    return success();
+  }
+};
+
 /// Legalize `vector.outerproduct` operations to fit within SME tiles by
 /// decomposing them into tile-sized operations.
 struct LegalizeVectorOuterProductOpsByDecomposition
@@ -637,7 +666,8 @@ struct VectorLegalizationPass
     // Note: High benefit to ensure masked outer products are lowered first.
     patterns.add<LegalizeMaskedVectorOuterProductOpsByDecomposition>(
         converter, context, 1024);
-    patterns.add<LegalizeVectorOuterProductOpsByDecomposition,
+    patterns.add<LegalizeArithConstantOpsByDecomposition,
+                 LegalizeVectorOuterProductOpsByDecomposition,
                  LegalizeTransferReadOpsByDecomposition,
                  LegalizeTransferWriteOpsByDecomposition>(converter, context);
     populateFuncTypeConversionPatterns(converter, patterns);
diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
index f8be697548c197..f43ef1cce787c5 100644
--- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -433,3 +433,14 @@ func.func @lift_illegal_1d_shape_cast_to_memory(%a: index, %b: index, %memref: m
   %cast = vector.shape_cast %illegalRead : vector<[4]x1xf32> to vector<[4]xf32>
   return %cast : vector<[4]xf32>
 }
+
+// -----
+
+// CHECK-LABEL: @multi_tile_splat
+func.func @multi_tile_splat() -> vector<[8]x[8]xi32>
+{
+  // CHECK: %[[SPLAT:.*]] = arith.constant dense<42> : vector<[4]x[4]xi32>
+  // CHECK-NEXT: return %[[SPLAT]], %[[SPLAT]], %[[SPLAT]], %[[SPLAT]] : vector<[4]x[4]xi32>, vector<[4]x[4]xi32>, vector<[4]x[4]xi32>, vector<[4]x[4]xi32>
+  %0 = arith.constant dense<42> : vector<[8]x[8]xi32>
+  return %0 : vector<[8]x[8]xi32>
+}

Copy link
Collaborator

@c-rhodes c-rhodes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM cheers

@MacDue MacDue merged commit dadcaf8 into llvm:main Apr 16, 2024
@MacDue MacDue deleted the sme_splats branch April 16, 2024 11:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants