|
23 | 23 | #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
24 | 24 | #include "mlir/Dialect/Func/IR/FuncOps.h"
|
25 | 25 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
| 26 | +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" |
26 | 27 | #include "mlir/Dialect/Math/IR/Math.h"
|
27 | 28 | #include "mlir/Dialect/MemRef/IR/MemRef.h"
|
28 | 29 | #include "mlir/Dialect/SCF/IR/SCF.h"
|
29 | 30 | #include "mlir/Dialect/Vector/IR/VectorOps.h"
|
| 31 | +#include "mlir/IR/BuiltinAttributes.h" |
30 | 32 | #include "mlir/IR/BuiltinDialect.h"
|
31 | 33 | #include "mlir/IR/BuiltinOps.h"
|
32 | 34 | #include "mlir/IR/BuiltinTypes.h"
|
@@ -104,9 +106,124 @@ class CIRCallOpLowering : public mlir::OpConversionPattern<cir::CallOp> {
|
104 | 106 | if (mlir::failed(
|
105 | 107 | getTypeConverter()->convertTypes(op.getResultTypes(), types)))
|
106 | 108 | return mlir::failure();
|
107 |
| - rewriter.replaceOpWithNewOp<mlir::func::CallOp>( |
108 |
| - op, op.getCalleeAttr(), types, adaptor.getOperands()); |
109 |
| - return mlir::LogicalResult::success(); |
| 109 | + |
| 110 | + if (!op.isIndirect()) { |
| 111 | + // Currently variadic functions are not supported by the builtin func |
| 112 | + // dialect. For now only basic call to printf are supported by using the |
| 113 | + // llvmir dialect. |
| 114 | + // TODO: remove this and add support for variadic function calls once |
| 115 | + // TODO: supported by the func dialect |
| 116 | + if (op.getCallee()->equals_insensitive("printf")) { |
| 117 | + SmallVector<mlir::Type> operandTypes = |
| 118 | + llvm::to_vector(adaptor.getOperands().getTypes()); |
| 119 | + |
| 120 | + // Drop the initial memref operand type (we replace the memref format |
| 121 | + // string with equivalent llvm.mlir ops) |
| 122 | + operandTypes.erase(operandTypes.begin()); |
| 123 | + |
| 124 | + // Check that the printf attributes can be used in llvmir dialect (i.e |
| 125 | + // they have integer/float type) |
| 126 | + if (!llvm::all_of(operandTypes, [](mlir::Type ty) { |
| 127 | + return mlir::LLVM::isCompatibleType(ty); |
| 128 | + })) { |
| 129 | + return op.emitError() |
| 130 | + << "lowering of printf attributes having a type that is " |
| 131 | + "converted to memref in cir-to-mlir lowering (e.g. " |
| 132 | + "pointers) not supported yet"; |
| 133 | + } |
| 134 | + |
| 135 | + // Currently only versions of printf are supported where the format |
| 136 | + // string is defined inside the printf ==> the lowering of the cir ops |
| 137 | + // will match: |
| 138 | + // %global = memref.get_global %frm_str |
| 139 | + // %* = memref.reinterpret_cast (%global, 0) |
| 140 | + if (auto reinterpret_castOP = |
| 141 | + mlir::dyn_cast_or_null<mlir::memref::ReinterpretCastOp>( |
| 142 | + adaptor.getOperands()[0].getDefiningOp())) { |
| 143 | + if (auto getGlobalOp = |
| 144 | + mlir::dyn_cast_or_null<mlir::memref::GetGlobalOp>( |
| 145 | + reinterpret_castOP->getOperand(0).getDefiningOp())) { |
| 146 | + mlir::ModuleOp parentModule = op->getParentOfType<mlir::ModuleOp>(); |
| 147 | + |
| 148 | + auto context = rewriter.getContext(); |
| 149 | + |
| 150 | + // Find the memref.global op defining the frm_str |
| 151 | + auto globalOp = parentModule.lookupSymbol<mlir::memref::GlobalOp>( |
| 152 | + getGlobalOp.getNameAttr()); |
| 153 | + |
| 154 | + rewriter.setInsertionPoint(globalOp); |
| 155 | + |
| 156 | + // Insert a equivalent llvm.mlir.global |
| 157 | + auto initialvalueAttr = |
| 158 | + mlir::dyn_cast_or_null<mlir::DenseIntElementsAttr>( |
| 159 | + globalOp.getInitialValueAttr()); |
| 160 | + |
| 161 | + auto type = mlir::LLVM::LLVMArrayType::get( |
| 162 | + mlir::IntegerType::get(context, 8), |
| 163 | + initialvalueAttr.getNumElements()); |
| 164 | + |
| 165 | + auto llvmglobalOp = rewriter.create<mlir::LLVM::GlobalOp>( |
| 166 | + globalOp->getLoc(), type, true, mlir::LLVM::Linkage::Internal, |
| 167 | + "printf_format_" + globalOp.getSymName().str(), |
| 168 | + initialvalueAttr, 0); |
| 169 | + |
| 170 | + rewriter.setInsertionPoint(getGlobalOp); |
| 171 | + |
| 172 | + // Insert llvmir dialect ops to retrive the !llvm.ptr of the global |
| 173 | + auto globalPtrOp = rewriter.create<mlir::LLVM::AddressOfOp>( |
| 174 | + getGlobalOp->getLoc(), llvmglobalOp); |
| 175 | + |
| 176 | + mlir::Value cst0 = rewriter.create<mlir::LLVM::ConstantOp>( |
| 177 | + getGlobalOp->getLoc(), rewriter.getI8Type(), |
| 178 | + rewriter.getIndexAttr(0)); |
| 179 | + auto gepPtrOp = rewriter.create<mlir::LLVM::GEPOp>( |
| 180 | + getGlobalOp->getLoc(), |
| 181 | + mlir::LLVM::LLVMPointerType::get(context), |
| 182 | + llvmglobalOp.getType(), globalPtrOp, |
| 183 | + ArrayRef<mlir::Value>({cst0, cst0})); |
| 184 | + |
| 185 | + mlir::ValueRange operands = adaptor.getOperands(); |
| 186 | + |
| 187 | + // Replace the old memref operand with the !llvm.ptr for the frm_str |
| 188 | + mlir::SmallVector<mlir::Value> newOperands; |
| 189 | + newOperands.push_back(gepPtrOp); |
| 190 | + newOperands.append(operands.begin() + 1, operands.end()); |
| 191 | + |
| 192 | + // Create the llvmir dialect function type for printf |
| 193 | + auto llvmI32Ty = mlir::IntegerType::get(context, 32); |
| 194 | + auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(context); |
| 195 | + auto llvmFnType = |
| 196 | + mlir::LLVM::LLVMFunctionType::get(llvmI32Ty, llvmPtrTy, |
| 197 | + /*isVarArg=*/true); |
| 198 | + |
| 199 | + rewriter.setInsertionPoint(op); |
| 200 | + |
| 201 | + // Insert an llvm.call op with the updated operands to printf |
| 202 | + rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>( |
| 203 | + op, llvmFnType, op.getCalleeAttr(), newOperands); |
| 204 | + |
| 205 | + // Cleanup printf frm_str memref ops |
| 206 | + rewriter.eraseOp(reinterpret_castOP); |
| 207 | + rewriter.eraseOp(getGlobalOp); |
| 208 | + rewriter.eraseOp(globalOp); |
| 209 | + |
| 210 | + return mlir::LogicalResult::success(); |
| 211 | + } |
| 212 | + } |
| 213 | + |
| 214 | + return op.emitError() |
| 215 | + << "lowering of printf function with Format-String" |
| 216 | + "defined outside of printf is not supported yet"; |
| 217 | + } |
| 218 | + |
| 219 | + rewriter.replaceOpWithNewOp<mlir::func::CallOp>( |
| 220 | + op, op.getCalleeAttr(), types, adaptor.getOperands()); |
| 221 | + return mlir::LogicalResult::success(); |
| 222 | + |
| 223 | + } else { |
| 224 | + // TODO: support lowering of indirect calls via func.call_indirect op |
| 225 | + return op.emitError() << "lowering of indirect calls not supported yet"; |
| 226 | + } |
110 | 227 | }
|
111 | 228 | };
|
112 | 229 |
|
@@ -557,37 +674,60 @@ class CIRFuncOpLowering : public mlir::OpConversionPattern<cir::FuncOp> {
|
557 | 674 | mlir::ConversionPatternRewriter &rewriter) const override {
|
558 | 675 |
|
559 | 676 | auto fnType = op.getFunctionType();
|
560 |
| - mlir::TypeConverter::SignatureConversion signatureConversion( |
561 |
| - fnType.getNumInputs()); |
562 | 677 |
|
563 |
| - for (const auto &argType : enumerate(fnType.getInputs())) { |
564 |
| - auto convertedType = typeConverter->convertType(argType.value()); |
565 |
| - if (!convertedType) |
566 |
| - return mlir::failure(); |
567 |
| - signatureConversion.addInputs(argType.index(), convertedType); |
568 |
| - } |
| 678 | + if (fnType.isVarArg()) { |
| 679 | + // TODO: once the func dialect supports variadic functions rewrite this |
| 680 | + // For now only insert special handling of printf via the llvmir dialect |
| 681 | + if (op.getSymName().equals_insensitive("printf")) { |
| 682 | + auto context = rewriter.getContext(); |
| 683 | + // Create a llvmir dialect function declaration for printf, the |
| 684 | + // signature is: i32 (!llvm.ptr, ...) |
| 685 | + auto llvmI32Ty = mlir::IntegerType::get(context, 32); |
| 686 | + auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(context); |
| 687 | + auto llvmFnType = |
| 688 | + mlir::LLVM::LLVMFunctionType::get(llvmI32Ty, llvmPtrTy, |
| 689 | + /*isVarArg=*/true); |
| 690 | + auto printfFunc = rewriter.create<mlir::LLVM::LLVMFuncOp>( |
| 691 | + op.getLoc(), "printf", llvmFnType); |
| 692 | + rewriter.replaceOp(op, printfFunc); |
| 693 | + } else { |
| 694 | + rewriter.eraseOp(op); |
| 695 | + return op.emitError() << "lowering of variadic functions (except " |
| 696 | + "printf) not supported yet"; |
| 697 | + } |
| 698 | + } else { |
| 699 | + mlir::TypeConverter::SignatureConversion signatureConversion( |
| 700 | + fnType.getNumInputs()); |
| 701 | + |
| 702 | + for (const auto &argType : enumerate(fnType.getInputs())) { |
| 703 | + auto convertedType = typeConverter->convertType(argType.value()); |
| 704 | + if (!convertedType) |
| 705 | + return mlir::failure(); |
| 706 | + signatureConversion.addInputs(argType.index(), convertedType); |
| 707 | + } |
569 | 708 |
|
570 |
| - SmallVector<mlir::NamedAttribute, 2> passThroughAttrs; |
| 709 | + SmallVector<mlir::NamedAttribute, 2> passThroughAttrs; |
571 | 710 |
|
572 |
| - if (auto symVisibilityAttr = op.getSymVisibilityAttr()) |
573 |
| - passThroughAttrs.push_back( |
574 |
| - rewriter.getNamedAttr("sym_visibility", symVisibilityAttr)); |
| 711 | + if (auto symVisibilityAttr = op.getSymVisibilityAttr()) |
| 712 | + passThroughAttrs.push_back( |
| 713 | + rewriter.getNamedAttr("sym_visibility", symVisibilityAttr)); |
575 | 714 |
|
576 |
| - mlir::Type resultType = |
577 |
| - getTypeConverter()->convertType(fnType.getReturnType()); |
578 |
| - auto fn = rewriter.create<mlir::func::FuncOp>( |
579 |
| - op.getLoc(), op.getName(), |
580 |
| - rewriter.getFunctionType(signatureConversion.getConvertedTypes(), |
581 |
| - resultType ? mlir::TypeRange(resultType) |
582 |
| - : mlir::TypeRange()), |
583 |
| - passThroughAttrs); |
| 715 | + mlir::Type resultType = |
| 716 | + getTypeConverter()->convertType(fnType.getReturnType()); |
| 717 | + auto fn = rewriter.create<mlir::func::FuncOp>( |
| 718 | + op.getLoc(), op.getName(), |
| 719 | + rewriter.getFunctionType(signatureConversion.getConvertedTypes(), |
| 720 | + resultType ? mlir::TypeRange(resultType) |
| 721 | + : mlir::TypeRange()), |
| 722 | + passThroughAttrs); |
584 | 723 |
|
585 |
| - if (failed(rewriter.convertRegionTypes(&op.getBody(), *typeConverter, |
586 |
| - &signatureConversion))) |
587 |
| - return mlir::failure(); |
588 |
| - rewriter.inlineRegionBefore(op.getBody(), fn.getBody(), fn.end()); |
| 724 | + if (failed(rewriter.convertRegionTypes(&op.getBody(), *typeConverter, |
| 725 | + &signatureConversion))) |
| 726 | + return mlir::failure(); |
| 727 | + rewriter.inlineRegionBefore(op.getBody(), fn.getBody(), fn.end()); |
589 | 728 |
|
590 |
| - rewriter.eraseOp(op); |
| 729 | + rewriter.eraseOp(op); |
| 730 | + } |
591 | 731 | return mlir::LogicalResult::success();
|
592 | 732 | }
|
593 | 733 | };
|
|
0 commit comments