|
29 | 29 | #include "mlir/Pass/Pass.h"
|
30 | 30 | #include "mlir/Transforms/DialectConversion.h"
|
31 | 31 |
|
| 32 | +#include "llvm/Support/FormatVariadic.h" |
| 33 | + |
32 | 34 | #include "../GPUCommon/IndexIntrinsicsOpLowering.h"
|
33 | 35 | #include "../GPUCommon/OpToFuncCallLowering.h"
|
34 | 36 |
|
@@ -451,6 +453,146 @@ struct GPUAllReduceOpLowering : public LLVMOpLowering {
|
451 | 453 | static constexpr int kWarpSize = 32;
|
452 | 454 | };
|
453 | 455 |
|
| 456 | +namespace { |
| 457 | + |
| 458 | +struct FuncOpLowering : LLVMOpLowering { |
| 459 | + explicit FuncOpLowering(LLVMTypeConverter &typeConverter) |
| 460 | + : LLVMOpLowering(gpu::GPUFuncOp::getOperationName(), |
| 461 | + typeConverter.getDialect()->getContext(), |
| 462 | + typeConverter) {} |
| 463 | + |
| 464 | + PatternMatchResult |
| 465 | + matchAndRewrite(Operation *op, ArrayRef<Value *> operands, |
| 466 | + ConversionPatternRewriter &rewriter) const override { |
| 467 | + assert(operands.empty() && "func op is not expected to have operands"); |
| 468 | + auto gpuFuncOp = cast<gpu::GPUFuncOp>(op); |
| 469 | + Location loc = gpuFuncOp.getLoc(); |
| 470 | + |
| 471 | + SmallVector<LLVM::GlobalOp, 3> workgroupBuffers; |
| 472 | + workgroupBuffers.reserve(gpuFuncOp.getNumWorkgroupAttributions()); |
| 473 | + for (auto en : llvm::enumerate(gpuFuncOp.getWorkgroupAttributions())) { |
| 474 | + Value *attribution = en.value(); |
| 475 | + |
| 476 | + auto type = attribution->getType().dyn_cast<MemRefType>(); |
| 477 | + assert(type && type.hasStaticShape() && "unexpected type in attribution"); |
| 478 | + |
| 479 | + uint64_t numElements = type.getNumElements(); |
| 480 | + |
| 481 | + auto elementType = |
| 482 | + lowering.convertType(type.getElementType()).cast<LLVM::LLVMType>(); |
| 483 | + auto arrayType = LLVM::LLVMType::getArrayTy(elementType, numElements); |
| 484 | + auto addSpaceAttr = rewriter.getNamedAttr( |
| 485 | + "addr_space", rewriter.getI32IntegerAttr( |
| 486 | + gpu::GPUDialect::getWorkgroupAddressSpace())); |
| 487 | + std::string name = |
| 488 | + llvm::formatv("__wg_{0}_{1}", gpuFuncOp.getName(), en.index()); |
| 489 | + auto globalOp = rewriter.create<LLVM::GlobalOp>( |
| 490 | + gpuFuncOp.getLoc(), arrayType, /*isConstant=*/false, |
| 491 | + LLVM::Linkage::Internal, name, /*value=*/Attribute(), |
| 492 | + llvm::makeArrayRef(addSpaceAttr)); |
| 493 | + workgroupBuffers.push_back(globalOp); |
| 494 | + } |
| 495 | + |
| 496 | + // Rewrite the original GPU function to an LLVM function. |
| 497 | + // TODO(zinenko): there is a hack in the std->llvm lowering that promotes |
| 498 | + // structs to pointers that probably needs to be replicated here. |
| 499 | + auto funcType = lowering.convertType(gpuFuncOp.getType()) |
| 500 | + .cast<LLVM::LLVMType>() |
| 501 | + .getPointerElementTy(); |
| 502 | + |
| 503 | + // Remap proper input types. |
| 504 | + TypeConverter::SignatureConversion signatureConversion( |
| 505 | + gpuFuncOp.front().getNumArguments()); |
| 506 | + for (unsigned i = 0, e = funcType.getFunctionNumParams(); i < e; ++i) |
| 507 | + signatureConversion.addInputs(i, funcType.getFunctionParamType(i)); |
| 508 | + |
| 509 | + // Create the new function operation. Only copy those attributes that are |
| 510 | + // not specific to function modeling. |
| 511 | + SmallVector<NamedAttribute, 4> attributes; |
| 512 | + for (const auto &attr : gpuFuncOp.getAttrs()) { |
| 513 | + if (attr.first.is(SymbolTable::getSymbolAttrName()) || |
| 514 | + attr.first.is(impl::getTypeAttrName()) || |
| 515 | + attr.first.is(gpu::GPUFuncOp::getNumWorkgroupAttributionsAttrName())) |
| 516 | + continue; |
| 517 | + attributes.push_back(attr); |
| 518 | + } |
| 519 | + auto llvmFuncOp = rewriter.create<LLVM::LLVMFuncOp>( |
| 520 | + gpuFuncOp.getLoc(), gpuFuncOp.getName(), funcType, |
| 521 | + LLVM::Linkage::External, attributes); |
| 522 | + |
| 523 | + { |
| 524 | + // Insert operations that correspond to converted workgroup and private |
| 525 | + // memory attributions to the body of the function. This must operate on |
| 526 | + // the original function, before the body region is inlined in the new |
| 527 | + // function to maintain the relation between block arguments and the |
| 528 | + // parent operation that assigns their semantics. |
| 529 | + OpBuilder::InsertionGuard guard(rewriter); |
| 530 | + |
| 531 | + // Rewrite workgroup memory attributions to addresses of global buffers. |
| 532 | + rewriter.setInsertionPointToStart(&gpuFuncOp.front()); |
| 533 | + unsigned numProperArguments = gpuFuncOp.getNumArguments(); |
| 534 | + auto i32Type = LLVM::LLVMType::getInt32Ty(lowering.getDialect()); |
| 535 | + |
| 536 | + Value *zero = nullptr; |
| 537 | + if (!workgroupBuffers.empty()) |
| 538 | + zero = rewriter.create<LLVM::ConstantOp>(loc, i32Type, |
| 539 | + rewriter.getI32IntegerAttr(0)); |
| 540 | + for (auto en : llvm::enumerate(workgroupBuffers)) { |
| 541 | + LLVM::GlobalOp global = en.value(); |
| 542 | + Value *address = rewriter.create<LLVM::AddressOfOp>(loc, global); |
| 543 | + auto elementType = global.getType().getArrayElementType(); |
| 544 | + Value *memory = rewriter.create<LLVM::GEPOp>( |
| 545 | + loc, elementType.getPointerTo(global.addr_space().getZExtValue()), |
| 546 | + address, ArrayRef<Value *>{zero, zero}); |
| 547 | + |
| 548 | + // Build a memref descriptor pointing to the buffer to plug with the |
| 549 | + // existing memref infrastructure. This may use more registers than |
| 550 | + // otherwise necessary given that memref sizes are fixed, but we can try |
| 551 | + // and canonicalize that away later. |
| 552 | + Value *attribution = gpuFuncOp.getWorkgroupAttributions()[en.index()]; |
| 553 | + auto type = attribution->getType().cast<MemRefType>(); |
| 554 | + auto descr = MemRefDescriptor::fromStaticShape(rewriter, loc, lowering, |
| 555 | + type, memory); |
| 556 | + signatureConversion.remapInput(numProperArguments + en.index(), descr); |
| 557 | + } |
| 558 | + |
| 559 | + // Rewrite private memory attributions to alloca'ed buffers. |
| 560 | + unsigned numWorkgroupAttributions = |
| 561 | + gpuFuncOp.getNumWorkgroupAttributions(); |
| 562 | + auto int64Ty = LLVM::LLVMType::getInt64Ty(lowering.getDialect()); |
| 563 | + for (auto en : llvm::enumerate(gpuFuncOp.getPrivateAttributions())) { |
| 564 | + Value *attribution = en.value(); |
| 565 | + auto type = attribution->getType().cast<MemRefType>(); |
| 566 | + assert(type && type.hasStaticShape() && |
| 567 | + "unexpected type in attribution"); |
| 568 | + |
| 569 | + auto ptrType = lowering.convertType(type.getElementType()) |
| 570 | + .cast<LLVM::LLVMType>() |
| 571 | + .getPointerTo(type.getMemorySpace()); |
| 572 | + Value *numElements = rewriter.create<LLVM::ConstantOp>( |
| 573 | + gpuFuncOp.getLoc(), int64Ty, |
| 574 | + rewriter.getI64IntegerAttr(type.getNumElements())); |
| 575 | + Value *allocated = rewriter.create<LLVM::AllocaOp>( |
| 576 | + gpuFuncOp.getLoc(), ptrType, numElements, /*alignment=*/0); |
| 577 | + auto descr = MemRefDescriptor::fromStaticShape(rewriter, loc, lowering, |
| 578 | + type, allocated); |
| 579 | + signatureConversion.remapInput( |
| 580 | + numProperArguments + numWorkgroupAttributions + en.index(), descr); |
| 581 | + } |
| 582 | + } |
| 583 | + |
| 584 | + rewriter.inlineRegionBefore(gpuFuncOp.getBody(), llvmFuncOp.getBody(), |
| 585 | + llvmFuncOp.end()); |
| 586 | + rewriter.applySignatureConversion(&llvmFuncOp.getBody(), |
| 587 | + signatureConversion); |
| 588 | + |
| 589 | + rewriter.eraseOp(gpuFuncOp); |
| 590 | + return matchSuccess(); |
| 591 | + } |
| 592 | +}; |
| 593 | + |
| 594 | +} // end namespace |
| 595 | + |
454 | 596 | /// Import the GPU Ops to NVVM Patterns.
|
455 | 597 | #include "GPUToNVVM.cpp.inc"
|
456 | 598 |
|
@@ -479,12 +621,13 @@ class LowerGpuOpsToNVVMOpsPass : public ModulePass<LowerGpuOpsToNVVMOpsPass> {
|
479 | 621 | NVVM::BlockIdYOp, NVVM::BlockIdZOp>,
|
480 | 622 | GPUIndexIntrinsicOpLowering<gpu::GridDimOp, NVVM::GridDimXOp,
|
481 | 623 | NVVM::GridDimYOp, NVVM::GridDimZOp>,
|
482 |
| - GPUAllReduceOpLowering>(converter); |
| 624 | + GPUAllReduceOpLowering, FuncOpLowering>(converter); |
483 | 625 | patterns.insert<OpToFuncCallLowering<ExpOp>>(converter, "__nv_expf",
|
484 | 626 | "__nv_exp");
|
485 | 627 | ConversionTarget target(getContext());
|
486 | 628 | target.addIllegalDialect<gpu::GPUDialect>();
|
487 | 629 | target.addIllegalOp<LLVM::ExpOp>();
|
| 630 | + target.addIllegalOp<FuncOp>(); |
488 | 631 | target.addLegalDialect<LLVM::LLVMDialect>();
|
489 | 632 | target.addLegalDialect<NVVM::NVVMDialect>();
|
490 | 633 | // TODO(csigg): Remove once we support replacing non-root ops.
|
|
0 commit comments