Skip to content

[mlir][spirv] Clean up map-memref-storage-class pass #79937

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 30, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 81 additions & 81 deletions mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/Debug.h"

Expand Down Expand Up @@ -54,7 +56,8 @@ using namespace mlir;
MAP_FN(spirv::StorageClass::PushConstant, 7) \
MAP_FN(spirv::StorageClass::UniformConstant, 8) \
MAP_FN(spirv::StorageClass::Input, 9) \
MAP_FN(spirv::StorageClass::Output, 10)
MAP_FN(spirv::StorageClass::Output, 10) \
MAP_FN(spirv::StorageClass::PhysicalStorageBuffer, 11)

std::optional<spirv::StorageClass>
spirv::mapMemorySpaceToVulkanStorageClass(Attribute memorySpaceAttr) {
Expand Down Expand Up @@ -185,13 +188,10 @@ spirv::MemorySpaceToStorageClassConverter::MemorySpaceToStorageClassConverter(
});

addConversion([this](FunctionType type) {
SmallVector<Type> inputs, results;
inputs.reserve(type.getNumInputs());
results.reserve(type.getNumResults());
for (Type input : type.getInputs())
inputs.push_back(convertType(input));
for (Type result : type.getResults())
results.push_back(convertType(result));
auto inputs = llvm::map_to_vector(
type.getInputs(), [this](Type ty) { return convertType(ty); });
auto results = llvm::map_to_vector(
type.getResults(), [this](Type ty) { return convertType(ty); });
return FunctionType::get(type.getContext(), inputs, results);
});
}
Expand Down Expand Up @@ -250,49 +250,54 @@ spirv::getMemorySpaceToStorageClassTarget(MLIRContext &context) {
namespace {
/// Converts any op that has operands/results/attributes with numeric MemRef
/// memory spaces.
struct MapMemRefStoragePattern final : public ConversionPattern {
struct MapMemRefStoragePattern final : ConversionPattern {
MapMemRefStoragePattern(MLIRContext *context, TypeConverter &converter)
: ConversionPattern(converter, MatchAnyOpTypeTag(), 1, context) {}

LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
} // namespace

LogicalResult MapMemRefStoragePattern::matchAndRewrite(
Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
llvm::SmallVector<NamedAttribute, 4> newAttrs;
newAttrs.reserve(op->getAttrs().size());
for (auto attr : op->getAttrs()) {
if (auto typeAttr = dyn_cast<TypeAttr>(attr.getValue())) {
auto newAttr = getTypeConverter()->convertType(typeAttr.getValue());
newAttrs.emplace_back(attr.getName(), TypeAttr::get(newAttr));
} else {
newAttrs.push_back(attr);
ConversionPatternRewriter &rewriter) const override {
llvm::SmallVector<NamedAttribute> newAttrs;
newAttrs.reserve(op->getAttrs().size());
for (NamedAttribute attr : op->getAttrs()) {
if (auto typeAttr = dyn_cast<TypeAttr>(attr.getValue())) {
Type newAttr = getTypeConverter()->convertType(typeAttr.getValue());
if (!newAttr) {
return rewriter.notifyMatchFailure(
op, "type attribute conversion failed");
}
newAttrs.emplace_back(attr.getName(), TypeAttr::get(newAttr));
} else {
newAttrs.push_back(attr);
}
}
}

llvm::SmallVector<Type, 4> newResults;
(void)getTypeConverter()->convertTypes(op->getResultTypes(), newResults);

OperationState state(op->getLoc(), op->getName().getStringRef(), operands,
newResults, newAttrs, op->getSuccessors());
llvm::SmallVector<Type, 4> newResults;
if (failed(
getTypeConverter()->convertTypes(op->getResultTypes(), newResults)))
return rewriter.notifyMatchFailure(op, "result type conversion failed");

OperationState state(op->getLoc(), op->getName().getStringRef(), operands,
newResults, newAttrs, op->getSuccessors());

for (Region &region : op->getRegions()) {
Region *newRegion = state.addRegion();
rewriter.inlineRegionBefore(region, *newRegion, newRegion->begin());
TypeConverter::SignatureConversion result(newRegion->getNumArguments());
if (failed(getTypeConverter()->convertSignatureArgs(
newRegion->getArgumentTypes(), result))) {
return rewriter.notifyMatchFailure(
op, "signature argument type conversion failed");
}
rewriter.applySignatureConversion(newRegion, result);
}

for (Region &region : op->getRegions()) {
Region *newRegion = state.addRegion();
rewriter.inlineRegionBefore(region, *newRegion, newRegion->begin());
TypeConverter::SignatureConversion result(newRegion->getNumArguments());
(void)getTypeConverter()->convertSignatureArgs(
newRegion->getArgumentTypes(), result);
rewriter.applySignatureConversion(newRegion, result);
Operation *newOp = rewriter.create(state);
rewriter.replaceOp(op, newOp->getResults());
return success();
}

Operation *newOp = rewriter.create(state);
rewriter.replaceOp(op, newOp->getResults());
return success();
}
};
} // namespace

void spirv::populateMemorySpaceToStorageClassPatterns(
spirv::MemorySpaceToStorageClassConverter &typeConverter,
Expand All @@ -308,58 +313,53 @@ namespace {
class MapMemRefStorageClassPass final
: public impl::MapMemRefStorageClassBase<MapMemRefStorageClassPass> {
public:
explicit MapMemRefStorageClassPass() {
memorySpaceMap = spirv::mapMemorySpaceToVulkanStorageClass;
}
MapMemRefStorageClassPass() = default;

explicit MapMemRefStorageClassPass(
const spirv::MemorySpaceToStorageClassMap &memorySpaceMap)
: memorySpaceMap(memorySpaceMap) {}

LogicalResult initializeOptions(StringRef options) override;

void runOnOperation() override;

private:
spirv::MemorySpaceToStorageClassMap memorySpaceMap;
};
} // namespace
LogicalResult initializeOptions(StringRef options) override {
if (failed(Pass::initializeOptions(options)))
return failure();

LogicalResult MapMemRefStorageClassPass::initializeOptions(StringRef options) {
if (failed(Pass::initializeOptions(options)))
return failure();
if (clientAPI == "opencl")
memorySpaceMap = spirv::mapMemorySpaceToOpenCLStorageClass;
else if (clientAPI != "vulkan")
return failure();

if (clientAPI == "opencl") {
memorySpaceMap = spirv::mapMemorySpaceToOpenCLStorageClass;
return success();
}

if (clientAPI != "vulkan" && clientAPI != "opencl")
return failure();
void runOnOperation() override {
MLIRContext *context = &getContext();
Operation *op = getOperation();

if (spirv::TargetEnvAttr attr = spirv::lookupTargetEnv(op)) {
spirv::TargetEnv targetEnv(attr);
if (targetEnv.allows(spirv::Capability::Kernel)) {
memorySpaceMap = spirv::mapMemorySpaceToOpenCLStorageClass;
} else if (targetEnv.allows(spirv::Capability::Shader)) {
memorySpaceMap = spirv::mapMemorySpaceToVulkanStorageClass;
}
}

return success();
}
std::unique_ptr<ConversionTarget> target =
spirv::getMemorySpaceToStorageClassTarget(*context);
spirv::MemorySpaceToStorageClassConverter converter(memorySpaceMap);

void MapMemRefStorageClassPass::runOnOperation() {
MLIRContext *context = &getContext();
Operation *op = getOperation();
RewritePatternSet patterns(context);
spirv::populateMemorySpaceToStorageClassPatterns(converter, patterns);

if (spirv::TargetEnvAttr attr = spirv::lookupTargetEnv(op)) {
spirv::TargetEnv targetEnv(attr);
if (targetEnv.allows(spirv::Capability::Kernel)) {
memorySpaceMap = spirv::mapMemorySpaceToOpenCLStorageClass;
} else if (targetEnv.allows(spirv::Capability::Shader)) {
memorySpaceMap = spirv::mapMemorySpaceToVulkanStorageClass;
}
if (failed(applyFullConversion(op, *target, std::move(patterns))))
return signalPassFailure();
}

auto target = spirv::getMemorySpaceToStorageClassTarget(*context);
spirv::MemorySpaceToStorageClassConverter converter(memorySpaceMap);

RewritePatternSet patterns(context);
spirv::populateMemorySpaceToStorageClassPatterns(converter, patterns);

if (failed(applyFullConversion(op, *target, std::move(patterns))))
return signalPassFailure();
}
private:
spirv::MemorySpaceToStorageClassMap memorySpaceMap =
spirv::mapMemorySpaceToVulkanStorageClass;
};
} // namespace

std::unique_ptr<OperationPass<>> mlir::createMapMemRefStorageClassPass() {
return std::make_unique<MapMemRefStorageClassPass>();
Expand Down