Skip to content

[mlir][LLVM] Switch undef for poison for uninitialized values #125629

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mlir/include/mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ class ComplexStructBuilder : public StructBuilder {
/// Construct a helper for the given complex number value.
using StructBuilder::StructBuilder;
/// Build IR creating an `undef` value of the complex number type.
static ComplexStructBuilder undef(OpBuilder &builder, Location loc,
Type type);
static ComplexStructBuilder poison(OpBuilder &builder, Location loc,
Type type);

// Build IR extracting the real value from the complex number struct.
Value real(OpBuilder &builder, Location loc);
Expand Down
10 changes: 5 additions & 5 deletions mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ class MemRefDescriptor : public StructBuilder {
public:
/// Construct a helper for the given descriptor value.
explicit MemRefDescriptor(Value descriptor);
/// Builds IR creating an `undef` value of the descriptor type.
static MemRefDescriptor undef(OpBuilder &builder, Location loc,
Type descriptorType);
/// Builds IR creating a `poison` value of the descriptor type.
static MemRefDescriptor poison(OpBuilder &builder, Location loc,
Type descriptorType);
/// Builds IR creating a MemRef descriptor that represents `type` and
/// populates it with static shape and stride information extracted from the
/// type.
Expand Down Expand Up @@ -160,8 +160,8 @@ class UnrankedMemRefDescriptor : public StructBuilder {
/// Construct a helper for the given descriptor value.
explicit UnrankedMemRefDescriptor(Value descriptor);
/// Builds IR creating an `undef` value of the descriptor type.
static UnrankedMemRefDescriptor undef(OpBuilder &builder, Location loc,
Type descriptorType);
static UnrankedMemRefDescriptor poison(OpBuilder &builder, Location loc,
Type descriptorType);

/// Builds IR extracting the rank from the descriptor
Value rank(OpBuilder &builder, Location loc) const;
Expand Down
6 changes: 3 additions & 3 deletions mlir/include/mlir/Conversion/LLVMCommon/StructBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ class StructBuilder {
public:
/// Construct a helper for the given value.
explicit StructBuilder(Value v);
/// Builds IR creating an `undef` value of the descriptor type.
static StructBuilder undef(OpBuilder &builder, Location loc,
Type descriptorType);
/// Builds IR creating a `poison` value of the descriptor type.
static StructBuilder poison(OpBuilder &builder, Location loc,
Type descriptorType);

/*implicit*/ operator Value() { return value; }

Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern {
auto allTruePredicate = rewriter.create<arith::ConstantOp>(
loc, DenseElementsAttr::get(predicateType, true));
// Create padding vector (never used due to all-true predicate).
auto padVector = rewriter.create<LLVM::UndefOp>(loc, sliceType);
auto padVector = rewriter.create<LLVM::PoisonOp>(loc, sliceType);
// Get a pointer to the current slice.
auto slicePtr =
getInMemoryTileSlicePtr(rewriter, loc, tileAlloca, sliceIndex);
Expand Down
17 changes: 9 additions & 8 deletions mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ using namespace mlir::arith;
static constexpr unsigned kRealPosInComplexNumberStruct = 0;
static constexpr unsigned kImaginaryPosInComplexNumberStruct = 1;

ComplexStructBuilder ComplexStructBuilder::undef(OpBuilder &builder,
Location loc, Type type) {
Value val = builder.create<LLVM::UndefOp>(loc, type);
ComplexStructBuilder ComplexStructBuilder::poison(OpBuilder &builder,
Location loc, Type type) {
Value val = builder.create<LLVM::PoisonOp>(loc, type);
return ComplexStructBuilder(val);
}

Expand Down Expand Up @@ -109,7 +109,8 @@ struct CreateOpConversion : public ConvertOpToLLVMPattern<complex::CreateOp> {
// Pack real and imaginary part in a complex number struct.
auto loc = complexOp.getLoc();
auto structType = typeConverter->convertType(complexOp.getType());
auto complexStruct = ComplexStructBuilder::undef(rewriter, loc, structType);
auto complexStruct =
ComplexStructBuilder::poison(rewriter, loc, structType);
complexStruct.setReal(rewriter, loc, adaptor.getReal());
complexStruct.setImaginary(rewriter, loc, adaptor.getImaginary());

Expand Down Expand Up @@ -183,7 +184,7 @@ struct AddOpConversion : public ConvertOpToLLVMPattern<complex::AddOp> {

// Initialize complex number struct for result.
auto structType = typeConverter->convertType(op.getType());
auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
auto result = ComplexStructBuilder::poison(rewriter, loc, structType);

// Emit IR to add complex numbers.
arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
Expand Down Expand Up @@ -214,7 +215,7 @@ struct DivOpConversion : public ConvertOpToLLVMPattern<complex::DivOp> {

// Initialize complex number struct for result.
auto structType = typeConverter->convertType(op.getType());
auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
auto result = ComplexStructBuilder::poison(rewriter, loc, structType);

// Emit IR to add complex numbers.
arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
Expand Down Expand Up @@ -262,7 +263,7 @@ struct MulOpConversion : public ConvertOpToLLVMPattern<complex::MulOp> {

// Initialize complex number struct for result.
auto structType = typeConverter->convertType(op.getType());
auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
auto result = ComplexStructBuilder::poison(rewriter, loc, structType);

// Emit IR to add complex numbers.
arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
Expand Down Expand Up @@ -302,7 +303,7 @@ struct SubOpConversion : public ConvertOpToLLVMPattern<complex::SubOp> {

// Initialize complex number struct for result.
auto structType = typeConverter->convertType(op.getType());
auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
auto result = ComplexStructBuilder::poison(rewriter, loc, structType);

// Emit IR to substract complex numbers.
arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -660,7 +660,7 @@ struct UnrealizedConversionCastOpLowering
// `ReturnOp` interacts with the function signature and must have as many
// operands as the function has return values. Because in LLVM IR, functions
// can only return 0 or 1 value, we pack multiple values into a structure type.
// Emit `UndefOp` followed by `InsertValueOp`s to create such structure if
// Emit `PoisonOp` followed by `InsertValueOp`s to create such structure if
// necessary before returning it
struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
using ConvertOpToLLVMPattern<func::ReturnOp>::ConvertOpToLLVMPattern;
Expand Down Expand Up @@ -714,7 +714,7 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
return rewriter.notifyMatchFailure(op, "could not convert result types");
}

Value packed = rewriter.create<LLVM::UndefOp>(loc, packedType);
Value packed = rewriter.create<LLVM::PoisonOp>(loc, packedType);
for (auto [idx, operand] : llvm::enumerate(updatedOperands)) {
packed = rewriter.create<LLVM::InsertValueOp>(loc, packed, operand, idx);
}
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,7 @@ LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands,
return rewriter.notifyMatchFailure(op, "expected vector result");

Location loc = op->getLoc();
Value result = rewriter.create<LLVM::UndefOp>(loc, vectorType);
Value result = rewriter.create<LLVM::PoisonOp>(loc, vectorType);
Type indexType = converter.convertType(rewriter.getIndexType());
StringAttr name = op->getName().getIdentifier();
Type elementType = vectorType.getElementType();
Expand Down Expand Up @@ -771,7 +771,7 @@ LogicalResult GPUReturnOpLowering::matchAndRewrite(
return rewriter.notifyMatchFailure(op, "could not convert result types");
}

Value packed = rewriter.create<LLVM::UndefOp>(loc, packedType);
Value packed = rewriter.create<LLVM::PoisonOp>(loc, packedType);
for (auto [idx, operand] : llvm::enumerate(updatedOperands)) {
packed = rewriter.create<LLVM::InsertValueOp>(loc, packed, operand, idx);
}
Expand Down
6 changes: 3 additions & 3 deletions mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ struct WmmaConstantOpToNVVMLowering
cast<gpu::MMAMatrixType>(subgroupMmaConstantOp.getType()));
// If the element type is a vector create a vector from the operand.
if (auto vecType = dyn_cast<VectorType>(type.getBody()[0])) {
Value vecCst = rewriter.create<LLVM::UndefOp>(loc, vecType);
Value vecCst = rewriter.create<LLVM::PoisonOp>(loc, vecType);
for (int64_t vecEl = 0; vecEl < vecType.getNumElements(); vecEl++) {
Value idx = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getI32Type(), vecEl);
Expand All @@ -288,7 +288,7 @@ struct WmmaConstantOpToNVVMLowering
}
cst = vecCst;
}
Value matrixStruct = rewriter.create<LLVM::UndefOp>(loc, type);
Value matrixStruct = rewriter.create<LLVM::PoisonOp>(loc, type);
for (size_t i : llvm::seq(size_t(0), type.getBody().size())) {
matrixStruct =
rewriter.create<LLVM::InsertValueOp>(loc, matrixStruct, cst, i);
Expand Down Expand Up @@ -355,7 +355,7 @@ struct WmmaElementwiseOpToNVVMLowering
size_t numOperands = adaptor.getOperands().size();
LLVM::LLVMStructType destType = convertMMAToLLVMType(
cast<gpu::MMAMatrixType>(subgroupMmaElementwiseOp.getType()));
Value matrixStruct = rewriter.create<LLVM::UndefOp>(loc, destType);
Value matrixStruct = rewriter.create<LLVM::PoisonOp>(loc, destType);
for (size_t i = 0, e = destType.getBody().size(); i < e; ++i) {
SmallVector<Value> extractedOperands;
for (size_t opIdx = 0; opIdx < numOperands; opIdx++) {
Expand Down
20 changes: 10 additions & 10 deletions mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ MemRefDescriptor::MemRefDescriptor(Value descriptor)
}

/// Builds IR creating an `undef` value of the descriptor type.
MemRefDescriptor MemRefDescriptor::undef(OpBuilder &builder, Location loc,
Type descriptorType) {
MemRefDescriptor MemRefDescriptor::poison(OpBuilder &builder, Location loc,
Type descriptorType) {

Value descriptor = builder.create<LLVM::UndefOp>(loc, descriptorType);
Value descriptor = builder.create<LLVM::PoisonOp>(loc, descriptorType);
return MemRefDescriptor(descriptor);
}

Expand Down Expand Up @@ -60,7 +60,7 @@ MemRefDescriptor MemRefDescriptor::fromStaticShape(
auto convertedType = typeConverter.convertType(type);
assert(convertedType && "unexpected failure in memref type conversion");

auto descr = MemRefDescriptor::undef(builder, loc, convertedType);
auto descr = MemRefDescriptor::poison(builder, loc, convertedType);
descr.setAllocatedPtr(builder, loc, memory);
descr.setAlignedPtr(builder, loc, alignedMemory);
descr.setConstantOffset(builder, loc, offset);
Expand Down Expand Up @@ -224,7 +224,7 @@ Value MemRefDescriptor::pack(OpBuilder &builder, Location loc,
const LLVMTypeConverter &converter,
MemRefType type, ValueRange values) {
Type llvmType = converter.convertType(type);
auto d = MemRefDescriptor::undef(builder, loc, llvmType);
auto d = MemRefDescriptor::poison(builder, loc, llvmType);

d.setAllocatedPtr(builder, loc, values[kAllocatedPtrPosInMemRefDescriptor]);
d.setAlignedPtr(builder, loc, values[kAlignedPtrPosInMemRefDescriptor]);
Expand Down Expand Up @@ -300,10 +300,10 @@ UnrankedMemRefDescriptor::UnrankedMemRefDescriptor(Value descriptor)
: StructBuilder(descriptor) {}

/// Builds IR creating an `undef` value of the descriptor type.
UnrankedMemRefDescriptor UnrankedMemRefDescriptor::undef(OpBuilder &builder,
Location loc,
Type descriptorType) {
Value descriptor = builder.create<LLVM::UndefOp>(loc, descriptorType);
UnrankedMemRefDescriptor UnrankedMemRefDescriptor::poison(OpBuilder &builder,
Location loc,
Type descriptorType) {
Value descriptor = builder.create<LLVM::PoisonOp>(loc, descriptorType);
return UnrankedMemRefDescriptor(descriptor);
}
Value UnrankedMemRefDescriptor::rank(OpBuilder &builder, Location loc) const {
Expand Down Expand Up @@ -331,7 +331,7 @@ Value UnrankedMemRefDescriptor::pack(OpBuilder &builder, Location loc,
UnrankedMemRefType type,
ValueRange values) {
Type llvmType = converter.convertType(type);
auto d = UnrankedMemRefDescriptor::undef(builder, loc, llvmType);
auto d = UnrankedMemRefDescriptor::poison(builder, loc, llvmType);

d.setRank(builder, loc, values[kRankInUnrankedMemRefDescriptor]);
d.setMemRefDescPtr(builder, loc, values[kPtrInUnrankedMemRefDescriptor]);
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Conversion/LLVMCommon/Pattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ MemRefDescriptor ConvertToLLVMPattern::createMemRefDescriptor(
ArrayRef<Value> sizes, ArrayRef<Value> strides,
ConversionPatternRewriter &rewriter) const {
auto structType = typeConverter->convertType(memRefType);
auto memRefDescriptor = MemRefDescriptor::undef(rewriter, loc, structType);
auto memRefDescriptor = MemRefDescriptor::poison(rewriter, loc, structType);

// Field 1: Allocated pointer, used for malloc/free.
memRefDescriptor.setAllocatedPtr(rewriter, loc, allocatedPtr);
Expand Down Expand Up @@ -319,7 +319,7 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
if (!descriptorType)
return failure();
auto updatedDesc =
UnrankedMemRefDescriptor::undef(builder, loc, descriptorType);
UnrankedMemRefDescriptor::poison(builder, loc, descriptorType);
Value rank = desc.rank(builder, loc);
updatedDesc.setRank(builder, loc, rank);
updatedDesc.setMemRefDescPtr(builder, loc, memory);
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ LogicalResult LLVM::detail::handleMultidimensionalVectors(
auto result1DVectorTy = resultTypeInfo.llvm1DVectorTy;
auto resultNDVectoryTy = resultTypeInfo.llvmNDVectorTy;
auto loc = op->getLoc();
Value desc = rewriter.create<LLVM::UndefOp>(loc, resultNDVectoryTy);
Value desc = rewriter.create<LLVM::PoisonOp>(loc, resultNDVectoryTy);
nDVectorIterate(resultTypeInfo, rewriter, [&](ArrayRef<int64_t> position) {
// For this unrolled `position` corresponding to the `linearIndex`^th
// element, extract operand vectors
Expand Down
18 changes: 9 additions & 9 deletions mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -714,10 +714,10 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
// rank = ConstantOp srcRank
auto rankVal = rewriter.create<LLVM::ConstantOp>(
loc, getIndexType(), rewriter.getIndexAttr(rank));
// undef = UndefOp
// poison = PoisonOp
UnrankedMemRefDescriptor memRefDesc =
UnrankedMemRefDescriptor::undef(rewriter, loc, targetStructType);
// d1 = InsertValueOp undef, rank, 0
UnrankedMemRefDescriptor::poison(rewriter, loc, targetStructType);
// d1 = InsertValueOp poison, rank, 0
memRefDesc.setRank(rewriter, loc, rankVal);
// d2 = InsertValueOp d1, ptr, 1
memRefDesc.setMemRefDescPtr(rewriter, loc, ptr);
Expand Down Expand Up @@ -928,7 +928,7 @@ struct MemorySpaceCastOpLowering
Value sourceUnderlyingDesc = sourceDesc.memRefDescPtr(rewriter, loc);

// Create and allocate storage for new memref descriptor.
auto result = UnrankedMemRefDescriptor::undef(
auto result = UnrankedMemRefDescriptor::poison(
rewriter, loc, typeConverter->convertType(resultTypeU));
result.setRank(rewriter, loc, rank);
SmallVector<Value, 1> sizes;
Expand Down Expand Up @@ -1058,7 +1058,7 @@ struct MemRefReinterpretCastOpLowering

// Create descriptor.
Location loc = castOp.getLoc();
auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
auto desc = MemRefDescriptor::poison(rewriter, loc, llvmTargetDescriptorTy);

// Set allocated and aligned pointers.
Value allocatedPtr, alignedPtr;
Expand Down Expand Up @@ -1128,7 +1128,7 @@ struct MemRefReshapeOpLowering
// Create descriptor.
Location loc = reshapeOp.getLoc();
auto desc =
MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
MemRefDescriptor::poison(rewriter, loc, llvmTargetDescriptorTy);

// Set allocated and aligned pointers.
Value allocatedPtr, alignedPtr;
Expand Down Expand Up @@ -1210,7 +1210,7 @@ struct MemRefReshapeOpLowering

// Create the unranked memref descriptor that holds the ranked one. The
// inner descriptor is allocated on stack.
auto targetDesc = UnrankedMemRefDescriptor::undef(
auto targetDesc = UnrankedMemRefDescriptor::poison(
rewriter, loc, typeConverter->convertType(targetType));
targetDesc.setRank(rewriter, loc, resultRank);
SmallVector<Value, 4> sizes;
Expand Down Expand Up @@ -1366,7 +1366,7 @@ class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> {
if (transposeOp.getPermutation().isIdentity())
return rewriter.replaceOp(transposeOp, {viewMemRef}), success();

auto targetMemRef = MemRefDescriptor::undef(
auto targetMemRef = MemRefDescriptor::poison(
rewriter, loc,
typeConverter->convertType(transposeOp.getIn().getType()));

Expand Down Expand Up @@ -1469,7 +1469,7 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {

// Create the descriptor.
MemRefDescriptor sourceMemRef(adaptor.getSource());
auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
auto targetMemRef = MemRefDescriptor::poison(rewriter, loc, targetDescTy);

// Field 1: Copy the allocated pointer, used for malloc/free.
Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc);
Expand Down
Loading