@@ -240,9 +240,9 @@ struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern {
240
240
LogicalResult
241
241
matchAndRewrite (Operation *op, ArrayRef<Value> operands,
242
242
ConversionPatternRewriter &rewriter) const override {
243
- auto tileOp = dyn_cast <arm_sme::ArmSMETileOpInterface>(op);
243
+ auto tileOp = cast <arm_sme::ArmSMETileOpInterface>(op);
244
244
// Tile has a real (hardware) tile. No spills/reloads required.
245
- if (!tileOp || !tileOp .isInMemoryTile ())
245
+ if (!tileOp.isInMemoryTile ())
246
246
return failure ();
247
247
248
248
// Step 1. Create an alloca for the tile at the top of the function (if one
@@ -364,6 +364,31 @@ struct ConvertArmSMEOpToLLVMPattern : ConvertOpToLLVMPattern<SourceOp> {
364
364
}
365
365
};
366
366
367
+ // / Helper to register `ConvertArmSMEOpToLLVMPattern` patterns.
368
+ template <typename ... Pattern>
369
+ static void
370
+ addArmSMEConversionPatterns (RewritePatternSet &patterns,
371
+ LLVMTypeConverter const &typeConverter) {
372
+ (
373
+ [&] {
374
+ // Register spills/fills for ops that implement the
375
+ // `ArmSMETileOpInterface` and have `requiresSpillsAndFills` set to
376
+ // `RequiresSpillsAndFills::Yes`.
377
+ if constexpr (Pattern::requiresSpillsAndFillsConversion () &&
378
+ std::is_base_of_v<arm_sme::ArmSMETileOpInterface::Trait<
379
+ typename Pattern::ArmSMEOp>,
380
+ typename Pattern::ArmSMEOp>) {
381
+ // Add spill/fill conversions with a very high benefit to ensure
382
+ // they are lowered first.
383
+ patterns.add <ConvertArmSMESpillsAndFillsToLLVM>(
384
+ Pattern::ArmSMEOp::getOperationName (), typeConverter,
385
+ /* benefit=*/ 1337 );
386
+ }
387
+ patterns.add <Pattern>(typeConverter);
388
+ }(),
389
+ ...);
390
+ }
391
+
367
392
struct GetTileConversion
368
393
: public ConvertArmSMEOpToLLVMPattern<arm_sme::GetTileOp,
369
394
RequiresSpillsAndFills::No> {
@@ -818,25 +843,6 @@ struct ConvertArmSMEToLLVMPass
818
843
}
819
844
};
820
845
821
- // / Helper to register `ConvertArmSMEOpToLLVMPattern` patterns.
822
- template <typename ... Pattern>
823
- static void
824
- addArmSMEConversionPatterns (RewritePatternSet &patterns,
825
- LLVMTypeConverter const &typeConverter) {
826
- (
827
- [&] {
828
- if (Pattern::requiresSpillsAndFillsConversion ()) {
829
- // Add spill/fill conversions with a very high benefit to ensure they
830
- // are lowered first.
831
- patterns.add <ConvertArmSMESpillsAndFillsToLLVM>(
832
- Pattern::ArmSMEOp::getOperationName (), typeConverter,
833
- /* benefit=*/ 1337 );
834
- }
835
- patterns.add <Pattern>(typeConverter);
836
- }(),
837
- ...);
838
- }
839
-
840
846
} // namespace
841
847
842
848
void mlir::configureArmSMEToLLVMConversionLegality (ConversionTarget &target) {
0 commit comments