Skip to content

Commit 7030280

Browse files
[mlir][GPU] Improve gpu.module op implementation (#102866)
- Replace hand-written parser/printer with auto-generated assembly format. - Remove implicit `gpu.module_end` terminator and use the `NoTerminator` trait instead. (Same as `builtin.module`.) - Turn the region into a graph region. (Same as `builtin.module`.)
1 parent f5ba3f6 commit 7030280

File tree

8 files changed

+20
-105
lines changed

8 files changed

+20
-105
lines changed

mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "mlir/IR/Dialect.h"
2323
#include "mlir/IR/OpDefinition.h"
2424
#include "mlir/IR/OpImplementation.h"
25+
#include "mlir/IR/RegionKindInterface.h"
2526
#include "mlir/IR/SymbolTable.h"
2627
#include "mlir/Interfaces/ControlFlowInterfaces.h"
2728
#include "mlir/Interfaces/FunctionInterfaces.h"

mlir/include/mlir/Dialect/GPU/IR/GPUOps.td

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ include "mlir/Dialect/GPU/IR/GPUDeviceMappingAttr.td"
2121
include "mlir/Dialect/GPU/IR/ParallelLoopMapperAttr.td"
2222
include "mlir/IR/CommonTypeConstraints.td"
2323
include "mlir/IR/EnumAttr.td"
24+
include "mlir/IR/RegionKindInterface.td"
2425
include "mlir/IR/SymbolInterfaces.td"
2526
include "mlir/Interfaces/ControlFlowInterfaces.td"
2627
include "mlir/Interfaces/DataLayoutInterfaces.td"
@@ -1347,10 +1348,7 @@ def GPU_BarrierOp : GPU_Op<"barrier"> {
13471348

13481349
def GPU_GPUModuleOp : GPU_Op<"module", [
13491350
DataLayoutOpInterface, HasDefaultDLTIDataLayout, IsolatedFromAbove,
1350-
SymbolTable, Symbol, SingleBlockImplicitTerminator<"ModuleEndOp">
1351-
]>, Arguments<(ins SymbolNameAttr:$sym_name,
1352-
OptionalAttr<GPUNonEmptyTargetArrayAttr>:$targets,
1353-
OptionalAttr<OffloadingTranslationAttr>:$offloadingHandler)> {
1351+
NoRegionArguments, SymbolTable, Symbol] # GraphRegionNoTerminator.traits> {
13541352
let summary = "A top level compilation unit containing code to be run on a GPU.";
13551353
let description = [{
13561354
GPU module contains code that is intended to be run on a GPU. A host device
@@ -1379,15 +1377,13 @@ def GPU_GPUModuleOp : GPU_Op<"module", [
13791377
gpu.module @symbol_name {
13801378
gpu.func {}
13811379
...
1382-
gpu.module_end
13831380
}
13841381
// Module with offloading handler and target attributes.
13851382
gpu.module @symbol_name2 <#gpu.select_object<1>> [
13861383
#nvvm.target,
13871384
#rocdl.target<chip = "gfx90a">] {
13881385
gpu.func {}
13891386
...
1390-
gpu.module_end
13911387
}
13921388
```
13931389
}];
@@ -1399,8 +1395,18 @@ def GPU_GPUModuleOp : GPU_Op<"module", [
13991395
"ArrayRef<Attribute>":$targets,
14001396
CArg<"Attribute", "{}">:$handler)>
14011397
];
1398+
1399+
let arguments = (ins
1400+
SymbolNameAttr:$sym_name,
1401+
OptionalAttr<GPUNonEmptyTargetArrayAttr>:$targets,
1402+
OptionalAttr<OffloadingTranslationAttr>:$offloadingHandler);
14021403
let regions = (region SizedRegion<1>:$bodyRegion);
1403-
let hasCustomAssemblyFormat = 1;
1404+
let assemblyFormat = [{
1405+
$sym_name
1406+
(`<` $offloadingHandler^ `>`)?
1407+
($targets^)?
1408+
attr-dict-with-keyword $bodyRegion
1409+
}];
14041410

14051411
// We need to ensure the block inside the region is properly terminated;
14061412
// the auto-generated builders do not guarantee that.
@@ -1415,17 +1421,6 @@ def GPU_GPUModuleOp : GPU_Op<"module", [
14151421
}];
14161422
}
14171423

1418-
def GPU_ModuleEndOp : GPU_Op<"module_end", [
1419-
Terminator, HasParent<"GPUModuleOp">
1420-
]> {
1421-
let summary = "A pseudo op that marks the end of a gpu.module.";
1422-
let description = [{
1423-
This op terminates the only block inside the only region of a `gpu.module`.
1424-
}];
1425-
1426-
let assemblyFormat = "attr-dict";
1427-
}
1428-
14291424
def GPU_BinaryOp : GPU_Op<"binary", [Symbol]>, Arguments<(ins
14301425
SymbolNameAttr:$sym_name,
14311426
OptionalAttr<OffloadingTranslationAttr>:$offloadingHandler,

mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) {
316316
LLVM::SinOp, LLVM::SqrtOp>();
317317

318318
// TODO: Remove once we support replacing non-root ops.
319-
target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp, gpu::ModuleEndOp>();
319+
target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp>();
320320
}
321321

322322
template <typename OpTy>

mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,7 @@ void mlir::configureGpuToROCDLConversionLegality(ConversionTarget &target) {
335335
LLVM::SqrtOp>();
336336

337337
// TODO: Remove once we support replacing non-root ops.
338-
target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp, gpu::ModuleEndOp>();
338+
target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp>();
339339
}
340340

341341
template <typename OpTy>

mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -90,19 +90,6 @@ class GPUModuleConversion final : public OpConversionPattern<gpu::GPUModuleOp> {
9090
ConversionPatternRewriter &rewriter) const override;
9191
};
9292

93-
class GPUModuleEndConversion final
94-
: public OpConversionPattern<gpu::ModuleEndOp> {
95-
public:
96-
using OpConversionPattern::OpConversionPattern;
97-
98-
LogicalResult
99-
matchAndRewrite(gpu::ModuleEndOp endOp, OpAdaptor adaptor,
100-
ConversionPatternRewriter &rewriter) const override {
101-
rewriter.eraseOp(endOp);
102-
return success();
103-
}
104-
};
105-
10693
/// Pattern to convert a gpu.return into a SPIR-V return.
10794
// TODO: This can go to DRR when GPU return has operands.
10895
class GPUReturnOpConversion final : public OpConversionPattern<gpu::ReturnOp> {
@@ -614,7 +601,7 @@ void mlir::populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
614601
RewritePatternSet &patterns) {
615602
patterns.add<
616603
GPUBarrierConversion, GPUFuncOpConversion, GPUModuleConversion,
617-
GPUModuleEndConversion, GPUReturnOpConversion, GPUShuffleConversion,
604+
GPUReturnOpConversion, GPUShuffleConversion,
618605
LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>,
619606
LaunchConfigConversion<gpu::GridDimOp, spirv::BuiltIn::NumWorkgroups>,
620607
LaunchConfigConversion<gpu::BlockDimOp, spirv::BuiltIn::WorkgroupSize>,

mlir/lib/Dialect/GPU/IR/GPUDialect.cpp

Lines changed: 1 addition & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1736,8 +1736,7 @@ LogicalResult gpu::ReturnOp::verify() {
17361736
void GPUModuleOp::build(OpBuilder &builder, OperationState &result,
17371737
StringRef name, ArrayAttr targets,
17381738
Attribute offloadingHandler) {
1739-
ensureTerminator(*result.addRegion(), builder, result.location);
1740-
1739+
result.addRegion()->emplaceBlock();
17411740
Properties &props = result.getOrAddProperties<Properties>();
17421741
if (targets)
17431742
props.targets = targets;
@@ -1753,73 +1752,6 @@ void GPUModuleOp::build(OpBuilder &builder, OperationState &result,
17531752
offloadingHandler);
17541753
}
17551754

1756-
ParseResult GPUModuleOp::parse(OpAsmParser &parser, OperationState &result) {
1757-
StringAttr nameAttr;
1758-
ArrayAttr targetsAttr;
1759-
1760-
if (parser.parseSymbolName(nameAttr))
1761-
return failure();
1762-
1763-
Properties &props = result.getOrAddProperties<Properties>();
1764-
props.setSymName(nameAttr);
1765-
1766-
// Parse the optional offloadingHandler
1767-
if (succeeded(parser.parseOptionalLess())) {
1768-
if (parser.parseAttribute(props.offloadingHandler))
1769-
return failure();
1770-
if (parser.parseGreater())
1771-
return failure();
1772-
}
1773-
1774-
// Parse the optional array of target attributes.
1775-
OptionalParseResult targetsAttrResult =
1776-
parser.parseOptionalAttribute(targetsAttr, Type{});
1777-
if (targetsAttrResult.has_value()) {
1778-
if (failed(*targetsAttrResult)) {
1779-
return failure();
1780-
}
1781-
props.targets = targetsAttr;
1782-
}
1783-
1784-
// If module attributes are present, parse them.
1785-
if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
1786-
return failure();
1787-
1788-
// Parse the module body.
1789-
auto *body = result.addRegion();
1790-
if (parser.parseRegion(*body, {}))
1791-
return failure();
1792-
1793-
// Ensure that this module has a valid terminator.
1794-
GPUModuleOp::ensureTerminator(*body, parser.getBuilder(), result.location);
1795-
return success();
1796-
}
1797-
1798-
void GPUModuleOp::print(OpAsmPrinter &p) {
1799-
p << ' ';
1800-
p.printSymbolName(getName());
1801-
1802-
if (Attribute attr = getOffloadingHandlerAttr()) {
1803-
p << " <";
1804-
p.printAttribute(attr);
1805-
p << ">";
1806-
}
1807-
1808-
if (Attribute attr = getTargetsAttr()) {
1809-
p << ' ';
1810-
p.printAttribute(attr);
1811-
p << ' ';
1812-
}
1813-
1814-
p.printOptionalAttrDictWithKeyword((*this)->getAttrs(),
1815-
{mlir::SymbolTable::getSymbolAttrName(),
1816-
getTargetsAttrName(),
1817-
getOffloadingHandlerAttrName()});
1818-
p << ' ';
1819-
p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
1820-
/*printBlockTerminators=*/false);
1821-
}
1822-
18231755
bool GPUModuleOp::hasTarget(Attribute target) {
18241756
if (ArrayAttr targets = getTargetsAttr())
18251757
return llvm::count(targets.getValue(), target);

mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-32b.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ module attributes {transform.with_named_sequence} {
6767
{index_bitwidth = 32, use_opaque_pointers = true}
6868
} {
6969
legal_dialects = ["llvm", "memref", "nvvm"],
70-
legal_ops = ["func.func", "gpu.module", "gpu.module_end", "gpu.yield"],
70+
legal_ops = ["func.func", "gpu.module", "gpu.yield"],
7171
illegal_dialects = ["gpu"],
7272
illegal_ops = ["llvm.cos", "llvm.exp", "llvm.exp2", "llvm.fabs", "llvm.fceil",
7373
"llvm.ffloor", "llvm.log", "llvm.log10", "llvm.log2", "llvm.pow",

mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -942,7 +942,7 @@ module attributes {transform.with_named_sequence} {
942942
use_bare_ptr_call_conv = false}
943943
} {
944944
legal_dialects = ["llvm", "memref", "nvvm", "test"],
945-
legal_ops = ["func.func", "gpu.module", "gpu.module_end", "gpu.yield"],
945+
legal_ops = ["func.func", "gpu.module", "gpu.yield"],
946946
illegal_dialects = ["gpu"],
947947
illegal_ops = ["llvm.copysign", "llvm.cos", "llvm.exp", "llvm.exp2", "llvm.fabs", "llvm.fceil",
948948
"llvm.ffloor", "llvm.fma", "llvm.frem", "llvm.log", "llvm.log10", "llvm.log2", "llvm.pow",

0 commit comments

Comments
 (0)