Skip to content

Commit 920c461

Browse files
[mlir][Transform] Add support to drive conversions of func to LLVM with TD
This revision adds a `transform.apply_conversion_patterns.func.func_to_llvm` transformation. It is unclear at this point whether this should be spelled out as a standalone transformation or whether it should resemble `transform.apply_conversion_patterns.dialect_to_llvm "fun"`. This is dependent on how we want to handle the type converter creation. In particular the current implementation exhibits the fact that `transform.apply_conversion_patterns.memref.memref_to_llvm_type_converter` was not rich enough and did not match the LowerToLLVMOptions. Keeping those options in sync across all the passes that lower to LLVM is very error prone. Instead, we should have a single `to_llvm_type_converter`. Differential Revision: https://reviews.llvm.org/D157553
1 parent e53b28c commit 920c461

File tree

15 files changed

+288
-43
lines changed

15 files changed

+288
-43
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
add_subdirectory(IR)
22
add_subdirectory(Transforms)
3+
add_subdirectory(TransformOps)
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
set(LLVM_TARGET_DEFINITIONS FuncTransformOps.td)
2+
mlir_tablegen(FuncTransformOps.h.inc -gen-op-decls)
3+
mlir_tablegen(FuncTransformOps.cpp.inc -gen-op-defs)
4+
add_public_tablegen_target(MLIRFuncTransformOpsIncGen)
5+
6+
add_mlir_doc(FuncTransformOps FuncTransformOps Dialects/ -gen-op-doc)
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
//===- FuncTransformOps.h - CF transformation ops --------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_FUNC_TRANSFORMOPS_FUNCTRANSFORMOPS_H
10+
#define MLIR_DIALECT_FUNC_TRANSFORMOPS_FUNCTRANSFORMOPS_H
11+
12+
#include "mlir/Bytecode/BytecodeOpInterface.h"
13+
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
14+
#include "mlir/IR/OpImplementation.h"
15+
16+
#define GET_OP_CLASSES
17+
#include "mlir/Dialect/Func/TransformOps/FuncTransformOps.h.inc"
18+
19+
namespace mlir {
20+
class DialectRegistry;
21+
22+
namespace func {
23+
void registerTransformDialectExtension(DialectRegistry &registry);
24+
} // namespace func
25+
} // namespace mlir
26+
27+
#endif // MLIR_DIALECT_FUNC_TRANSFORMOPS_FUNCTRANSFORMOPS_H
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
//===- FuncTransformOps.td - CF transformation ops -*- tablegen -*--===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef FUNC_TRANSFORM_OPS
10+
#define FUNC_TRANSFORM_OPS
11+
12+
include "mlir/Dialect/Transform/IR/TransformDialect.td"
13+
include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
14+
include "mlir/Dialect/Transform/IR/TransformTypes.td"
15+
include "mlir/IR/OpBase.td"
16+
17+
def ApplyFuncToLLVMConversionPatternsOp : Op<Transform_Dialect,
18+
"apply_conversion_patterns.func.func_to_llvm",
19+
[DeclareOpInterfaceMethods<ConversionPatternDescriptorOpInterface,
20+
["verifyTypeConverter"]>]> {
21+
let description = [{
22+
Collects patterns that convert Func dialect ops to LLVM dialect ops.
23+
These patterns require an "LLVMTypeConverter".
24+
}];
25+
26+
let assemblyFormat = "attr-dict";
27+
}
28+
29+
#endif // FUNC_TRANSFORM_OPS

mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,22 @@ def MemrefToLLVMTypeConverterOp : Op<Transform_Dialect,
3232
instead of the classic "malloc", "aligned_alloc" and "free" functions.
3333
- `use_opaque_pointers`: Generate LLVM IR using opaque pointers instead of
3434
typed pointers.
35+
// TODO: the following two options don't really make sense for
36+
// memref_to_llvm_type_converter specifically.
37+
// We should have a single to_llvm_type_converter.
38+
- `use_bare_ptr_call_conv`: Replace FuncOp's MemRef arguments with bare
39+
pointers to the MemRef element types.
40+
- `data-layout`: String description (LLVM format) of the data layout that is
41+
expected on the produced module.
3542
}];
3643

3744
let arguments = (ins
38-
DefaultValuedAttr<BoolAttr, "false">:$use_aligned_alloc,
39-
DefaultValuedAttr<I64Attr, "0">:$index_bitwidth,
40-
DefaultValuedAttr<BoolAttr, "false">:$use_generic_functions,
41-
DefaultValuedAttr<BoolAttr, "false">:$use_opaque_pointers);
45+
DefaultValuedOptionalAttr<BoolAttr, "false">:$use_aligned_alloc,
46+
DefaultValuedOptionalAttr<I64Attr, "64">:$index_bitwidth,
47+
DefaultValuedOptionalAttr<BoolAttr, "false">:$use_generic_functions,
48+
DefaultValuedOptionalAttr<BoolAttr, "false">:$use_opaque_pointers,
49+
DefaultValuedOptionalAttr<BoolAttr, "false">:$use_bare_ptr_call_conv,
50+
OptionalAttr<StrAttr>:$data_layout);
4251
let assemblyFormat = "attr-dict";
4352
}
4453

mlir/include/mlir/InitAllDialects.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include "mlir/Dialect/DLTI/DLTI.h"
3535
#include "mlir/Dialect/EmitC/IR/EmitC.h"
3636
#include "mlir/Dialect/Func/IR/FuncOps.h"
37+
#include "mlir/Dialect/Func/TransformOps/FuncTransformOps.h"
3738
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
3839
#include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.h"
3940
#include "mlir/Dialect/IRDL/IR/IRDL.h"
@@ -138,6 +139,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
138139
// Register all dialect extensions.
139140
affine::registerTransformDialectExtension(registry);
140141
bufferization::registerTransformDialectExtension(registry);
142+
func::registerTransformDialectExtension(registry);
141143
gpu::registerTransformDialectExtension(registry);
142144
linalg::registerTransformDialectExtension(registry);
143145
memref::registerTransformDialectExtension(registry);

mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ struct FuncOpConversionBase : public ConvertOpToLLVMPattern<func::FuncOp> {
333333

334334
// Convert input FuncOp to LLVMFuncOp by using the LLVMTypeConverter provided
335335
// to this legalization pattern.
336-
LLVM::LLVMFuncOp
336+
FailureOr<LLVM::LLVMFuncOp>
337337
convertFuncOpToLLVMFuncOp(func::FuncOp funcOp,
338338
ConversionPatternRewriter &rewriter) const {
339339
// Convert the original function arguments. They are converted using the
@@ -344,7 +344,7 @@ struct FuncOpConversionBase : public ConvertOpToLLVMPattern<func::FuncOp> {
344344
funcOp.getFunctionType(), varargsAttr && varargsAttr.getValue(),
345345
shouldUseBarePtrCallConv(funcOp, getTypeConverter()), result);
346346
if (!llvmType)
347-
return nullptr;
347+
return rewriter.notifyMatchFailure(funcOp, "signature conversion failed");
348348

349349
// Propagate argument/result attributes to all converted arguments/result
350350
// obtained after converting a given original argument/result.
@@ -423,7 +423,8 @@ struct FuncOpConversionBase : public ConvertOpToLLVMPattern<func::FuncOp> {
423423
if (!attr) {
424424
funcOp->emitError() << "Contains " << linkageAttrName
425425
<< " attribute not of type LLVM::LinkageAttr";
426-
return nullptr;
426+
return rewriter.notifyMatchFailure(
427+
funcOp, "Contains linkage attribute not of type LLVM::LinkageAttr");
427428
}
428429
linkage = attr.getLinkage();
429430
}
@@ -436,7 +437,8 @@ struct FuncOpConversionBase : public ConvertOpToLLVMPattern<func::FuncOp> {
436437
if (!attr) {
437438
funcOp->emitError() << "Contains " << readnoneAttrName
438439
<< " attribute not of type UnitAttr";
439-
return nullptr;
440+
return rewriter.notifyMatchFailure(
441+
funcOp, "Contains readnone attribute not of type UnitAttr");
440442
}
441443
memoryAttr = LLVM::MemoryEffectsAttr::get(rewriter.getContext(),
442444
{LLVM::ModRefInfo::NoModRef,
@@ -453,8 +455,10 @@ struct FuncOpConversionBase : public ConvertOpToLLVMPattern<func::FuncOp> {
453455
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
454456
newFuncOp.end());
455457
if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), *typeConverter,
456-
&result)))
457-
return nullptr;
458+
&result))) {
459+
return rewriter.notifyMatchFailure(funcOp,
460+
"region types conversion failed");
461+
}
458462

459463
return newFuncOp;
460464
}
@@ -470,27 +474,28 @@ struct FuncOpConversion : public FuncOpConversionBase {
470474
LogicalResult
471475
matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
472476
ConversionPatternRewriter &rewriter) const override {
473-
auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter);
474-
if (!newFuncOp)
475-
return failure();
477+
FailureOr<LLVM::LLVMFuncOp> newFuncOp =
478+
convertFuncOpToLLVMFuncOp(funcOp, rewriter);
479+
if (failed(newFuncOp))
480+
return rewriter.notifyMatchFailure(funcOp, "Could not convert funcop");
476481

477482
if (!shouldUseBarePtrCallConv(funcOp, this->getTypeConverter())) {
478483
if (funcOp->getAttrOfType<UnitAttr>(
479484
LLVM::LLVMDialect::getEmitCWrapperAttrName())) {
480-
if (newFuncOp.isVarArg())
485+
if (newFuncOp->isVarArg())
481486
return funcOp->emitError("C interface for variadic functions is not "
482487
"supported yet.");
483488

484-
if (newFuncOp.isExternal())
485-
wrapExternalFunction(rewriter, funcOp.getLoc(), *getTypeConverter(),
486-
funcOp, newFuncOp);
489+
if (newFuncOp->isExternal())
490+
wrapExternalFunction(rewriter, funcOp->getLoc(), *getTypeConverter(),
491+
funcOp, *newFuncOp);
487492
else
488-
wrapForExternalCallers(rewriter, funcOp.getLoc(), *getTypeConverter(),
489-
funcOp, newFuncOp);
493+
wrapForExternalCallers(rewriter, funcOp->getLoc(),
494+
*getTypeConverter(), funcOp, *newFuncOp);
490495
}
491496
} else {
492-
modifyFuncOpToUseBarePtrCallingConv(rewriter, funcOp.getLoc(),
493-
*getTypeConverter(), newFuncOp,
497+
modifyFuncOpToUseBarePtrCallingConv(rewriter, funcOp->getLoc(),
498+
*getTypeConverter(), *newFuncOp,
494499
funcOp.getFunctionType().getInputs());
495500
}
496501

mlir/lib/Dialect/Func/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
add_subdirectory(Extensions)
22
add_subdirectory(IR)
33
add_subdirectory(Transforms)
4+
add_subdirectory(TransformOps)
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
add_mlir_dialect_library(MLIRFuncTransformOps
2+
FuncTransformOps.cpp
3+
4+
ADDITIONAL_HEADER_DIRS
5+
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Func/TransformOps
6+
7+
DEPENDS
8+
MLIRFuncTransformOpsIncGen
9+
10+
LINK_LIBS PUBLIC
11+
MLIRFuncDialect
12+
MLIRFuncToLLVM
13+
MLIRIR
14+
MLIRLLVMCommonConversion
15+
MLIRLLVMDialect
16+
MLIRTransformDialect
17+
)
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
//===- FuncTransformOps.cpp - Implementation of CF transform ops ---===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/Func/TransformOps/FuncTransformOps.h"
10+
11+
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
12+
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
13+
#include "mlir/Dialect/Func/IR/FuncOps.h"
14+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15+
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
16+
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
17+
#include "mlir/Dialect/Transform/IR/TransformOps.h"
18+
19+
using namespace mlir;
20+
21+
//===----------------------------------------------------------------------===//
22+
// Apply...ConversionPatternsOp
23+
//===----------------------------------------------------------------------===//
24+
25+
void transform::ApplyFuncToLLVMConversionPatternsOp::populatePatterns(
26+
TypeConverter &typeConverter, RewritePatternSet &patterns) {
27+
populateFuncToLLVMConversionPatterns(
28+
static_cast<LLVMTypeConverter &>(typeConverter), patterns);
29+
}
30+
31+
LogicalResult
32+
transform::ApplyFuncToLLVMConversionPatternsOp::verifyTypeConverter(
33+
transform::TypeConverterBuilderOpInterface builder) {
34+
if (builder.getTypeConverterType() != "LLVMTypeConverter")
35+
return emitOpError("expected LLVMTypeConverter");
36+
return success();
37+
}
38+
39+
//===----------------------------------------------------------------------===//
40+
// Transform op registration
41+
//===----------------------------------------------------------------------===//
42+
43+
namespace {
44+
class FuncTransformDialectExtension
45+
: public transform::TransformDialectExtension<
46+
FuncTransformDialectExtension> {
47+
public:
48+
using Base::Base;
49+
50+
void init() {
51+
declareGeneratedDialect<LLVM::LLVMDialect>();
52+
53+
registerTransformOps<
54+
#define GET_OP_LIST
55+
#include "mlir/Dialect/Func/TransformOps/FuncTransformOps.cpp.inc"
56+
>();
57+
}
58+
};
59+
} // namespace
60+
61+
#define GET_OP_CLASSES
62+
#include "mlir/Dialect/Func/TransformOps/FuncTransformOps.cpp.inc"
63+
64+
void mlir::func::registerTransformDialectExtension(DialectRegistry &registry) {
65+
registry.addExtensions<FuncTransformDialectExtension>();
66+
}

mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,13 @@ transform::MemrefToLLVMTypeConverterOp::getTypeConverter() {
4444
if (getIndexBitwidth() != kDeriveIndexBitwidthFromDataLayout)
4545
options.overrideIndexBitwidth(getIndexBitwidth());
4646

47+
// TODO: the following two options don't really make sense for
48+
// memref_to_llvm_type_converter specifically but we should have a single
49+
// to_llvm_type_converter.
50+
if (getDataLayout().has_value())
51+
options.dataLayout = llvm::DataLayout(getDataLayout().value());
52+
options.useBarePtrCallConv = getUseBarePtrCallConv();
53+
4754
return std::make_unique<LLVMTypeConverter>(getContext(), options);
4855
}
4956

mlir/lib/Dialect/Transform/IR/TransformOps.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "mlir/Dialect/Transform/IR/TransformOps.h"
1010

1111
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
12+
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
1213
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
1314
#include "mlir/Dialect/Transform/IR/MatchInterfaces.h"
1415
#include "mlir/Dialect/Transform/IR/TransformAttrs.h"
@@ -498,7 +499,7 @@ DiagnosedSilenceableFailure transform::ApplyConversionPatternsOp::apply(
498499
defaultTypeConverter = typeConverterBuilder.getTypeConverter();
499500

500501
// Configure conversion target.
501-
ConversionTarget conversionTarget(*ctx);
502+
ConversionTarget conversionTarget(*getContext());
502503
if (getLegalOps())
503504
for (Attribute attr : cast<ArrayAttr>(*getLegalOps()))
504505
conversionTarget.addLegalOp(

mlir/test/Conversion/FuncToLLVM/func-memref-return.mlir

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
// RUN: mlir-opt -convert-func-to-llvm='use-opaque-pointers=1' -reconcile-unrealized-casts %s | FileCheck %s
2-
// RUN: mlir-opt -convert-func-to-llvm='use-bare-ptr-memref-call-conv=1 use-opaque-pointers=1' -split-input-file %s | FileCheck %s --check-prefix=BAREPTR
2+
3+
// RUN: mlir-opt -convert-func-to-llvm='use-bare-ptr-memref-call-conv=1 use-opaque-pointers=1' %s | FileCheck %s --check-prefix=BAREPTR
4+
5+
// RUN: mlir-opt -test-transform-dialect-interpreter %s | FileCheck %s --check-prefix=BAREPTR
36

47
// These tests were separated from func-memref.mlir because applying
58
// -reconcile-unrealized-casts resulted in `llvm.extractvalue` ops getting
@@ -32,8 +35,6 @@ func.func @check_static_return(%static : memref<32x18xf32>) -> memref<32x18xf32>
3235
return %static : memref<32x18xf32>
3336
}
3437

35-
// -----
36-
3738
// CHECK-LABEL: func @check_static_return_with_offset
3839
// CHECK-COUNT-2: !llvm.ptr
3940
// CHECK-COUNT-5: i64
@@ -61,7 +62,6 @@ func.func @check_static_return_with_offset(%static : memref<32x18xf32, strided<[
6162
return %static : memref<32x18xf32, strided<[22,1], offset: 7>>
6263
}
6364

64-
// -----
6565

6666
// BAREPTR: llvm.func @foo(!llvm.ptr) -> !llvm.ptr
6767
func.func private @foo(memref<10xi8>) -> memref<20xi8>
@@ -87,24 +87,19 @@ func.func @check_memref_func_call(%in : memref<10xi8>) -> memref<20xi8> {
8787
return %res : memref<20xi8>
8888
}
8989

90-
// -----
91-
9290
// BAREPTR-LABEL: func @check_return(
9391
// BAREPTR-SAME: %{{.*}}: memref<?xi8>) -> memref<?xi8>
9492
func.func @check_return(%in : memref<?xi8>) -> memref<?xi8> {
9593
// BAREPTR: llvm.return {{.*}} : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
9694
return %in : memref<?xi8>
9795
}
9896

99-
// -----
100-
10197
// BAREPTR-LABEL: func @unconvertible_multiresult
10298
// BAREPTR-SAME: %{{.*}}: memref<?xf32>, %{{.*}}: memref<?xf32>) -> (memref<?xf32>, memref<?xf32>)
10399
func.func @unconvertible_multiresult(%arg0: memref<?xf32> , %arg1: memref<?xf32>) -> (memref<?xf32>, memref<?xf32>) {
104100
return %arg0, %arg1 : memref<?xf32>, memref<?xf32>
105101
}
106102

107-
// -----
108103
// BAREPTR-LABEL: func @unranked_memref(
109104
// BAREPTR-SAME: %{{.*}}: memref<*xi32>)
110105
func.func @unranked_memref(%arg0:memref<*xi32>) {
@@ -114,3 +109,18 @@ func.func @unranked_memref(%arg0:memref<*xi32>) {
114109
return
115110
}
116111
func.func private @printMemrefI32(memref<*xi32>)
112+
113+
transform.sequence failures(propagate) {
114+
^bb1(%toplevel_module: !transform.any_op):
115+
%func = transform.structured.match ops{["func.func"]} in %toplevel_module
116+
: (!transform.any_op) -> !transform.any_op
117+
transform.apply_conversion_patterns to %func {
118+
transform.apply_conversion_patterns.func.func_to_llvm
119+
} with type_converter {
120+
transform.apply_conversion_patterns.memref.memref_to_llvm_type_converter
121+
{use_bare_ptr_call_conv = true, use_opaque_pointers = true}
122+
} {
123+
legal_dialects = ["llvm"],
124+
partial_conversion
125+
} : !transform.any_op
126+
}

0 commit comments

Comments
 (0)