Skip to content

Commit 041f1ab

Browse files
committed
[mlir][memref] Fix num elements in lowering of memref.alloca op to LLVM
Fixes a mistake in the lowering of memref.alloca to llvm.alloca, as llvm.alloca uses the number of elements to allocate in the stack and not the size in bytes. Reference: LLVM IR: https://llvm.org/docs/LangRef.html#alloca-instruction LLVM MLIR: https://mlir.llvm.org/docs/Dialects/LLVM/#llvmalloca-mlirllvmallocaop Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D150705
1 parent 986cbd8 commit 041f1ab

File tree

8 files changed

+99
-53
lines changed

8 files changed

+99
-53
lines changed

mlir/include/mlir/Conversion/LLVMCommon/Pattern.h

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -82,12 +82,12 @@ class ConvertToLLVMPattern : public ConversionPattern {
8282
/// Returns the type of a pointer to an element of the memref.
8383
Type getElementPtrType(MemRefType type) const;
8484

85-
/// Computes sizes, strides and buffer size in bytes of `memRefType` with
86-
/// identity layout. Emits constant ops for the static sizes of `memRefType`,
87-
/// and uses `dynamicSizes` for the others. Emits instructions to compute
88-
/// strides and buffer size from these sizes.
85+
/// Computes sizes, strides and buffer size of `memRefType` with identity
86+
/// layout. Emits constant ops for the static sizes of `memRefType`, and uses
87+
/// `dynamicSizes` for the others. Emits instructions to compute strides and
88+
/// buffer size from these sizes.
8989
///
90-
/// For example, memref<4x?xf32> emits:
90+
/// For example, memref<4x?xf32> with `sizeInBytes = true` emits:
9191
/// `sizes[0]` = llvm.mlir.constant(4 : index) : i64
9292
/// `sizes[1]` = `dynamicSizes[0]`
9393
/// `strides[1]` = llvm.mlir.constant(1 : index) : i64
@@ -97,19 +97,27 @@ class ConvertToLLVMPattern : public ConversionPattern {
9797
/// %gep = llvm.getelementptr %nullptr[%size]
9898
/// : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
9999
/// `sizeBytes` = llvm.ptrtoint %gep : !llvm.ptr<f32> to i64
100+
///
101+
/// If `sizeInBytes = false`, memref<4x?xf32> emits:
102+
/// `sizes[0]` = llvm.mlir.constant(4 : index) : i64
103+
/// `sizes[1]` = `dynamicSizes[0]`
104+
/// `strides[1]` = llvm.mlir.constant(1 : index) : i64
105+
/// `strides[0]` = `sizes[0]`
106+
/// %size = llvm.mul `sizes[0]`, `sizes[1]` : i64
100107
void getMemRefDescriptorSizes(Location loc, MemRefType memRefType,
101108
ValueRange dynamicSizes,
102109
ConversionPatternRewriter &rewriter,
103110
SmallVectorImpl<Value> &sizes,
104-
SmallVectorImpl<Value> &strides,
105-
Value &sizeBytes) const;
111+
SmallVectorImpl<Value> &strides, Value &size,
112+
bool sizeInBytes = true) const;
106113

107114
/// Computes the size of type in bytes.
108115
Value getSizeInBytes(Location loc, Type type,
109116
ConversionPatternRewriter &rewriter) const;
110117

111-
/// Computes total number of elements for the given shape.
112-
Value getNumElements(Location loc, ArrayRef<Value> shape,
118+
/// Computes total number of elements for the given MemRef and dynamicSizes.
119+
Value getNumElements(Location loc, MemRefType memRefType,
120+
ValueRange dynamicSizes,
113121
ConversionPatternRewriter &rewriter) const;
114122

115123
/// Creates and populates a canonical memref descriptor struct.

mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@ struct AllocationOpLLVMLowering : public ConvertToLLVMPattern {
2020
using ConvertToLLVMPattern::getVoidPtrType;
2121

2222
explicit AllocationOpLLVMLowering(StringRef opName,
23-
LLVMTypeConverter &converter)
24-
: ConvertToLLVMPattern(opName, &converter.getContext(), converter) {}
23+
LLVMTypeConverter &converter,
24+
PatternBenefit benefit = 1)
25+
: ConvertToLLVMPattern(opName, &converter.getContext(), converter,
26+
benefit) {}
2527

2628
protected:
2729
/// Computes the aligned value for 'input' as follows:
@@ -103,15 +105,20 @@ struct AllocationOpLLVMLowering : public ConvertToLLVMPattern {
103105
/// Lowering for AllocOp and AllocaOp.
104106
struct AllocLikeOpLLVMLowering : public AllocationOpLLVMLowering {
105107
explicit AllocLikeOpLLVMLowering(StringRef opName,
106-
LLVMTypeConverter &converter)
107-
: AllocationOpLLVMLowering(opName, converter) {}
108+
LLVMTypeConverter &converter,
109+
PatternBenefit benefit = 1)
110+
: AllocationOpLLVMLowering(opName, converter, benefit) {}
108111

109112
protected:
110113
/// Allocates the underlying buffer. Returns the allocated pointer and the
111114
/// aligned pointer.
112115
virtual std::tuple<Value, Value>
113-
allocateBuffer(ConversionPatternRewriter &rewriter, Location loc,
114-
Value sizeBytes, Operation *op) const = 0;
116+
allocateBuffer(ConversionPatternRewriter &rewriter, Location loc, Value size,
117+
Operation *op) const = 0;
118+
119+
/// Sets the flag 'requiresNumElements', specifying the Op requires the number
120+
/// of elements instead of the size in bytes.
121+
void setRequiresNumElements();
115122

116123
private:
117124
// An `alloc` is converted into a definition of a memref descriptor value and
@@ -133,6 +140,10 @@ struct AllocLikeOpLLVMLowering : public AllocationOpLLVMLowering {
133140
LogicalResult
134141
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
135142
ConversionPatternRewriter &rewriter) const override;
143+
144+
// Flag for specifying the Op requires the number of elements instead of the
145+
// size in bytes.
146+
bool requiresNumElements = false;
136147
};
137148

138149
} // namespace mlir

mlir/lib/Conversion/LLVMCommon/Pattern.cpp

Lines changed: 38 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ Type ConvertToLLVMPattern::getElementPtrType(MemRefType type) const {
121121
void ConvertToLLVMPattern::getMemRefDescriptorSizes(
122122
Location loc, MemRefType memRefType, ValueRange dynamicSizes,
123123
ConversionPatternRewriter &rewriter, SmallVectorImpl<Value> &sizes,
124-
SmallVectorImpl<Value> &strides, Value &sizeBytes) const {
124+
SmallVectorImpl<Value> &strides, Value &size, bool sizeInBytes) const {
125125
assert(isConvertibleAndHasIdentityMaps(memRefType) &&
126126
"layout maps must have been normalized away");
127127
assert(count(memRefType.getShape(), ShapedType::kDynamic) ==
@@ -143,14 +143,14 @@ void ConvertToLLVMPattern::getMemRefDescriptorSizes(
143143
for (auto i = memRefType.getRank(); i-- > 0;) {
144144
strides[i] = runningStride;
145145

146-
int64_t size = memRefType.getShape()[i];
147-
if (size == 0)
146+
int64_t staticSize = memRefType.getShape()[i];
147+
if (staticSize == 0)
148148
continue;
149149
bool useSizeAsStride = stride == 1;
150-
if (size == ShapedType::kDynamic)
150+
if (staticSize == ShapedType::kDynamic)
151151
stride = ShapedType::kDynamic;
152152
if (stride != ShapedType::kDynamic)
153-
stride *= size;
153+
stride *= staticSize;
154154

155155
if (useSizeAsStride)
156156
runningStride = sizes[i];
@@ -160,14 +160,17 @@ void ConvertToLLVMPattern::getMemRefDescriptorSizes(
160160
else
161161
runningStride = createIndexConstant(rewriter, loc, stride);
162162
}
163-
164-
// Buffer size in bytes.
165-
Type elementType = typeConverter->convertType(memRefType.getElementType());
166-
Type elementPtrType = getTypeConverter()->getPointerType(elementType);
167-
Value nullPtr = rewriter.create<LLVM::NullOp>(loc, elementPtrType);
168-
Value gepPtr = rewriter.create<LLVM::GEPOp>(loc, elementPtrType, elementType,
169-
nullPtr, runningStride);
170-
sizeBytes = rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr);
163+
if (sizeInBytes) {
164+
// Buffer size in bytes.
165+
Type elementType = typeConverter->convertType(memRefType.getElementType());
166+
Type elementPtrType = getTypeConverter()->getPointerType(elementType);
167+
Value nullPtr = rewriter.create<LLVM::NullOp>(loc, elementPtrType);
168+
Value gepPtr = rewriter.create<LLVM::GEPOp>(
169+
loc, elementPtrType, elementType, nullPtr, runningStride);
170+
size = rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr);
171+
} else {
172+
size = runningStride;
173+
}
171174
}
172175

173176
Value ConvertToLLVMPattern::getSizeInBytes(
@@ -186,13 +189,30 @@ Value ConvertToLLVMPattern::getSizeInBytes(
186189
}
187190

188191
Value ConvertToLLVMPattern::getNumElements(
189-
Location loc, ArrayRef<Value> shape,
192+
Location loc, MemRefType memRefType, ValueRange dynamicSizes,
190193
ConversionPatternRewriter &rewriter) const {
194+
assert(count(memRefType.getShape(), ShapedType::kDynamic) ==
195+
static_cast<ssize_t>(dynamicSizes.size()) &&
196+
"dynamicSizes size doesn't match dynamic sizes count in memref shape");
197+
198+
Value numElements = memRefType.getRank() == 0
199+
? createIndexConstant(rewriter, loc, 1)
200+
: nullptr;
201+
unsigned dynamicIndex = 0;
202+
191203
// Compute the total number of memref elements.
192-
Value numElements =
193-
shape.empty() ? createIndexConstant(rewriter, loc, 1) : shape.front();
194-
for (unsigned i = 1, e = shape.size(); i < e; ++i)
195-
numElements = rewriter.create<LLVM::MulOp>(loc, numElements, shape[i]);
204+
for (int64_t staticSize : memRefType.getShape()) {
205+
if (numElements) {
206+
Value size = staticSize == ShapedType::kDynamic
207+
? dynamicSizes[dynamicIndex++]
208+
: createIndexConstant(rewriter, loc, staticSize);
209+
numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size);
210+
} else {
211+
numElements = staticSize == ShapedType::kDynamic
212+
? dynamicSizes[dynamicIndex++]
213+
: createIndexConstant(rewriter, loc, staticSize);
214+
}
215+
}
196216
return numElements;
197217
}
198218

mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,10 @@ Value AllocationOpLLVMLowering::allocateBufferAutoAlign(
156156
elementPtrType, *getTypeConverter());
157157
}
158158

159+
void AllocLikeOpLLVMLowering::setRequiresNumElements() {
160+
requiresNumElements = true;
161+
}
162+
159163
LogicalResult AllocLikeOpLLVMLowering::matchAndRewrite(
160164
Operation *op, ArrayRef<Value> operands,
161165
ConversionPatternRewriter &rewriter) const {
@@ -169,13 +173,14 @@ LogicalResult AllocLikeOpLLVMLowering::matchAndRewrite(
169173
// zero-dimensional memref, assume a scalar (size 1).
170174
SmallVector<Value, 4> sizes;
171175
SmallVector<Value, 4> strides;
172-
Value sizeBytes;
176+
Value size;
177+
173178
this->getMemRefDescriptorSizes(loc, memRefType, operands, rewriter, sizes,
174-
strides, sizeBytes);
179+
strides, size, !requiresNumElements);
175180

176181
// Allocate the underlying buffer.
177182
auto [allocatedPtr, alignedPtr] =
178-
this->allocateBuffer(rewriter, loc, sizeBytes, op);
183+
this->allocateBuffer(rewriter, loc, size, op);
179184

180185
// Create the MemRef descriptor.
181186
auto memRefDescriptor = this->createMemRefDescriptor(

mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,13 +85,15 @@ struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering {
8585
struct AllocaOpLowering : public AllocLikeOpLLVMLowering {
8686
AllocaOpLowering(LLVMTypeConverter &converter)
8787
: AllocLikeOpLLVMLowering(memref::AllocaOp::getOperationName(),
88-
converter) {}
88+
converter) {
89+
setRequiresNumElements();
90+
}
8991

9092
/// Allocates the underlying buffer using the right call. `allocatedBytePtr`
9193
/// is set to null for stack allocations. `accessAlignment` is set if
9294
/// alignment is needed post allocation (for eg. in conjunction with malloc).
9395
std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
94-
Location loc, Value sizeBytes,
96+
Location loc, Value size,
9597
Operation *op) const override {
9698

9799
// With alloca, one gets a pointer to the element type right away.
@@ -104,9 +106,9 @@ struct AllocaOpLowering : public AllocLikeOpLLVMLowering {
104106
auto elementPtrType =
105107
getTypeConverter()->getPointerType(elementType, addrSpace);
106108

107-
auto allocatedElementPtr = rewriter.create<LLVM::AllocaOp>(
108-
loc, elementPtrType, elementType, sizeBytes,
109-
allocaOp.getAlignment().value_or(0));
109+
auto allocatedElementPtr =
110+
rewriter.create<LLVM::AllocaOp>(loc, elementPtrType, elementType, size,
111+
allocaOp.getAlignment().value_or(0));
110112

111113
return std::make_tuple(allocatedElementPtr, allocatedElementPtr);
112114
}

mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,12 +91,15 @@ gpu.module @test_module {
9191
%j = arith.constant 16 : index
9292
gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index, transpose} : !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, 3>
9393
// CHECK: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i64
94-
// CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
94+
// CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
95+
// CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
96+
// CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
97+
// CHECK: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
9598
// CHECK: %[[EL1:.*]] = llvm.extractvalue %[[D]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
9699
// CHECK: %[[EL2:.*]] = llvm.extractvalue %[[D]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
97100
// CHECK: %[[EL3:.*]] = llvm.extractvalue %[[D]][2] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
98101
// CHECK: %[[EL4:.*]] = llvm.extractvalue %[[D]][3] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
99-
// CHECK: %[[BASE:.*]] = llvm.extractvalue %17[1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
102+
// CHECK: %[[BASE:.*]] = llvm.extractvalue %[[MEMREF]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
100103
// CHECK: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i64
101104
// CHECK: %[[LI:.*]] = llvm.mul %[[INX]], %[[LDM]] : i64
102105
// CHECK: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]] : i64
@@ -107,12 +110,15 @@ gpu.module @test_module {
107110
// CHECK: llvm.return
108111

109112
// CHECK32: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i32
110-
// CHECK32: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
113+
// CHECK32: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
114+
// CHECK32: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
115+
// CHECK32: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
116+
// CHECK32: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
111117
// CHECK32: %[[EL1:.*]] = llvm.extractvalue %[[D]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
112118
// CHECK32: %[[EL2:.*]] = llvm.extractvalue %[[D]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
113119
// CHECK32: %[[EL3:.*]] = llvm.extractvalue %[[D]][2] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
114120
// CHECK32: %[[EL4:.*]] = llvm.extractvalue %[[D]][3] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
115-
// CHECK32: %[[BASE:.*]] = llvm.extractvalue %17[1] : !llvm.struct<(ptr<3>, ptr<3>, i32, array<2 x i32>, array<2 x i32>)>
121+
// CHECK32: %[[BASE:.*]] = llvm.extractvalue %[[MEMREF]][1] : !llvm.struct<(ptr<3>, ptr<3>, i32, array<2 x i32>, array<2 x i32>)>
116122
// CHECK32: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i32
117123
// CHECK32: %[[LI:.*]] = llvm.mul %[[INX]], %[[LDM]] : i32
118124
// CHECK32: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]] : i32

mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,7 @@ func.func @dynamic_alloca(%arg0: index, %arg1: index) -> memref<?x?xf32> {
8686
// CHECK-DAG: %[[N:.*]] = builtin.unrealized_conversion_cast %[[Narg]]
8787
// CHECK-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : i64
8888
// CHECK-NEXT: %[[num_elems:.*]] = llvm.mul %[[N]], %[[M]] : i64
89-
// CHECK-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm.ptr
90-
// CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[num_elems]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
91-
// CHECK-NEXT: %[[sz_bytes:.*]] = llvm.ptrtoint %[[gep]] : !llvm.ptr to i64
92-
// CHECK-NEXT: %[[allocated:.*]] = llvm.alloca %[[sz_bytes]] x f32 : (i64) -> !llvm.ptr
89+
// CHECK-NEXT: %[[allocated:.*]] = llvm.alloca %[[num_elems]] x f32 : (i64) -> !llvm.ptr
9390
// CHECK-NEXT: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
9491
// CHECK-NEXT: llvm.insertvalue %[[allocated]], %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
9592
// CHECK-NEXT: llvm.insertvalue %[[allocated]], %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>

mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,7 @@ func.func @static_alloca() -> memref<32x18xf32> {
7979
// CHECK: %[[sz2:.*]] = llvm.mlir.constant(18 : index) : i64
8080
// CHECK: %[[st2:.*]] = llvm.mlir.constant(1 : index) : i64
8181
// CHECK: %[[num_elems:.*]] = llvm.mlir.constant(576 : index) : i64
82-
// CHECK: %[[null:.*]] = llvm.mlir.null : !llvm.ptr
83-
// CHECK: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[num_elems]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
84-
// CHECK: %[[size_bytes:.*]] = llvm.ptrtoint %[[gep]] : !llvm.ptr to i64
85-
// CHECK: %[[allocated:.*]] = llvm.alloca %[[size_bytes]] x f32 : (i64) -> !llvm.ptr
82+
// CHECK: %[[allocated:.*]] = llvm.alloca %[[num_elems]] x f32 : (i64) -> !llvm.ptr
8683
%0 = memref.alloca() : memref<32x18xf32>
8784

8885
// Test with explicitly specified alignment. llvm.alloca takes care of the

0 commit comments

Comments
 (0)