Skip to content

Commit e216a72

Browse files
ftynsetensorflower-gardener
authored andcommitted
Add conversions of GPU func with memory attributions to LLVM/NVVM
GPU functions use memory attributions, a combination of Op attributes and region arguments, to specify function-wide buffers placed in workgroup or private memory spaces. Introduce a lowering pattern for GPU functions to be converted to LLVM functions taking into account memory attributions. Workgroup attributions get transformed into module-level globals with unique names derived from function names. Private attributions get converted into llvm.allocas inside the function body. In both cases, we inject at the beginning of the function the IR that obtains the raw pointer to the data and populates a MemRef descriptor based on the MemRef type of buffer, making attributions compose with the rest of the MemRef lowering and transparent for use with std.load and std.store. While using raw pointers instead of descriptors might have been more efficient, it is better implemented as a canonicalization or a separate transformation so that non-attribution memrefs could also benefit from it. PiperOrigin-RevId: 284208396
1 parent 3c69ca1 commit e216a72

File tree

5 files changed

+371
-9
lines changed

5 files changed

+371
-9
lines changed

mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,13 @@ class MemRefDescriptor : public StructBuilder {
168168
/// Builds IR creating an `undef` value of the descriptor type.
169169
static MemRefDescriptor undef(OpBuilder &builder, Location loc,
170170
Type descriptorType);
171+
/// Builds IR creating a MemRef descriptor that represents `type` and
172+
/// populates it with static shape and stride information extracted from the
173+
/// type.
174+
static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc,
175+
LLVMTypeConverter &typeConverter,
176+
MemRefType type, Value *memory);
177+
171178
/// Builds IR extracting the allocated pointer from the descriptor.
172179
Value *allocatedPtr(OpBuilder &builder, Location loc);
173180
/// Builds IR inserting the allocated pointer into the descriptor.
@@ -184,18 +191,23 @@ class MemRefDescriptor : public StructBuilder {
184191

185192
/// Builds IR inserting the offset into the descriptor.
186193
void setOffset(OpBuilder &builder, Location loc, Value *offset);
194+
void setConstantOffset(OpBuilder &builder, Location loc, uint64_t offset);
187195

188196
/// Builds IR extracting the pos-th size from the descriptor.
189197
Value *size(OpBuilder &builder, Location loc, unsigned pos);
190198

191199
/// Builds IR inserting the pos-th size into the descriptor
192200
void setSize(OpBuilder &builder, Location loc, unsigned pos, Value *size);
201+
void setConstantSize(OpBuilder &builder, Location loc, unsigned pos,
202+
uint64_t size);
193203

194204
/// Builds IR extracting the pos-th size from the descriptor.
195205
Value *stride(OpBuilder &builder, Location loc, unsigned pos);
196206

197207
/// Builds IR inserting the pos-th stride into the descriptor
198208
void setStride(OpBuilder &builder, Location loc, unsigned pos, Value *stride);
209+
void setConstantStride(OpBuilder &builder, Location loc, unsigned pos,
210+
uint64_t stride);
199211

200212
/// Returns the (LLVM) type this descriptor points to.
201213
LLVM::LLVMType getElementType();

mlir/include/mlir/Dialect/GPU/GPUDialect.h

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,10 @@ class GPUDialect : public Dialect {
6161
/// 'gpu.kernel' attribute.
6262
static bool isKernel(Operation *op);
6363

64+
/// Returns the numeric value used to identify the workgroup memory address
65+
/// space.
66+
static int getWorkgroupAddressSpace() { return 3; }
67+
6468
LogicalResult verifyOperationAttribute(Operation *op,
6569
NamedAttribute attr) override;
6670
};
@@ -249,6 +253,12 @@ class GPUFuncOp : public Op<GPUFuncOp, OpTrait::FunctionLike,
249253
return {begin, getBody().front().args_end()};
250254
}
251255

256+
/// Returns the name of the attribute containing the number of buffers located
257+
/// in the workgroup memory.
258+
static StringRef getNumWorkgroupAttributionsAttrName() {
259+
return "workgroup_attibutions";
260+
}
261+
252262
private:
253263
// FunctionLike trait needs access to the functions below.
254264
friend class OpTrait::FunctionLike<GPUFuncOp>;
@@ -257,12 +267,6 @@ class GPUFuncOp : public Op<GPUFuncOp, OpTrait::FunctionLike,
257267
unsigned getNumFuncArguments() { return getType().getNumInputs(); }
258268
unsigned getNumFuncResults() { return getType().getNumResults(); }
259269

260-
/// Returns the name of the attribute containing the number of buffers located
261-
/// in the workgroup memory.
262-
static StringRef getNumWorkgroupAttributionsAttrName() {
263-
return "workgroup_attibutions";
264-
}
265-
266270
/// Returns the keywords used in the custom syntax for this Op.
267271
static StringRef getWorkgroupKeyword() { return "workgroup"; }
268272
static StringRef getPrivateKeyword() { return "private"; }

mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp

Lines changed: 144 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
#include "mlir/Pass/Pass.h"
3030
#include "mlir/Transforms/DialectConversion.h"
3131

32+
#include "llvm/Support/FormatVariadic.h"
33+
3234
#include "../GPUCommon/IndexIntrinsicsOpLowering.h"
3335
#include "../GPUCommon/OpToFuncCallLowering.h"
3436

@@ -451,6 +453,146 @@ struct GPUAllReduceOpLowering : public LLVMOpLowering {
451453
static constexpr int kWarpSize = 32;
452454
};
453455

456+
namespace {
457+
458+
struct FuncOpLowering : LLVMOpLowering {
459+
explicit FuncOpLowering(LLVMTypeConverter &typeConverter)
460+
: LLVMOpLowering(gpu::GPUFuncOp::getOperationName(),
461+
typeConverter.getDialect()->getContext(),
462+
typeConverter) {}
463+
464+
PatternMatchResult
465+
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
466+
ConversionPatternRewriter &rewriter) const override {
467+
assert(operands.empty() && "func op is not expected to have operands");
468+
auto gpuFuncOp = cast<gpu::GPUFuncOp>(op);
469+
Location loc = gpuFuncOp.getLoc();
470+
471+
SmallVector<LLVM::GlobalOp, 3> workgroupBuffers;
472+
workgroupBuffers.reserve(gpuFuncOp.getNumWorkgroupAttributions());
473+
for (auto en : llvm::enumerate(gpuFuncOp.getWorkgroupAttributions())) {
474+
Value *attribution = en.value();
475+
476+
auto type = attribution->getType().dyn_cast<MemRefType>();
477+
assert(type && type.hasStaticShape() && "unexpected type in attribution");
478+
479+
uint64_t numElements = type.getNumElements();
480+
481+
auto elementType =
482+
lowering.convertType(type.getElementType()).cast<LLVM::LLVMType>();
483+
auto arrayType = LLVM::LLVMType::getArrayTy(elementType, numElements);
484+
auto addSpaceAttr = rewriter.getNamedAttr(
485+
"addr_space", rewriter.getI32IntegerAttr(
486+
gpu::GPUDialect::getWorkgroupAddressSpace()));
487+
std::string name =
488+
llvm::formatv("__wg_{0}_{1}", gpuFuncOp.getName(), en.index());
489+
auto globalOp = rewriter.create<LLVM::GlobalOp>(
490+
gpuFuncOp.getLoc(), arrayType, /*isConstant=*/false,
491+
LLVM::Linkage::Internal, name, /*value=*/Attribute(),
492+
llvm::makeArrayRef(addSpaceAttr));
493+
workgroupBuffers.push_back(globalOp);
494+
}
495+
496+
// Rewrite the original GPU function to an LLVM function.
497+
// TODO(zinenko): there is a hack in the std->llvm lowering that promotes
498+
// structs to pointers that probably needs to be replicated here.
499+
auto funcType = lowering.convertType(gpuFuncOp.getType())
500+
.cast<LLVM::LLVMType>()
501+
.getPointerElementTy();
502+
503+
// Remap proper input types.
504+
TypeConverter::SignatureConversion signatureConversion(
505+
gpuFuncOp.front().getNumArguments());
506+
for (unsigned i = 0, e = funcType.getFunctionNumParams(); i < e; ++i)
507+
signatureConversion.addInputs(i, funcType.getFunctionParamType(i));
508+
509+
// Create the new function operation. Only copy those attributes that are
510+
// not specific to function modeling.
511+
SmallVector<NamedAttribute, 4> attributes;
512+
for (const auto &attr : gpuFuncOp.getAttrs()) {
513+
if (attr.first.is(SymbolTable::getSymbolAttrName()) ||
514+
attr.first.is(impl::getTypeAttrName()) ||
515+
attr.first.is(gpu::GPUFuncOp::getNumWorkgroupAttributionsAttrName()))
516+
continue;
517+
attributes.push_back(attr);
518+
}
519+
auto llvmFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
520+
gpuFuncOp.getLoc(), gpuFuncOp.getName(), funcType,
521+
LLVM::Linkage::External, attributes);
522+
523+
{
524+
// Insert operations that correspond to converted workgroup and private
525+
// memory attributions to the body of the function. This must operate on
526+
// the original function, before the body region is inlined in the new
527+
// function to maintain the relation between block arguments and the
528+
// parent operation that assigns their semantics.
529+
OpBuilder::InsertionGuard guard(rewriter);
530+
531+
// Rewrite workgroup memory attributions to addresses of global buffers.
532+
rewriter.setInsertionPointToStart(&gpuFuncOp.front());
533+
unsigned numProperArguments = gpuFuncOp.getNumArguments();
534+
auto i32Type = LLVM::LLVMType::getInt32Ty(lowering.getDialect());
535+
536+
Value *zero = nullptr;
537+
if (!workgroupBuffers.empty())
538+
zero = rewriter.create<LLVM::ConstantOp>(loc, i32Type,
539+
rewriter.getI32IntegerAttr(0));
540+
for (auto en : llvm::enumerate(workgroupBuffers)) {
541+
LLVM::GlobalOp global = en.value();
542+
Value *address = rewriter.create<LLVM::AddressOfOp>(loc, global);
543+
auto elementType = global.getType().getArrayElementType();
544+
Value *memory = rewriter.create<LLVM::GEPOp>(
545+
loc, elementType.getPointerTo(global.addr_space().getZExtValue()),
546+
address, ArrayRef<Value *>{zero, zero});
547+
548+
// Build a memref descriptor pointing to the buffer to plug with the
549+
// existing memref infrastructure. This may use more registers than
550+
// otherwise necessary given that memref sizes are fixed, but we can try
551+
// and canonicalize that away later.
552+
Value *attribution = gpuFuncOp.getWorkgroupAttributions()[en.index()];
553+
auto type = attribution->getType().cast<MemRefType>();
554+
auto descr = MemRefDescriptor::fromStaticShape(rewriter, loc, lowering,
555+
type, memory);
556+
signatureConversion.remapInput(numProperArguments + en.index(), descr);
557+
}
558+
559+
// Rewrite private memory attributions to alloca'ed buffers.
560+
unsigned numWorkgroupAttributions =
561+
gpuFuncOp.getNumWorkgroupAttributions();
562+
auto int64Ty = LLVM::LLVMType::getInt64Ty(lowering.getDialect());
563+
for (auto en : llvm::enumerate(gpuFuncOp.getPrivateAttributions())) {
564+
Value *attribution = en.value();
565+
auto type = attribution->getType().cast<MemRefType>();
566+
assert(type && type.hasStaticShape() &&
567+
"unexpected type in attribution");
568+
569+
auto ptrType = lowering.convertType(type.getElementType())
570+
.cast<LLVM::LLVMType>()
571+
.getPointerTo(type.getMemorySpace());
572+
Value *numElements = rewriter.create<LLVM::ConstantOp>(
573+
gpuFuncOp.getLoc(), int64Ty,
574+
rewriter.getI64IntegerAttr(type.getNumElements()));
575+
Value *allocated = rewriter.create<LLVM::AllocaOp>(
576+
gpuFuncOp.getLoc(), ptrType, numElements, /*alignment=*/0);
577+
auto descr = MemRefDescriptor::fromStaticShape(rewriter, loc, lowering,
578+
type, allocated);
579+
signatureConversion.remapInput(
580+
numProperArguments + numWorkgroupAttributions + en.index(), descr);
581+
}
582+
}
583+
584+
rewriter.inlineRegionBefore(gpuFuncOp.getBody(), llvmFuncOp.getBody(),
585+
llvmFuncOp.end());
586+
rewriter.applySignatureConversion(&llvmFuncOp.getBody(),
587+
signatureConversion);
588+
589+
rewriter.eraseOp(gpuFuncOp);
590+
return matchSuccess();
591+
}
592+
};
593+
594+
} // end namespace
595+
454596
/// Import the GPU Ops to NVVM Patterns.
455597
#include "GPUToNVVM.cpp.inc"
456598

@@ -479,12 +621,13 @@ class LowerGpuOpsToNVVMOpsPass : public ModulePass<LowerGpuOpsToNVVMOpsPass> {
479621
NVVM::BlockIdYOp, NVVM::BlockIdZOp>,
480622
GPUIndexIntrinsicOpLowering<gpu::GridDimOp, NVVM::GridDimXOp,
481623
NVVM::GridDimYOp, NVVM::GridDimZOp>,
482-
GPUAllReduceOpLowering>(converter);
624+
GPUAllReduceOpLowering, FuncOpLowering>(converter);
483625
patterns.insert<OpToFuncCallLowering<ExpOp>>(converter, "__nv_expf",
484626
"__nv_exp");
485627
ConversionTarget target(getContext());
486628
target.addIllegalDialect<gpu::GPUDialect>();
487629
target.addIllegalOp<LLVM::ExpOp>();
630+
target.addIllegalOp<FuncOp>();
488631
target.addLegalDialect<LLVM::LLVMDialect>();
489632
target.addLegalDialect<NVVM::NVVMDialect>();
490633
// TODO(csigg): Remove once we support replacing non-root ops.

mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,36 @@ MemRefDescriptor MemRefDescriptor::undef(OpBuilder &builder, Location loc,
304304
return MemRefDescriptor(descriptor);
305305
}
306306

307+
/// Builds IR creating a MemRef descriptor that represents `type` and
308+
/// populates it with static shape and stride information extracted from the
309+
/// type.
310+
MemRefDescriptor
311+
MemRefDescriptor::fromStaticShape(OpBuilder &builder, Location loc,
312+
LLVMTypeConverter &typeConverter,
313+
MemRefType type, Value *memory) {
314+
assert(type.hasStaticShape() && "unexpected dynamic shape");
315+
assert(type.getAffineMaps().empty() && "unexpected layout map");
316+
317+
auto convertedType = typeConverter.convertType(type);
318+
assert(convertedType && "unexpected failure in memref type conversion");
319+
320+
auto descr = MemRefDescriptor::undef(builder, loc, convertedType);
321+
descr.setAllocatedPtr(builder, loc, memory);
322+
descr.setAlignedPtr(builder, loc, memory);
323+
descr.setConstantOffset(builder, loc, 0);
324+
325+
// Fill in sizes and strides, in reverse order to simplify stride
326+
// calculation.
327+
uint64_t runningStride = 1;
328+
for (unsigned i = type.getRank(); i > 0; --i) {
329+
unsigned dim = i - 1;
330+
descr.setConstantSize(builder, loc, dim, type.getDimSize(dim));
331+
descr.setConstantStride(builder, loc, dim, runningStride);
332+
runningStride *= type.getDimSize(dim);
333+
}
334+
return descr;
335+
}
336+
307337
/// Builds IR extracting the allocated pointer from the descriptor.
308338
Value *MemRefDescriptor::allocatedPtr(OpBuilder &builder, Location loc) {
309339
return extractPtr(builder, loc, kAllocatedPtrPosInMemRefDescriptor);
@@ -326,6 +356,14 @@ void MemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc,
326356
setPtr(builder, loc, kAlignedPtrPosInMemRefDescriptor, ptr);
327357
}
328358

359+
// Creates a constant Op producing a value of `resultType` from an index-typed
360+
// integer attribute.
361+
static Value *createIndexAttrConstant(OpBuilder &builder, Location loc,
362+
Type resultType, int64_t value) {
363+
return builder.create<LLVM::ConstantOp>(
364+
loc, resultType, builder.getIntegerAttr(builder.getIndexType(), value));
365+
}
366+
329367
/// Builds IR extracting the offset from the descriptor.
330368
Value *MemRefDescriptor::offset(OpBuilder &builder, Location loc) {
331369
return builder.create<LLVM::ExtractValueOp>(
@@ -341,6 +379,13 @@ void MemRefDescriptor::setOffset(OpBuilder &builder, Location loc,
341379
builder.getI64ArrayAttr(kOffsetPosInMemRefDescriptor));
342380
}
343381

382+
/// Builds IR inserting the offset into the descriptor.
383+
void MemRefDescriptor::setConstantOffset(OpBuilder &builder, Location loc,
384+
uint64_t offset) {
385+
setOffset(builder, loc,
386+
createIndexAttrConstant(builder, loc, indexType, offset));
387+
}
388+
344389
/// Builds IR extracting the pos-th size from the descriptor.
345390
Value *MemRefDescriptor::size(OpBuilder &builder, Location loc, unsigned pos) {
346391
return builder.create<LLVM::ExtractValueOp>(
@@ -356,6 +401,13 @@ void MemRefDescriptor::setSize(OpBuilder &builder, Location loc, unsigned pos,
356401
builder.getI64ArrayAttr({kSizePosInMemRefDescriptor, pos}));
357402
}
358403

404+
/// Builds IR inserting the pos-th size into the descriptor
405+
void MemRefDescriptor::setConstantSize(OpBuilder &builder, Location loc,
406+
unsigned pos, uint64_t size) {
407+
setSize(builder, loc, pos,
408+
createIndexAttrConstant(builder, loc, indexType, size));
409+
}
410+
359411
/// Builds IR extracting the pos-th size from the descriptor.
360412
Value *MemRefDescriptor::stride(OpBuilder &builder, Location loc,
361413
unsigned pos) {
@@ -372,6 +424,13 @@ void MemRefDescriptor::setStride(OpBuilder &builder, Location loc, unsigned pos,
372424
builder.getI64ArrayAttr({kStridePosInMemRefDescriptor, pos}));
373425
}
374426

427+
/// Builds IR inserting the pos-th stride into the descriptor
428+
void MemRefDescriptor::setConstantStride(OpBuilder &builder, Location loc,
429+
unsigned pos, uint64_t stride) {
430+
setStride(builder, loc, pos,
431+
createIndexAttrConstant(builder, loc, indexType, stride));
432+
}
433+
375434
LLVM::LLVMType MemRefDescriptor::getElementType() {
376435
return value->getType().cast<LLVM::LLVMType>().getStructElementType(
377436
kAlignedPtrPosInMemRefDescriptor);
@@ -448,8 +507,7 @@ class LLVMLegalizationPattern : public LLVMOpLowering {
448507
// Create an LLVM IR pseudo-operation defining the given index constant.
449508
Value *createIndexConstant(ConversionPatternRewriter &builder, Location loc,
450509
uint64_t value) const {
451-
auto attr = builder.getIntegerAttr(builder.getIndexType(), value);
452-
return builder.create<LLVM::ConstantOp>(loc, getIndexType(), attr);
510+
return createIndexAttrConstant(builder, loc, getIndexType(), value);
453511
}
454512

455513
protected:

0 commit comments

Comments
 (0)