-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][spirv] Clean up map-memref-storage-class pass #79937
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Clean up the code before making more substantial changes. NFC modulo extra error checking and physical storage buffer storage class handling. * Add switch case for physical storage bufer * Handle type conversion failures * Inline methods to reduce scrolling * Clean up code
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-spirv Author: Jakub Kuderski (kuhar) ChangesClean up the code before making more substantial changes. NFC modulo extra error checking and physical storage buffer storage class handling.
Full diff: https://github.com/llvm/llvm-project/pull/79937.diff 1 Files Affected:
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
index c6ef5be2494ad..cb969e0b5d7f3 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
@@ -18,10 +18,12 @@
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
+#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/Debug.h"
@@ -54,7 +56,8 @@ using namespace mlir;
MAP_FN(spirv::StorageClass::PushConstant, 7) \
MAP_FN(spirv::StorageClass::UniformConstant, 8) \
MAP_FN(spirv::StorageClass::Input, 9) \
- MAP_FN(spirv::StorageClass::Output, 10)
+ MAP_FN(spirv::StorageClass::Output, 10) \
+ MAP_FN(spirv::StorageClass::PhysicalStorageBuffer, 11)
std::optional<spirv::StorageClass>
spirv::mapMemorySpaceToVulkanStorageClass(Attribute memorySpaceAttr) {
@@ -185,13 +188,10 @@ spirv::MemorySpaceToStorageClassConverter::MemorySpaceToStorageClassConverter(
});
addConversion([this](FunctionType type) {
- SmallVector<Type> inputs, results;
- inputs.reserve(type.getNumInputs());
- results.reserve(type.getNumResults());
- for (Type input : type.getInputs())
- inputs.push_back(convertType(input));
- for (Type result : type.getResults())
- results.push_back(convertType(result));
+ auto inputs = llvm::to_vector(llvm::map_range(
+ type.getInputs(), [this](Type ty) { return convertType(ty); }));
+ auto results = llvm::to_vector(llvm::map_range(
+ type.getResults(), [this](Type ty) { return convertType(ty); }));
return FunctionType::get(type.getContext(), inputs, results);
});
}
@@ -250,49 +250,54 @@ spirv::getMemorySpaceToStorageClassTarget(MLIRContext &context) {
namespace {
/// Converts any op that has operands/results/attributes with numeric MemRef
/// memory spaces.
-struct MapMemRefStoragePattern final : public ConversionPattern {
+struct MapMemRefStoragePattern final : ConversionPattern {
MapMemRefStoragePattern(MLIRContext *context, TypeConverter &converter)
: ConversionPattern(converter, MatchAnyOpTypeTag(), 1, context) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override;
-};
-} // namespace
-
-LogicalResult MapMemRefStoragePattern::matchAndRewrite(
- Operation *op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const {
- llvm::SmallVector<NamedAttribute, 4> newAttrs;
- newAttrs.reserve(op->getAttrs().size());
- for (auto attr : op->getAttrs()) {
- if (auto typeAttr = dyn_cast<TypeAttr>(attr.getValue())) {
- auto newAttr = getTypeConverter()->convertType(typeAttr.getValue());
- newAttrs.emplace_back(attr.getName(), TypeAttr::get(newAttr));
- } else {
- newAttrs.push_back(attr);
+ ConversionPatternRewriter &rewriter) const override {
+ llvm::SmallVector<NamedAttribute> newAttrs;
+ newAttrs.reserve(op->getAttrs().size());
+ for (NamedAttribute attr : op->getAttrs()) {
+ if (auto typeAttr = dyn_cast<TypeAttr>(attr.getValue())) {
+ Type newAttr = getTypeConverter()->convertType(typeAttr.getValue());
+ if (!newAttr) {
+ return rewriter.notifyMatchFailure(
+ op, "type attribute conversion failed");
+ }
+ newAttrs.emplace_back(attr.getName(), TypeAttr::get(newAttr));
+ } else {
+ newAttrs.push_back(attr);
+ }
}
- }
- llvm::SmallVector<Type, 4> newResults;
- (void)getTypeConverter()->convertTypes(op->getResultTypes(), newResults);
-
- OperationState state(op->getLoc(), op->getName().getStringRef(), operands,
- newResults, newAttrs, op->getSuccessors());
+ llvm::SmallVector<Type, 4> newResults;
+ if (failed(
+ getTypeConverter()->convertTypes(op->getResultTypes(), newResults)))
+ return rewriter.notifyMatchFailure(op, "result type conversion failed");
+
+ OperationState state(op->getLoc(), op->getName().getStringRef(), operands,
+ newResults, newAttrs, op->getSuccessors());
+
+ for (Region ®ion : op->getRegions()) {
+ Region *newRegion = state.addRegion();
+ rewriter.inlineRegionBefore(region, *newRegion, newRegion->begin());
+ TypeConverter::SignatureConversion result(newRegion->getNumArguments());
+ if (failed(getTypeConverter()->convertSignatureArgs(
+ newRegion->getArgumentTypes(), result))) {
+ return rewriter.notifyMatchFailure(
+ op, "signature argument type conversion failed");
+ }
+ rewriter.applySignatureConversion(newRegion, result);
+ }
- for (Region ®ion : op->getRegions()) {
- Region *newRegion = state.addRegion();
- rewriter.inlineRegionBefore(region, *newRegion, newRegion->begin());
- TypeConverter::SignatureConversion result(newRegion->getNumArguments());
- (void)getTypeConverter()->convertSignatureArgs(
- newRegion->getArgumentTypes(), result);
- rewriter.applySignatureConversion(newRegion, result);
+ Operation *newOp = rewriter.create(state);
+ rewriter.replaceOp(op, newOp->getResults());
+ return success();
}
-
- Operation *newOp = rewriter.create(state);
- rewriter.replaceOp(op, newOp->getResults());
- return success();
-}
+};
+} // namespace
void spirv::populateMemorySpaceToStorageClassPatterns(
spirv::MemorySpaceToStorageClassConverter &typeConverter,
@@ -315,51 +320,46 @@ class MapMemRefStorageClassPass final
const spirv::MemorySpaceToStorageClassMap &memorySpaceMap)
: memorySpaceMap(memorySpaceMap) {}
- LogicalResult initializeOptions(StringRef options) override;
-
- void runOnOperation() override;
-
-private:
- spirv::MemorySpaceToStorageClassMap memorySpaceMap;
-};
-} // namespace
+ LogicalResult initializeOptions(StringRef options) override {
+ if (failed(Pass::initializeOptions(options)))
+ return failure();
-LogicalResult MapMemRefStorageClassPass::initializeOptions(StringRef options) {
- if (failed(Pass::initializeOptions(options)))
- return failure();
+ if (clientAPI == "opencl")
+ memorySpaceMap = spirv::mapMemorySpaceToOpenCLStorageClass;
+ else if (clientAPI != "vulkan")
+ return failure();
- if (clientAPI == "opencl") {
- memorySpaceMap = spirv::mapMemorySpaceToOpenCLStorageClass;
+ return success();
}
- if (clientAPI != "vulkan" && clientAPI != "opencl")
- return failure();
+ void runOnOperation() override {
+ MLIRContext *context = &getContext();
+ Operation *op = getOperation();
+
+ if (spirv::TargetEnvAttr attr = spirv::lookupTargetEnv(op)) {
+ spirv::TargetEnv targetEnv(attr);
+ if (targetEnv.allows(spirv::Capability::Kernel)) {
+ memorySpaceMap = spirv::mapMemorySpaceToOpenCLStorageClass;
+ } else if (targetEnv.allows(spirv::Capability::Shader)) {
+ memorySpaceMap = spirv::mapMemorySpaceToVulkanStorageClass;
+ }
+ }
- return success();
-}
+ std::unique_ptr<ConversionTarget> target =
+ spirv::getMemorySpaceToStorageClassTarget(*context);
+ spirv::MemorySpaceToStorageClassConverter converter(memorySpaceMap);
-void MapMemRefStorageClassPass::runOnOperation() {
- MLIRContext *context = &getContext();
- Operation *op = getOperation();
+ RewritePatternSet patterns(context);
+ spirv::populateMemorySpaceToStorageClassPatterns(converter, patterns);
- if (spirv::TargetEnvAttr attr = spirv::lookupTargetEnv(op)) {
- spirv::TargetEnv targetEnv(attr);
- if (targetEnv.allows(spirv::Capability::Kernel)) {
- memorySpaceMap = spirv::mapMemorySpaceToOpenCLStorageClass;
- } else if (targetEnv.allows(spirv::Capability::Shader)) {
- memorySpaceMap = spirv::mapMemorySpaceToVulkanStorageClass;
- }
+ if (failed(applyFullConversion(op, *target, std::move(patterns))))
+ return signalPassFailure();
}
- auto target = spirv::getMemorySpaceToStorageClassTarget(*context);
- spirv::MemorySpaceToStorageClassConverter converter(memorySpaceMap);
-
- RewritePatternSet patterns(context);
- spirv::populateMemorySpaceToStorageClassPatterns(converter, patterns);
-
- if (failed(applyFullConversion(op, *target, std::move(patterns))))
- return signalPassFailure();
-}
+private:
+ spirv::MemorySpaceToStorageClassMap memorySpaceMap;
+};
+} // namespace
std::unique_ptr<OperationPass<>> mlir::createMapMemRefStorageClassPass() {
return std::make_unique<MapMemRefStorageClassPass>();
|
The Windows CI is having some infra issues, testing on linux passed. |
Maybe we should acutally use |
Clean up the code before making more substantial changes. NFC modulo extra error checking and physical storage buffer storage class handling.