Skip to content

[mlir][spirv] Fix a crash of typeConverter with non supported type #79955

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 31, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 64 additions & 55 deletions mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands,
if (auto loadOp = dyn_cast<spirv::LoadOp>(op)) {
auto dstType = typeConverter.convertType(loadOp.getType());
if (!dstType)
return failure();
return rewriter.notifyMatchFailure(op, "type conversion failed");
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
loadOp, dstType, spirv::LoadOpAdaptor(operands).getPtr(), alignment,
isVolatile, isNonTemporal);
Expand Down Expand Up @@ -357,22 +357,23 @@ class AccessChainPattern : public SPIRVToLLVMConversion<spirv::AccessChainOp> {
ConversionPatternRewriter &rewriter) const override {
auto dstType = typeConverter.convertType(op.getComponentPtr().getType());
if (!dstType)
return failure();
return rewriter.notifyMatchFailure(op, "type conversion failed");
// To use GEP we need to add a first 0 index to go through the pointer.
auto indices = llvm::to_vector<4>(adaptor.getIndices());
Type indexType = op.getIndices().front().getType();
auto llvmIndexType = typeConverter.convertType(indexType);
if (!llvmIndexType)
return failure();
return rewriter.notifyMatchFailure(op, "type conversion failed");
Value zero = rewriter.create<LLVM::ConstantOp>(
op.getLoc(), llvmIndexType, rewriter.getIntegerAttr(indexType, 0));
indices.insert(indices.begin(), zero);
rewriter.replaceOpWithNewOp<LLVM::GEPOp>(
op, dstType,
typeConverter.convertType(
cast<spirv::PointerType>(op.getBasePtr().getType())
.getPointeeType()),
adaptor.getBasePtr(), indices);

auto elementType = typeConverter.convertType(
cast<spirv::PointerType>(op.getBasePtr().getType()).getPointeeType());
if (!elementType)
return rewriter.notifyMatchFailure(op, "type conversion failed");
rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, dstType, elementType,
adaptor.getBasePtr(), indices);
return success();
}
};
Expand All @@ -386,7 +387,7 @@ class AddressOfPattern : public SPIRVToLLVMConversion<spirv::AddressOfOp> {
ConversionPatternRewriter &rewriter) const override {
auto dstType = typeConverter.convertType(op.getPointer().getType());
if (!dstType)
return failure();
return rewriter.notifyMatchFailure(op, "type conversion failed");
rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, dstType,
op.getVariable());
return success();
Expand All @@ -404,7 +405,7 @@ class BitFieldInsertPattern
auto srcType = op.getType();
auto dstType = typeConverter.convertType(srcType);
if (!dstType)
return failure();
return rewriter.notifyMatchFailure(op, "type conversion failed");
Location loc = op.getLoc();

// Process `Offset` and `Count`: broadcast and extend/truncate if needed.
Expand Down Expand Up @@ -451,7 +452,7 @@ class ConstantScalarAndVectorPattern

auto dstType = typeConverter.convertType(srcType);
if (!dstType)
return failure();
return rewriter.notifyMatchFailure(constOp, "type conversion failed");

// SPIR-V constant can be a signed/unsigned integer, which has to be
// casted to signless integer when converting to LLVM dialect. Removing the
Expand Down Expand Up @@ -492,7 +493,7 @@ class BitFieldSExtractPattern
auto srcType = op.getType();
auto dstType = typeConverter.convertType(srcType);
if (!dstType)
return failure();
return rewriter.notifyMatchFailure(op, "type conversion failed");
Location loc = op.getLoc();

// Process `Offset` and `Count`: broadcast and extend/truncate if needed.
Expand Down Expand Up @@ -545,7 +546,7 @@ class BitFieldUExtractPattern
auto srcType = op.getType();
auto dstType = typeConverter.convertType(srcType);
if (!dstType)
return failure();
return rewriter.notifyMatchFailure(op, "type conversion failed");
Location loc = op.getLoc();

// Process `Offset` and `Count`: broadcast and extend/truncate if needed.
Expand Down Expand Up @@ -621,7 +622,7 @@ class CompositeExtractPattern
ConversionPatternRewriter &rewriter) const override {
auto dstType = this->typeConverter.convertType(op.getType());
if (!dstType)
return failure();
return rewriter.notifyMatchFailure(op, "type conversion failed");

Type containerType = op.getComposite().getType();
if (isa<VectorType>(containerType)) {
Expand Down Expand Up @@ -653,7 +654,7 @@ class CompositeInsertPattern
ConversionPatternRewriter &rewriter) const override {
auto dstType = this->typeConverter.convertType(op.getType());
if (!dstType)
return failure();
return rewriter.notifyMatchFailure(op, "type conversion failed");

Type containerType = op.getComposite().getType();
if (isa<VectorType>(containerType)) {
Expand All @@ -680,13 +681,13 @@ class DirectConversionPattern : public SPIRVToLLVMConversion<SPIRVOp> {
using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;

LogicalResult
matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor,
matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto dstType = this->typeConverter.convertType(operation.getType());
auto dstType = this->typeConverter.convertType(op.getType());
if (!dstType)
return failure();
return rewriter.notifyMatchFailure(op, "type conversion failed");
rewriter.template replaceOpWithNewOp<LLVMOp>(
operation, dstType, adaptor.getOperands(), operation->getAttrs());
op, dstType, adaptor.getOperands(), op->getAttrs());
return success();
}
};
Expand Down Expand Up @@ -790,7 +791,7 @@ class GlobalVariablePattern
auto srcType = cast<spirv::PointerType>(op.getType());
auto dstType = typeConverter.convertType(srcType.getPointeeType());
if (!dstType)
return failure();
return rewriter.notifyMatchFailure(op, "type conversion failed");

// Limit conversion to the current invocation only or `StorageBuffer`
// required by SPIR-V runner.
Expand Down Expand Up @@ -843,23 +844,23 @@ class IndirectCastPattern : public SPIRVToLLVMConversion<SPIRVOp> {
using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;

LogicalResult
matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor,
matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

Type fromType = operation.getOperand().getType();
Type toType = operation.getType();
Type fromType = op.getOperand().getType();
Type toType = op.getType();

auto dstType = this->typeConverter.convertType(toType);
if (!dstType)
return failure();
return rewriter.notifyMatchFailure(op, "type conversion failed");

if (getBitWidth(fromType) < getBitWidth(toType)) {
rewriter.template replaceOpWithNewOp<LLVMExtOp>(operation, dstType,
rewriter.template replaceOpWithNewOp<LLVMExtOp>(op, dstType,
adaptor.getOperands());
return success();
}
if (getBitWidth(fromType) > getBitWidth(toType)) {
rewriter.template replaceOpWithNewOp<LLVMTruncOp>(operation, dstType,
rewriter.template replaceOpWithNewOp<LLVMTruncOp>(op, dstType,
adaptor.getOperands());
return success();
}
Expand All @@ -883,6 +884,8 @@ class FunctionCallPattern

// Function returns a single result.
auto dstType = typeConverter.convertType(callOp.getType(0));
if (!dstType)
return rewriter.notifyMatchFailure(callOp, "type conversion failed");
rewriter.replaceOpWithNewOp<LLVM::CallOp>(
callOp, dstType, adaptor.getOperands(), callOp->getAttrs());
return success();
Expand All @@ -896,16 +899,15 @@ class FComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;

LogicalResult
matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor,
matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

auto dstType = this->typeConverter.convertType(operation.getType());
auto dstType = this->typeConverter.convertType(op.getType());
if (!dstType)
return failure();
return rewriter.notifyMatchFailure(op, "type conversion failed");

rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>(
operation, dstType, predicate, operation.getOperand1(),
operation.getOperand2());
op, dstType, predicate, op.getOperand1(), op.getOperand2());
return success();
}
};
Expand All @@ -917,16 +919,15 @@ class IComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;

LogicalResult
matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor,
matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

auto dstType = this->typeConverter.convertType(operation.getType());
auto dstType = this->typeConverter.convertType(op.getType());
if (!dstType)
return failure();
return rewriter.notifyMatchFailure(op, "type conversion failed");

rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>(
operation, dstType, predicate, operation.getOperand1(),
operation.getOperand2());
op, dstType, predicate, op.getOperand1(), op.getOperand2());
return success();
}
};
Expand All @@ -942,7 +943,7 @@ class InverseSqrtPattern
auto srcType = op.getType();
auto dstType = typeConverter.convertType(srcType);
if (!dstType)
return failure();
return rewriter.notifyMatchFailure(op, "type conversion failed");

Location loc = op.getLoc();
Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
Expand Down Expand Up @@ -1000,7 +1001,7 @@ class NotPattern : public SPIRVToLLVMConversion<SPIRVOp> {
auto srcType = notOp.getType();
auto dstType = this->typeConverter.convertType(srcType);
if (!dstType)
return failure();
return rewriter.notifyMatchFailure(notOp, "type conversion failed");

Location loc = notOp.getLoc();
IntegerAttr minusOne = minusOneIntegerAttribute(srcType, rewriter);
Expand Down Expand Up @@ -1226,18 +1227,18 @@ class ShiftPattern : public SPIRVToLLVMConversion<SPIRVOp> {
using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;

LogicalResult
matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor,
matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

auto dstType = this->typeConverter.convertType(operation.getType());
auto dstType = this->typeConverter.convertType(op.getType());
if (!dstType)
return failure();
return rewriter.notifyMatchFailure(op, "type conversion failed");

Type op1Type = operation.getOperand1().getType();
Type op2Type = operation.getOperand2().getType();
Type op1Type = op.getOperand1().getType();
Type op2Type = op.getOperand2().getType();

if (op1Type == op2Type) {
rewriter.template replaceOpWithNewOp<LLVMOp>(operation, dstType,
rewriter.template replaceOpWithNewOp<LLVMOp>(op, dstType,
adaptor.getOperands());
return success();
}
Expand All @@ -1250,7 +1251,7 @@ class ShiftPattern : public SPIRVToLLVMConversion<SPIRVOp> {
if (!dstTypeWidth || !op2TypeWidth)
return failure();

Location loc = operation.getLoc();
Location loc = op.getLoc();
Value extended;
if (op2TypeWidth < dstTypeWidth) {
if (isUnsignedIntegerOrVector(op2Type)) {
Expand All @@ -1268,7 +1269,7 @@ class ShiftPattern : public SPIRVToLLVMConversion<SPIRVOp> {

Value result = rewriter.template create<LLVMOp>(
loc, dstType, adaptor.getOperand1(), extended);
rewriter.replaceOp(operation, result);
rewriter.replaceOp(op, result);
return success();
}
};
Expand All @@ -1282,7 +1283,7 @@ class TanPattern : public SPIRVToLLVMConversion<spirv::GLTanOp> {
ConversionPatternRewriter &rewriter) const override {
auto dstType = typeConverter.convertType(tanOp.getType());
if (!dstType)
return failure();
return rewriter.notifyMatchFailure(tanOp, "type conversion failed");

Location loc = tanOp.getLoc();
Value sin = rewriter.create<LLVM::SinOp>(loc, dstType, tanOp.getOperand());
Expand All @@ -1308,7 +1309,7 @@ class TanhPattern : public SPIRVToLLVMConversion<spirv::GLTanhOp> {
auto srcType = tanhOp.getType();
auto dstType = typeConverter.convertType(srcType);
if (!dstType)
return failure();
return rewriter.notifyMatchFailure(tanhOp, "type conversion failed");

Location loc = tanhOp.getLoc();
Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0);
Expand Down Expand Up @@ -1342,17 +1343,23 @@ class VariablePattern : public SPIRVToLLVMConversion<spirv::VariableOp> {

auto dstType = typeConverter.convertType(srcType);
if (!dstType)
return failure();
return rewriter.notifyMatchFailure(varOp, "type conversion failed");

Location loc = varOp.getLoc();
Value size = createI32ConstantOf(loc, rewriter, 1);
if (!init) {
rewriter.replaceOpWithNewOp<LLVM::AllocaOp>(
varOp, dstType, typeConverter.convertType(pointerTo), size);
auto elementType = typeConverter.convertType(pointerTo);
if (!elementType)
return rewriter.notifyMatchFailure(varOp, "type conversion failed");
rewriter.replaceOpWithNewOp<LLVM::AllocaOp>(varOp, dstType, elementType,
size);
return success();
}
Value allocated = rewriter.create<LLVM::AllocaOp>(
loc, dstType, typeConverter.convertType(pointerTo), size);
auto elementType = typeConverter.convertType(pointerTo);
if (!elementType)
return rewriter.notifyMatchFailure(varOp, "type conversion failed");
Value allocated =
rewriter.create<LLVM::AllocaOp>(loc, dstType, elementType, size);
rewriter.create<LLVM::StoreOp>(loc, adaptor.getInitializer(), allocated);
rewriter.replaceOp(varOp, allocated);
return success();
Expand All @@ -1373,7 +1380,7 @@ class BitcastConversionPattern
ConversionPatternRewriter &rewriter) const override {
auto dstType = typeConverter.convertType(bitcastOp.getType());
if (!dstType)
return failure();
return rewriter.notifyMatchFailure(bitcastOp, "type conversion failed");

// LLVM's opaque pointers do not require bitcasts.
if (isa<LLVM::LLVMPointerType>(dstType)) {
Expand Down Expand Up @@ -1499,6 +1506,8 @@ class VectorShufflePattern
}

auto dstType = typeConverter.convertType(op.getType());
if (!dstType)
return rewriter.notifyMatchFailure(op, "type conversion failed");
auto scalarType = cast<VectorType>(dstType).getElementType();
auto componentsArray = components.getValue();
auto *context = rewriter.getContext();
Expand Down