Skip to content

[mlir][ArmSME] Add initial SME vector legalization pass #79152

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ std::unique_ptr<Pass> createTileAllocationPass();
/// variants.
std::unique_ptr<Pass> createOuterProductFusionPass();

/// Pass that legalizes vectors so they can be lowered to ArmSME.
std::unique_ptr<Pass> createVectorLegalizationPass();

//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
Expand Down
23 changes: 23 additions & 0 deletions mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -156,4 +156,27 @@ def OuterProductFusion
let dependentDialects = ["func::FuncDialect", "arm_sme::ArmSMEDialect", "LLVM::LLVMDialect"];
}

def VectorLegalization
: Pass<"arm-sme-vector-legalization", "mlir::ModuleOp"> {
let summary = "Legalize vectors for ArmSME";
let description = [{
This pass legalizes vector operations so that they can be lowered to ArmSME.
This includes decomposing operations that operate on vector types larger
than a single SME tile (e.g. `vector<[8]x[8]xf32>`) into multiple SME
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vector<[8]x[8]xf32> matches the size of ZA, right? How about vectors that are larger than this? Or smaller?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That matches the size of ZA. There's infinite bigger vector types (this pass does not care how many tiles the resulting type uses), and a few smaller sizes. There may also be cases when you want sizes smaller than ZA (but more than a tile), i.e. a matmul with a rather small width/height but a large reduction dimension. This pass allows the tiling to be fairly flexible.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right - but are there any limitations? I know that there are - those are effectively enforced by isMultipleOfSMETileVectorType, right? It would be good add a few more comments here, e.g.:

  • "legal" in this context means something that matches SMEs virtual tiles as available in hardware
  • "input" type must be a multiple of these tiles (so definitely 2d and scalable in both dimensions)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

legal - means something that can be lowered to ArmSME, as mentioned this pass will be broader than just type decomposition, and will include rewrites to eliminate illegal types like vector<[8]x4xf32> (which also falls under legalizing vector operations).

(I think the description already mentions that the decomposition breaks down larger vector types into multiple SME-sized operations?)

I'll add the input size limitation 👍

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

legal - means something that can be lowered to ArmSME

Perhaps you could try to defining that? I assume it means either:

  • 2d scalable vectors with sizes matching SME virtual tiles,
  • 1d scalable vectors
  • n-D vectors with at most the trailing dim scalable?

tile-sized operations, as well as rewrites needed to get operations into
forms compatible with SME lowerings.

Note: Decomposition is currently limited to vector types that are an exact
multiple of SME tiles. That is scalable in two dimensions, with both the
rows and columns divisible by the SVE vector length for the element type.
}];
let constructor = "mlir::arm_sme::createVectorLegalizationPass()";
let dependentDialects = [
"func::FuncDialect",
"arm_sme::ArmSMEDialect",
"vector::VectorDialect",
"arith::ArithDialect"
];
}

#endif // MLIR_DIALECT_ARMSME_TRANSFORMS_PASSES_TD
7 changes: 7 additions & 0 deletions mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,13 @@ scf::ForOp createLoopOverTileSlices(
PatternRewriter &rewriter, Location loc, Value initTile,
std::function<Value(OpBuilder &, Location, Value, Value)> makeLoopBody);

/// Returns true if `vType` is a multiple of an SME tile size. Returns false if
/// the `vType` exactly matches the size of an SME tile.
bool isMultipleOfSMETileVectorType(VectorType vType);

/// Creates a vector type for the SME tile of `elementType`.
VectorType getSMETileTypeForElement(Type elementType);

} // namespace mlir::arm_sme

#endif // MLIR_DIALECT_ARMSME_UTILS_UTILS_H_
22 changes: 22 additions & 0 deletions mlir/lib/Dialect/ArmSME/IR/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,4 +94,26 @@ scf::ForOp createLoopOverTileSlices(
return forOp;
}

bool isMultipleOfSMETileVectorType(VectorType vType) {
if (vType.getRank() != 2 || !vType.allDimsScalable())
return false;

auto elementType = vType.getElementType();
if (!isValidSMETileElementType(elementType))
return false;

unsigned minNumElts = getSMETileSliceMinNumElts(elementType);

int64_t vectorRows = vType.getDimSize(0);
int64_t vectorCols = vType.getDimSize(1);

return (vectorRows > minNumElts || vectorCols > minNumElts) &&
vectorRows % minNumElts == 0 && vectorCols % minNumElts == 0;
}

VectorType getSMETileTypeForElement(Type elementType) {
unsigned minNumElts = getSMETileSliceMinNumElts(elementType);
return VectorType::get({minNumElts, minNumElts}, elementType, {true, true});
}

} // namespace mlir::arm_sme
5 changes: 4 additions & 1 deletion mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRArmSMETransforms
EnableArmStreaming.cpp
OuterProductFusion.cpp
TileAllocation.cpp
VectorLegalization.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmSME/Transforms
Expand All @@ -10,10 +11,12 @@ add_mlir_dialect_library(MLIRArmSMETransforms
MLIRArmSMETransformsIncGen

LINK_LIBS PUBLIC
MLIRPass
MLIRArmSMEDialect
MLIRFuncDialect
MLIRLLVMCommonConversion
MLIRVectorDialect
MLIRSCFDialect
MLIRPass
MLIRSCFTransforms
MLIRFuncTransforms
)
Loading