Skip to content

Commit 1ca3376

Browse files
authored
[ThroughMLIR] basic printf support (#1687)
This PR is related to #1685 and adds some basic support for the printf function. Limitations: 1. It only works if all variadic params are of basic interger/float type (for more info why memref type operands don't work see #1685) 2. Only works if the format string is definied directly inside the printf function The downside of this PR is also that the handling this edge case adds significant code bloat and reduces readability for the cir.call op lowering (I tried to insert some meanigful comments to improve the readability), but I think its worth to do this so we have some basic printf support (without adding an extra cir operation) until upstream support for variadic functions is added to the func dialect. Also a few more test (which use such a basic form of printf) in the llvm Single Source test suite are working with this PR: before this PR: Testing Time: 4.00s Total Discovered Tests: 1833 Passed : 420 (22.91%) Failed : 10 (0.55%) Executable Missing: 1403 (76.54%) with this PR: Testing Time: 10.29s Total Discovered Tests: 1833 Passed : 458 (24.99%) Failed : 6 (0.33%) Executable Missing: 1369 (74.69%)
1 parent 3401122 commit 1ca3376

File tree

2 files changed

+206
-28
lines changed

2 files changed

+206
-28
lines changed

clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp

Lines changed: 168 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,12 @@
2323
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
2424
#include "mlir/Dialect/Func/IR/FuncOps.h"
2525
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
26+
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
2627
#include "mlir/Dialect/Math/IR/Math.h"
2728
#include "mlir/Dialect/MemRef/IR/MemRef.h"
2829
#include "mlir/Dialect/SCF/IR/SCF.h"
2930
#include "mlir/Dialect/Vector/IR/VectorOps.h"
31+
#include "mlir/IR/BuiltinAttributes.h"
3032
#include "mlir/IR/BuiltinDialect.h"
3133
#include "mlir/IR/BuiltinOps.h"
3234
#include "mlir/IR/BuiltinTypes.h"
@@ -104,9 +106,124 @@ class CIRCallOpLowering : public mlir::OpConversionPattern<cir::CallOp> {
104106
if (mlir::failed(
105107
getTypeConverter()->convertTypes(op.getResultTypes(), types)))
106108
return mlir::failure();
107-
rewriter.replaceOpWithNewOp<mlir::func::CallOp>(
108-
op, op.getCalleeAttr(), types, adaptor.getOperands());
109-
return mlir::LogicalResult::success();
109+
110+
if (!op.isIndirect()) {
111+
// Currently variadic functions are not supported by the builtin func
112+
// dialect. For now only basic call to printf are supported by using the
113+
// llvmir dialect.
114+
// TODO: remove this and add support for variadic function calls once
115+
// TODO: supported by the func dialect
116+
if (op.getCallee()->equals_insensitive("printf")) {
117+
SmallVector<mlir::Type> operandTypes =
118+
llvm::to_vector(adaptor.getOperands().getTypes());
119+
120+
// Drop the initial memref operand type (we replace the memref format
121+
// string with equivalent llvm.mlir ops)
122+
operandTypes.erase(operandTypes.begin());
123+
124+
// Check that the printf attributes can be used in llvmir dialect (i.e
125+
// they have integer/float type)
126+
if (!llvm::all_of(operandTypes, [](mlir::Type ty) {
127+
return mlir::LLVM::isCompatibleType(ty);
128+
})) {
129+
return op.emitError()
130+
<< "lowering of printf attributes having a type that is "
131+
"converted to memref in cir-to-mlir lowering (e.g. "
132+
"pointers) not supported yet";
133+
}
134+
135+
// Currently only versions of printf are supported where the format
136+
// string is defined inside the printf ==> the lowering of the cir ops
137+
// will match:
138+
// %global = memref.get_global %frm_str
139+
// %* = memref.reinterpret_cast (%global, 0)
140+
if (auto reinterpret_castOP =
141+
mlir::dyn_cast_or_null<mlir::memref::ReinterpretCastOp>(
142+
adaptor.getOperands()[0].getDefiningOp())) {
143+
if (auto getGlobalOp =
144+
mlir::dyn_cast_or_null<mlir::memref::GetGlobalOp>(
145+
reinterpret_castOP->getOperand(0).getDefiningOp())) {
146+
mlir::ModuleOp parentModule = op->getParentOfType<mlir::ModuleOp>();
147+
148+
auto context = rewriter.getContext();
149+
150+
// Find the memref.global op defining the frm_str
151+
auto globalOp = parentModule.lookupSymbol<mlir::memref::GlobalOp>(
152+
getGlobalOp.getNameAttr());
153+
154+
rewriter.setInsertionPoint(globalOp);
155+
156+
// Insert a equivalent llvm.mlir.global
157+
auto initialvalueAttr =
158+
mlir::dyn_cast_or_null<mlir::DenseIntElementsAttr>(
159+
globalOp.getInitialValueAttr());
160+
161+
auto type = mlir::LLVM::LLVMArrayType::get(
162+
mlir::IntegerType::get(context, 8),
163+
initialvalueAttr.getNumElements());
164+
165+
auto llvmglobalOp = rewriter.create<mlir::LLVM::GlobalOp>(
166+
globalOp->getLoc(), type, true, mlir::LLVM::Linkage::Internal,
167+
"printf_format_" + globalOp.getSymName().str(),
168+
initialvalueAttr, 0);
169+
170+
rewriter.setInsertionPoint(getGlobalOp);
171+
172+
// Insert llvmir dialect ops to retrive the !llvm.ptr of the global
173+
auto globalPtrOp = rewriter.create<mlir::LLVM::AddressOfOp>(
174+
getGlobalOp->getLoc(), llvmglobalOp);
175+
176+
mlir::Value cst0 = rewriter.create<mlir::LLVM::ConstantOp>(
177+
getGlobalOp->getLoc(), rewriter.getI8Type(),
178+
rewriter.getIndexAttr(0));
179+
auto gepPtrOp = rewriter.create<mlir::LLVM::GEPOp>(
180+
getGlobalOp->getLoc(),
181+
mlir::LLVM::LLVMPointerType::get(context),
182+
llvmglobalOp.getType(), globalPtrOp,
183+
ArrayRef<mlir::Value>({cst0, cst0}));
184+
185+
mlir::ValueRange operands = adaptor.getOperands();
186+
187+
// Replace the old memref operand with the !llvm.ptr for the frm_str
188+
mlir::SmallVector<mlir::Value> newOperands;
189+
newOperands.push_back(gepPtrOp);
190+
newOperands.append(operands.begin() + 1, operands.end());
191+
192+
// Create the llvmir dialect function type for printf
193+
auto llvmI32Ty = mlir::IntegerType::get(context, 32);
194+
auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(context);
195+
auto llvmFnType =
196+
mlir::LLVM::LLVMFunctionType::get(llvmI32Ty, llvmPtrTy,
197+
/*isVarArg=*/true);
198+
199+
rewriter.setInsertionPoint(op);
200+
201+
// Insert an llvm.call op with the updated operands to printf
202+
rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
203+
op, llvmFnType, op.getCalleeAttr(), newOperands);
204+
205+
// Cleanup printf frm_str memref ops
206+
rewriter.eraseOp(reinterpret_castOP);
207+
rewriter.eraseOp(getGlobalOp);
208+
rewriter.eraseOp(globalOp);
209+
210+
return mlir::LogicalResult::success();
211+
}
212+
}
213+
214+
return op.emitError()
215+
<< "lowering of printf function with Format-String"
216+
"defined outside of printf is not supported yet";
217+
}
218+
219+
rewriter.replaceOpWithNewOp<mlir::func::CallOp>(
220+
op, op.getCalleeAttr(), types, adaptor.getOperands());
221+
return mlir::LogicalResult::success();
222+
223+
} else {
224+
// TODO: support lowering of indirect calls via func.call_indirect op
225+
return op.emitError() << "lowering of indirect calls not supported yet";
226+
}
110227
}
111228
};
112229

@@ -557,37 +674,60 @@ class CIRFuncOpLowering : public mlir::OpConversionPattern<cir::FuncOp> {
557674
mlir::ConversionPatternRewriter &rewriter) const override {
558675

559676
auto fnType = op.getFunctionType();
560-
mlir::TypeConverter::SignatureConversion signatureConversion(
561-
fnType.getNumInputs());
562677

563-
for (const auto &argType : enumerate(fnType.getInputs())) {
564-
auto convertedType = typeConverter->convertType(argType.value());
565-
if (!convertedType)
566-
return mlir::failure();
567-
signatureConversion.addInputs(argType.index(), convertedType);
568-
}
678+
if (fnType.isVarArg()) {
679+
// TODO: once the func dialect supports variadic functions rewrite this
680+
// For now only insert special handling of printf via the llvmir dialect
681+
if (op.getSymName().equals_insensitive("printf")) {
682+
auto context = rewriter.getContext();
683+
// Create a llvmir dialect function declaration for printf, the
684+
// signature is: i32 (!llvm.ptr, ...)
685+
auto llvmI32Ty = mlir::IntegerType::get(context, 32);
686+
auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(context);
687+
auto llvmFnType =
688+
mlir::LLVM::LLVMFunctionType::get(llvmI32Ty, llvmPtrTy,
689+
/*isVarArg=*/true);
690+
auto printfFunc = rewriter.create<mlir::LLVM::LLVMFuncOp>(
691+
op.getLoc(), "printf", llvmFnType);
692+
rewriter.replaceOp(op, printfFunc);
693+
} else {
694+
rewriter.eraseOp(op);
695+
return op.emitError() << "lowering of variadic functions (except "
696+
"printf) not supported yet";
697+
}
698+
} else {
699+
mlir::TypeConverter::SignatureConversion signatureConversion(
700+
fnType.getNumInputs());
701+
702+
for (const auto &argType : enumerate(fnType.getInputs())) {
703+
auto convertedType = typeConverter->convertType(argType.value());
704+
if (!convertedType)
705+
return mlir::failure();
706+
signatureConversion.addInputs(argType.index(), convertedType);
707+
}
569708

570-
SmallVector<mlir::NamedAttribute, 2> passThroughAttrs;
709+
SmallVector<mlir::NamedAttribute, 2> passThroughAttrs;
571710

572-
if (auto symVisibilityAttr = op.getSymVisibilityAttr())
573-
passThroughAttrs.push_back(
574-
rewriter.getNamedAttr("sym_visibility", symVisibilityAttr));
711+
if (auto symVisibilityAttr = op.getSymVisibilityAttr())
712+
passThroughAttrs.push_back(
713+
rewriter.getNamedAttr("sym_visibility", symVisibilityAttr));
575714

576-
mlir::Type resultType =
577-
getTypeConverter()->convertType(fnType.getReturnType());
578-
auto fn = rewriter.create<mlir::func::FuncOp>(
579-
op.getLoc(), op.getName(),
580-
rewriter.getFunctionType(signatureConversion.getConvertedTypes(),
581-
resultType ? mlir::TypeRange(resultType)
582-
: mlir::TypeRange()),
583-
passThroughAttrs);
715+
mlir::Type resultType =
716+
getTypeConverter()->convertType(fnType.getReturnType());
717+
auto fn = rewriter.create<mlir::func::FuncOp>(
718+
op.getLoc(), op.getName(),
719+
rewriter.getFunctionType(signatureConversion.getConvertedTypes(),
720+
resultType ? mlir::TypeRange(resultType)
721+
: mlir::TypeRange()),
722+
passThroughAttrs);
584723

585-
if (failed(rewriter.convertRegionTypes(&op.getBody(), *typeConverter,
586-
&signatureConversion)))
587-
return mlir::failure();
588-
rewriter.inlineRegionBefore(op.getBody(), fn.getBody(), fn.end());
724+
if (failed(rewriter.convertRegionTypes(&op.getBody(), *typeConverter,
725+
&signatureConversion)))
726+
return mlir::failure();
727+
rewriter.inlineRegionBefore(op.getBody(), fn.getBody(), fn.end());
589728

590-
rewriter.eraseOp(op);
729+
rewriter.eraseOp(op);
730+
}
591731
return mlir::LogicalResult::success();
592732
}
593733
};

clang/test/CIR/Lowering/ThroughMLIR/call.c

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,41 @@ int test(void) {
1212
// CHECK: %[[ARG:.+]] = arith.constant 2 : i32
1313
// CHECK-NEXT: call @foo(%[[ARG]]) : (i32) -> ()
1414
// CHECK: }
15+
16+
extern int printf(const char *str, ...);
17+
18+
// CHECK-LABEL: llvm.func @printf(!llvm.ptr, ...) -> i32
19+
// CHECK: llvm.mlir.global internal constant @[[FRMT_STR:.*]](dense<[37, 100, 44, 32, 37, 102, 44, 32, 37, 100, 44, 32, 37, 108, 108, 100, 44, 32, 37, 100, 44, 32, 37, 102, 10, 0]> : tensor<26xi8>) {addr_space = 0 : i32} : !llvm.array<26 x i8>
20+
21+
void testfunc(short s, float X, char C, long long LL, int I, double D) {
22+
printf("%d, %f, %d, %lld, %d, %f\n", s, X, C, LL, I, D);
23+
}
24+
25+
// CHECK: func.func @testfunc(%[[ARG0:.*]]: i16 {{.*}}, %[[ARG1:.*]]: f32 {{.*}}, %[[ARG2:.*]]: i8 {{.*}}, %[[ARG3:.*]]: i64 {{.*}}, %[[ARG4:.*]]: i32 {{.*}}, %[[ARG5:.*]]: f64 {{.*}}) {
26+
// CHECK: %[[ALLOCA_S:.*]] = memref.alloca() {alignment = 2 : i64} : memref<i16>
27+
// CHECK: %[[ALLOCA_X:.*]] = memref.alloca() {alignment = 4 : i64} : memref<f32>
28+
// CHECK: %[[ALLOCA_C:.*]] = memref.alloca() {alignment = 1 : i64} : memref<i8>
29+
// CHECK: %[[ALLOCA_LL:.*]] = memref.alloca() {alignment = 8 : i64} : memref<i64>
30+
// CHECK: %[[ALLOCA_I:.*]] = memref.alloca() {alignment = 4 : i64} : memref<i32>
31+
// CHECK: %[[ALLOCA_D:.*]] = memref.alloca() {alignment = 8 : i64} : memref<f64>
32+
// CHECK: memref.store %[[ARG0]], %[[ALLOCA_S]][] : memref<i16>
33+
// CHECK: memref.store %[[ARG1]], %[[ALLOCA_X]][] : memref<f32>
34+
// CHECK: memref.store %[[ARG2]], %[[ALLOCA_C]][] : memref<i8>
35+
// CHECK: memref.store %[[ARG3]], %[[ALLOCA_LL]][] : memref<i64>
36+
// CHECK: memref.store %[[ARG4]], %[[ALLOCA_I]][] : memref<i32>
37+
// CHECK: memref.store %[[ARG5]], %[[ALLOCA_D]][] : memref<f64>
38+
// CHECK: %[[FRMT_STR_ADDR:.*]] = llvm.mlir.addressof @[[FRMT_STR]] : !llvm.ptr
39+
// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i8
40+
// CHECK: %[[FRMT_STR_DATA:.*]] = llvm.getelementptr %[[FRMT_STR_ADDR]][%[[C0]], %[[C0]]] : (!llvm.ptr, i8, i8) -> !llvm.ptr, !llvm.array<26 x i8>
41+
// CHECK: %[[S:.*]] = memref.load %[[ALLOCA_S]][] : memref<i16>
42+
// CHECK: %[[S_EXT:.*]] = arith.extsi %3 : i16 to i32
43+
// CHECK: %[[X:.*]] = memref.load %[[ALLOCA_X]][] : memref<f32>
44+
// CHECK: %[[X_EXT:.*]] = arith.extf %5 : f32 to f64
45+
// CHECK: %[[C:.*]] = memref.load %[[ALLOCA_C]][] : memref<i8>
46+
// CHECK: %[[C_EXT:.*]] = arith.extsi %7 : i8 to i32
47+
// CHECK: %[[LL:.*]] = memref.load %[[ALLOCA_LL]][] : memref<i64>
48+
// CHECK: %[[I:.*]] = memref.load %[[ALLOCA_I]][] : memref<i32>
49+
// CHECK: %[[D:.*]] = memref.load %[[ALLOCA_D]][] : memref<f64>
50+
// CHECK: {{.*}} = llvm.call @printf(%[[FRMT_STR_DATA]], %[[S_EXT]], %[[X_EXT]], %[[C_EXT]], %[[LL]], %[[I]], %[[D]]) vararg(!llvm.func<i32 (ptr, ...)>) : (!llvm.ptr, i32, f64, i32, i64, i32, f64) -> i32
51+
// CHECK: return
52+
// CHECK: }

0 commit comments

Comments
 (0)