Skip to content

[mlir][LLVM] Add operand bundle support #108933

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 32 additions & 7 deletions flang/lib/Optimizer/CodeGen/CodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,26 @@ static unsigned getLenParamFieldId(mlir::Type ty) {
return getTypeDescFieldId(ty) + 1;
}

static llvm::SmallVector<mlir::NamedAttribute>
addLLVMOpBundleAttrs(mlir::ConversionPatternRewriter &rewriter,
llvm::ArrayRef<mlir::NamedAttribute> attrs,
int32_t numCallOperands) {
llvm::SmallVector<mlir::NamedAttribute> newAttrs;
newAttrs.reserve(attrs.size() + 2);

for (mlir::NamedAttribute attr : attrs) {
if (attr.getName() != "operandSegmentSizes")
newAttrs.push_back(attr);
}

newAttrs.push_back(rewriter.getNamedAttr(
"operandSegmentSizes",
rewriter.getDenseI32ArrayAttr({numCallOperands, 0})));
newAttrs.push_back(rewriter.getNamedAttr("op_bundle_sizes",
rewriter.getDenseI32ArrayAttr({})));
return newAttrs;
}

namespace {
/// Lower `fir.address_of` operation to `llvm.address_of` operation.
struct AddrOfOpConversion : public fir::FIROpConversion<fir::AddrOfOp> {
Expand Down Expand Up @@ -229,7 +249,8 @@ struct AllocaOpConversion : public fir::FIROpConversion<fir::AllocaOp> {
mlir::NamedAttribute attr = rewriter.getNamedAttr(
"callee", mlir::SymbolRefAttr::get(memSizeFn));
auto call = rewriter.create<mlir::LLVM::CallOp>(
loc, ity, lenParams, llvm::ArrayRef<mlir::NamedAttribute>{attr});
loc, ity, lenParams,
addLLVMOpBundleAttrs(rewriter, {attr}, lenParams.size()));
size = call.getResult();
llvmObjectType = ::getI8Type(alloc.getContext());
} else {
Expand Down Expand Up @@ -559,7 +580,9 @@ struct CallOpConversion : public fir::FIROpConversion<fir::CallOp> {
mlir::arith::AttrConvertFastMathToLLVM<fir::CallOp, mlir::LLVM::CallOp>
attrConvert(call);
rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
call, resultTys, adaptor.getOperands(), attrConvert.getAttrs());
call, resultTys, adaptor.getOperands(),
addLLVMOpBundleAttrs(rewriter, attrConvert.getAttrs(),
adaptor.getOperands().size()));
return mlir::success();
}
};
Expand Down Expand Up @@ -980,7 +1003,8 @@ struct AllocMemOpConversion : public fir::FIROpConversion<fir::AllocMemOp> {
loc, ity, size, integerCast(loc, rewriter, ity, opnd));
heap->setAttr("callee", getMalloc(heap, rewriter));
rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
heap, ::getLlvmPtrType(heap.getContext()), size, heap->getAttrs());
heap, ::getLlvmPtrType(heap.getContext()), size,
addLLVMOpBundleAttrs(rewriter, heap->getAttrs(), 1));
return mlir::success();
}

Expand Down Expand Up @@ -1037,9 +1061,9 @@ struct FreeMemOpConversion : public fir::FIROpConversion<fir::FreeMemOp> {
mlir::ConversionPatternRewriter &rewriter) const override {
mlir::Location loc = freemem.getLoc();
freemem->setAttr("callee", getFree(freemem, rewriter));
rewriter.create<mlir::LLVM::CallOp>(loc, mlir::TypeRange{},
mlir::ValueRange{adaptor.getHeapref()},
freemem->getAttrs());
rewriter.create<mlir::LLVM::CallOp>(
loc, mlir::TypeRange{}, mlir::ValueRange{adaptor.getHeapref()},
addLLVMOpBundleAttrs(rewriter, freemem->getAttrs(), 1));
rewriter.eraseOp(freemem);
return mlir::success();
}
Expand Down Expand Up @@ -2671,7 +2695,8 @@ struct FieldIndexOpConversion : public fir::FIROpConversion<fir::FieldIndexOp> {
"field", mlir::IntegerAttr::get(lowerTy().indexType(), index));
rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
field, lowerTy().offsetType(), adaptor.getOperands(),
llvm::ArrayRef<mlir::NamedAttribute>{callAttr, fieldAttr});
addLLVMOpBundleAttrs(rewriter, {callAttr, fieldAttr},
adaptor.getOperands().size()));
return mlir::success();
}

Expand Down
44 changes: 37 additions & 7 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,15 @@ def LLVM_InvokeOp : LLVM_Op<"invoke", [
Variadic<LLVM_Type>:$normalDestOperands,
Variadic<LLVM_Type>:$unwindDestOperands,
OptionalAttr<DenseI32ArrayAttr>:$branch_weights,
DefaultValuedAttr<CConv, "CConv::C">:$CConv);
DefaultValuedAttr<CConv, "CConv::C">:$CConv,
VariadicOfVariadic<LLVM_Type,
"op_bundle_sizes">:$op_bundle_operands,
DenseI32ArrayAttr:$op_bundle_sizes,
DefaultValuedProperty<
ArrayProperty<StringProperty, "operand bundle tags">,
"ArrayRef<std::string>{}",
"SmallVector<std::string>{}"
>:$op_bundle_tags);
let results = (outs Optional<LLVM_Type>:$result);
let successors = (successor AnySuccessor:$normalDest,
AnySuccessor:$unwindDest);
Expand Down Expand Up @@ -607,7 +615,8 @@ def LLVM_VaArgOp : LLVM_Op<"va_arg"> {
//===----------------------------------------------------------------------===//

def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
[DeclareOpInterfaceMethods<FastmathFlagsInterface>,
[AttrSizedOperandSegments,
DeclareOpInterfaceMethods<FastmathFlagsInterface>,
DeclareOpInterfaceMethods<CallOpInterface>,
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
DeclareOpInterfaceMethods<BranchWeightOpInterface>]> {
Expand Down Expand Up @@ -661,8 +670,15 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
OptionalAttr<LLVM_MemoryEffectsAttr>:$memory_effects,
OptionalAttr<UnitAttr>:$convergent,
OptionalAttr<UnitAttr>:$no_unwind,
OptionalAttr<UnitAttr>:$will_return
);
OptionalAttr<UnitAttr>:$will_return,
VariadicOfVariadic<LLVM_Type,
"op_bundle_sizes">:$op_bundle_operands,
DenseI32ArrayAttr:$op_bundle_sizes,
DefaultValuedProperty<
ArrayProperty<StringProperty, "operand bundle tags">,
"ArrayRef<std::string>{}",
"SmallVector<std::string>{}"
>:$op_bundle_tags);
// Append the aliasing related attributes defined in LLVM_MemAccessOpBase.
let arguments = !con(args, aliasAttrs);
let results = (outs Optional<LLVM_Type>:$result);
Expand All @@ -682,6 +698,7 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
OpBuilder<(ins "LLVMFunctionType":$calleeType, "StringRef":$callee,
CArg<"ValueRange", "{}">:$args)>
];
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
let extraClassDeclaration = [{
/// Returns the callee function type.
Expand Down Expand Up @@ -1895,21 +1912,34 @@ def LLVM_InlineAsmOp : LLVM_Op<"inline_asm", [DeclareOpInterfaceMethods<MemoryEf

def LLVM_CallIntrinsicOp
: LLVM_Op<"call_intrinsic",
[DeclareOpInterfaceMethods<FastmathFlagsInterface>]> {
[AttrSizedOperandSegments,
DeclareOpInterfaceMethods<FastmathFlagsInterface>]> {
let summary = "Call to an LLVM intrinsic function.";
let description = [{
Call the specified llvm intrinsic. If the intrinsic is overloaded, use
the MLIR function type of this op to determine which intrinsic to call.
}];
let arguments = (ins StrAttr:$intrin, Variadic<LLVM_Type>:$args,
DefaultValuedAttr<LLVM_FastmathFlagsAttr,
"{}">:$fastmathFlags);
"{}">:$fastmathFlags,
VariadicOfVariadic<LLVM_Type,
"op_bundle_sizes">:$op_bundle_operands,
DenseI32ArrayAttr:$op_bundle_sizes,
DefaultValuedProperty<
ArrayProperty<StringProperty, "operand bundle tags">,
"ArrayRef<std::string>{}",
"SmallVector<std::string>{}"
>:$op_bundle_tags);
let results = (outs Optional<LLVM_Type>:$results);
let llvmBuilder = [{
return convertCallLLVMIntrinsicOp(op, builder, moduleTranslation);
}];
let assemblyFormat = [{
$intrin `(` $args `)` `:` functional-type($args, $results) attr-dict
$intrin `(` $args `)`
( custom<OpBundles>($op_bundle_operands, type($op_bundle_operands),
$op_bundle_tags)^ )?
`:` functional-type($args, $results)
attr-dict
}];

let hasVerifier = 1;
Expand Down
4 changes: 4 additions & 0 deletions mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,10 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(),
promoted, callOp->getAttrs());

newOp.getProperties().operandSegmentSizes = {
static_cast<int32_t>(promoted.size()), 0};
newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({});

SmallVector<Value, 4> results;
if (numResults < 2) {
// If < 2 results, packing did not do anything and we can just return.
Expand Down
10 changes: 8 additions & 2 deletions mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -837,17 +837,23 @@ class FunctionCallPattern
matchAndRewrite(spirv::FunctionCallOp callOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (callOp.getNumResults() == 0) {
rewriter.replaceOpWithNewOp<LLVM::CallOp>(
auto newOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
callOp, std::nullopt, adaptor.getOperands(), callOp->getAttrs());
newOp.getProperties().operandSegmentSizes = {
static_cast<int32_t>(adaptor.getOperands().size()), 0};
newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({});
return success();
}

// Function returns a single result.
auto dstType = typeConverter.convertType(callOp.getType(0));
if (!dstType)
return rewriter.notifyMatchFailure(callOp, "type conversion failed");
rewriter.replaceOpWithNewOp<LLVM::CallOp>(
auto newOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
callOp, dstType, adaptor.getOperands(), callOp->getAttrs());
newOp.getProperties().operandSegmentSizes = {
static_cast<int32_t>(adaptor.getOperands().size()), 0};
newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({});
return success();
}
};
Expand Down
Loading
Loading