Skip to content

[mlir][spirv] Clean up map-memref-storage-class pass #79937

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 30, 2024

Conversation

kuhar
Copy link
Member

@kuhar kuhar commented Jan 30, 2024

Clean up the code before making more substantial changes. NFC modulo extra error checking and physical storage buffer storage class handling.

  • Add switch case for physical storage buffer
  • Handle type conversion failures
  • Inline methods to reduce scrolling
  • Other minor cleanups

Clean up the code before making more substantial changes. NFC modulo extra
error checking and physical storage buffer storage class handling.

* Add switch case for physical storage bufer
* Handle type conversion failures
* Inline methods to reduce scrolling
* Clean up code
@llvmbot
Copy link
Member

llvmbot commented Jan 30, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-spirv

Author: Jakub Kuderski (kuhar)

Changes

Clean up the code before making more substantial changes. NFC modulo extra error checking and physical storage buffer storage class handling.

  • Add switch case for physical storage buffer
  • Handle type conversion failures
  • Inline methods to reduce scrolling
  • Other minor cleanups

Full diff: https://github.com/llvm/llvm-project/pull/79937.diff

1 Files Affected:

  • (modified) mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp (+78-78)
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
index c6ef5be2494ad..cb969e0b5d7f3 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
@@ -18,10 +18,12 @@
 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
+#include "mlir/IR/Attributes.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/Interfaces/FunctionInterfaces.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/Support/Debug.h"
 
@@ -54,7 +56,8 @@ using namespace mlir;
   MAP_FN(spirv::StorageClass::PushConstant, 7)                                 \
   MAP_FN(spirv::StorageClass::UniformConstant, 8)                              \
   MAP_FN(spirv::StorageClass::Input, 9)                                        \
-  MAP_FN(spirv::StorageClass::Output, 10)
+  MAP_FN(spirv::StorageClass::Output, 10)                                      \
+  MAP_FN(spirv::StorageClass::PhysicalStorageBuffer, 11)
 
 std::optional<spirv::StorageClass>
 spirv::mapMemorySpaceToVulkanStorageClass(Attribute memorySpaceAttr) {
@@ -185,13 +188,10 @@ spirv::MemorySpaceToStorageClassConverter::MemorySpaceToStorageClassConverter(
   });
 
   addConversion([this](FunctionType type) {
-    SmallVector<Type> inputs, results;
-    inputs.reserve(type.getNumInputs());
-    results.reserve(type.getNumResults());
-    for (Type input : type.getInputs())
-      inputs.push_back(convertType(input));
-    for (Type result : type.getResults())
-      results.push_back(convertType(result));
+    auto inputs = llvm::to_vector(llvm::map_range(
+        type.getInputs(), [this](Type ty) { return convertType(ty); }));
+    auto results = llvm::to_vector(llvm::map_range(
+        type.getResults(), [this](Type ty) { return convertType(ty); }));
     return FunctionType::get(type.getContext(), inputs, results);
   });
 }
@@ -250,49 +250,54 @@ spirv::getMemorySpaceToStorageClassTarget(MLIRContext &context) {
 namespace {
 /// Converts any op that has operands/results/attributes with numeric MemRef
 /// memory spaces.
-struct MapMemRefStoragePattern final : public ConversionPattern {
+struct MapMemRefStoragePattern final : ConversionPattern {
   MapMemRefStoragePattern(MLIRContext *context, TypeConverter &converter)
       : ConversionPattern(converter, MatchAnyOpTypeTag(), 1, context) {}
 
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
-                  ConversionPatternRewriter &rewriter) const override;
-};
-} // namespace
-
-LogicalResult MapMemRefStoragePattern::matchAndRewrite(
-    Operation *op, ArrayRef<Value> operands,
-    ConversionPatternRewriter &rewriter) const {
-  llvm::SmallVector<NamedAttribute, 4> newAttrs;
-  newAttrs.reserve(op->getAttrs().size());
-  for (auto attr : op->getAttrs()) {
-    if (auto typeAttr = dyn_cast<TypeAttr>(attr.getValue())) {
-      auto newAttr = getTypeConverter()->convertType(typeAttr.getValue());
-      newAttrs.emplace_back(attr.getName(), TypeAttr::get(newAttr));
-    } else {
-      newAttrs.push_back(attr);
+                  ConversionPatternRewriter &rewriter) const override {
+    llvm::SmallVector<NamedAttribute> newAttrs;
+    newAttrs.reserve(op->getAttrs().size());
+    for (NamedAttribute attr : op->getAttrs()) {
+      if (auto typeAttr = dyn_cast<TypeAttr>(attr.getValue())) {
+        Type newAttr = getTypeConverter()->convertType(typeAttr.getValue());
+        if (!newAttr) {
+          return rewriter.notifyMatchFailure(
+              op, "type attribute conversion failed");
+        }
+        newAttrs.emplace_back(attr.getName(), TypeAttr::get(newAttr));
+      } else {
+        newAttrs.push_back(attr);
+      }
     }
-  }
 
-  llvm::SmallVector<Type, 4> newResults;
-  (void)getTypeConverter()->convertTypes(op->getResultTypes(), newResults);
-
-  OperationState state(op->getLoc(), op->getName().getStringRef(), operands,
-                       newResults, newAttrs, op->getSuccessors());
+    llvm::SmallVector<Type, 4> newResults;
+    if (failed(
+            getTypeConverter()->convertTypes(op->getResultTypes(), newResults)))
+      return rewriter.notifyMatchFailure(op, "result type conversion failed");
+
+    OperationState state(op->getLoc(), op->getName().getStringRef(), operands,
+                         newResults, newAttrs, op->getSuccessors());
+
+    for (Region &region : op->getRegions()) {
+      Region *newRegion = state.addRegion();
+      rewriter.inlineRegionBefore(region, *newRegion, newRegion->begin());
+      TypeConverter::SignatureConversion result(newRegion->getNumArguments());
+      if (failed(getTypeConverter()->convertSignatureArgs(
+              newRegion->getArgumentTypes(), result))) {
+        return rewriter.notifyMatchFailure(
+            op, "signature argument type conversion failed");
+      }
+      rewriter.applySignatureConversion(newRegion, result);
+    }
 
-  for (Region &region : op->getRegions()) {
-    Region *newRegion = state.addRegion();
-    rewriter.inlineRegionBefore(region, *newRegion, newRegion->begin());
-    TypeConverter::SignatureConversion result(newRegion->getNumArguments());
-    (void)getTypeConverter()->convertSignatureArgs(
-        newRegion->getArgumentTypes(), result);
-    rewriter.applySignatureConversion(newRegion, result);
+    Operation *newOp = rewriter.create(state);
+    rewriter.replaceOp(op, newOp->getResults());
+    return success();
   }
-
-  Operation *newOp = rewriter.create(state);
-  rewriter.replaceOp(op, newOp->getResults());
-  return success();
-}
+};
+} // namespace
 
 void spirv::populateMemorySpaceToStorageClassPatterns(
     spirv::MemorySpaceToStorageClassConverter &typeConverter,
@@ -315,51 +320,46 @@ class MapMemRefStorageClassPass final
       const spirv::MemorySpaceToStorageClassMap &memorySpaceMap)
       : memorySpaceMap(memorySpaceMap) {}
 
-  LogicalResult initializeOptions(StringRef options) override;
-
-  void runOnOperation() override;
-
-private:
-  spirv::MemorySpaceToStorageClassMap memorySpaceMap;
-};
-} // namespace
+  LogicalResult initializeOptions(StringRef options) override {
+    if (failed(Pass::initializeOptions(options)))
+      return failure();
 
-LogicalResult MapMemRefStorageClassPass::initializeOptions(StringRef options) {
-  if (failed(Pass::initializeOptions(options)))
-    return failure();
+    if (clientAPI == "opencl")
+      memorySpaceMap = spirv::mapMemorySpaceToOpenCLStorageClass;
+    else if (clientAPI != "vulkan")
+      return failure();
 
-  if (clientAPI == "opencl") {
-    memorySpaceMap = spirv::mapMemorySpaceToOpenCLStorageClass;
+    return success();
   }
 
-  if (clientAPI != "vulkan" && clientAPI != "opencl")
-    return failure();
+  void runOnOperation() override {
+    MLIRContext *context = &getContext();
+    Operation *op = getOperation();
+
+    if (spirv::TargetEnvAttr attr = spirv::lookupTargetEnv(op)) {
+      spirv::TargetEnv targetEnv(attr);
+      if (targetEnv.allows(spirv::Capability::Kernel)) {
+        memorySpaceMap = spirv::mapMemorySpaceToOpenCLStorageClass;
+      } else if (targetEnv.allows(spirv::Capability::Shader)) {
+        memorySpaceMap = spirv::mapMemorySpaceToVulkanStorageClass;
+      }
+    }
 
-  return success();
-}
+    std::unique_ptr<ConversionTarget> target =
+        spirv::getMemorySpaceToStorageClassTarget(*context);
+    spirv::MemorySpaceToStorageClassConverter converter(memorySpaceMap);
 
-void MapMemRefStorageClassPass::runOnOperation() {
-  MLIRContext *context = &getContext();
-  Operation *op = getOperation();
+    RewritePatternSet patterns(context);
+    spirv::populateMemorySpaceToStorageClassPatterns(converter, patterns);
 
-  if (spirv::TargetEnvAttr attr = spirv::lookupTargetEnv(op)) {
-    spirv::TargetEnv targetEnv(attr);
-    if (targetEnv.allows(spirv::Capability::Kernel)) {
-      memorySpaceMap = spirv::mapMemorySpaceToOpenCLStorageClass;
-    } else if (targetEnv.allows(spirv::Capability::Shader)) {
-      memorySpaceMap = spirv::mapMemorySpaceToVulkanStorageClass;
-    }
+    if (failed(applyFullConversion(op, *target, std::move(patterns))))
+      return signalPassFailure();
   }
 
-  auto target = spirv::getMemorySpaceToStorageClassTarget(*context);
-  spirv::MemorySpaceToStorageClassConverter converter(memorySpaceMap);
-
-  RewritePatternSet patterns(context);
-  spirv::populateMemorySpaceToStorageClassPatterns(converter, patterns);
-
-  if (failed(applyFullConversion(op, *target, std::move(patterns))))
-    return signalPassFailure();
-}
+private:
+  spirv::MemorySpaceToStorageClassMap memorySpaceMap;
+};
+} // namespace
 
 std::unique_ptr<OperationPass<>> mlir::createMapMemRefStorageClassPass() {
   return std::make_unique<MapMemRefStorageClassPass>();

@kuhar kuhar changed the title [mlir][spirv] Clean up map memref-storage-class pass [mlir][spirv] Clean up map-memref-storage-class pass Jan 30, 2024
@kuhar
Copy link
Member Author

kuhar commented Jan 30, 2024

The Windows CI is having some infra issues, testing on linux passed.

@kuhar kuhar merged commit 7f6d445 into llvm:main Jan 30, 2024
@antiagainst
Copy link
Member

Maybe we should acutally use AttrTypeReplacer like iree-org/iree#16238.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants