Skip to content

[mlir][nvvm] Move BasicPtxBuilder Interface to Its Own File (NFC) #68095

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions mlir/include/mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -792,10 +792,10 @@ def ConvertMemRefToSPIRV : Pass<"convert-memref-to-spirv"> {
//===----------------------------------------------------------------------===//

def ConvertNVVMToLLVMPass : Pass<"convert-nvvm-to-llvm"> {
let summary = "Convert NVVM dialect to LLVM dialect";
let summary = "Convert NVVM to PTX with Inline Assembly in LLVM dialect";
let description = [{
This pass generates inline assembly for the NVVM ops which is not
implemented in LLVM core.
This pass generates PTX instructions using inline assembly for NVVM
operations implements `BasicPtxBuilderInterface`.
}];
let dependentDialects = [
"NVVM::NVVMDialect",
Expand Down
80 changes: 80 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
//===- BasicPtxBuilderInterface.td - PTX builder interface -*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Defines the interface to build PTX (Parallel Thread Execution) from NVVM Ops
// automatically. It is used by NVVM to LLVM pass.
//
//===----------------------------------------------------------------------===//

#ifndef NVVM_DIALECT_NVVM_IR_BASICPTXBUILDERINTERFACE_H_
#define NVVM_DIALECT_NVVM_IR_BASICPTXBUILDERINTERFACE_H_

#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"

namespace mlir {
namespace NVVM {
/// Register read/write modifier to build constraint string for PTX inline
/// https://docs.nvidia.com/cuda/inline-ptx-assembly/index.html#parameters
enum class PTXRegisterMod {
/// Read register with no modifier
Read = 0,
/// Read register with '+' modifier
Write = 2,
/// Read register with '=' modifier.
/// Note that, this is not natively supported by LLVM, but it is possible to
/// set read and write for the same operand.
ReadWrite = 1,
};
} // namespace NVVM
} // namespace mlir

/// Include the generated interface declarations.
#include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h.inc"

namespace mlir {

namespace NVVM {

/// A class to build PTX assembly automatically. It is used by
/// BasicPtxBuilderInterface.
class PtxBuilder {
// The interface op that is used to build the PTX.
BasicPtxBuilderInterface interfaceOp;
// Rewriter to create new operations.
PatternRewriter &rewriter;
// The operands for the PTX instruction
SmallVector<Value> ptxOperands;
// Register constraints (read, write, readwrite) and register data types
std::string registerConstraints;

bool hasResult = false;

public:
/// Single constructor that only initializes members.
PtxBuilder(Operation *op, PatternRewriter &rewriter)
: interfaceOp(op), rewriter(rewriter) {}

/// Add an operand with the read/write input type.
void insertValue(Value v, PTXRegisterMod itype = PTXRegisterMod::Read);

/// Builds the inline assembly Op and returns it. The `insertValue` needs to
/// be called to pass operands before building the PTX.
LLVM::InlineAsmOp build();

/// Shortcut to build the inline assembly Op and replace or erase the original
/// op with
void buildAndReplaceOp();
};

} // namespace NVVM
} // namespace mlir

#endif // NVVM_DIALECT_NVVM_IR_BASICPTXBUILDERINTERFACE_H_
139 changes: 139 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
//===- BasicPtxBuilderInterface.td - PTX builder interface -*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Defines the interface to build PTX (Parallel Thread Execution) from NVVM Ops
// automatically. It is used by NVVM to LLVM pass.
//
//===----------------------------------------------------------------------===//

#ifndef BASICPTXBUILDER_OP_INTERFACE
#define BASICPTXBUILDER_OP_INTERFACE

include "mlir/IR/EnumAttr.td"
include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td"
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"

//===----------------------------------------------------------------------===//
// Basic PTX Builder Interface
//===----------------------------------------------------------------------===//

def BasicPtxBuilderOpInterface : OpInterface<"BasicPtxBuilderInterface"> {
let description = [{
This interface is used to generate inline assembly with PTX for basic
operations. It's utilized in the `convert-nvvm-to-llvm pass` to lower
NVVM Ops that implement this interface to PTX (parallel thread execution)
using inline assembly Ops. Interface methods play a crucial role in this
lowering process.

Here's an example of an Op with the `BasicPtxBuilderOpInterface`:
```tablegen
def NVVM_SpecialOp : NVVM_Op<"special.op",
[DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
Results<(outs LLVM_Type:$res)>,
Arguments<(ins LLVM_i64ptr_any:$op1, I32:$op2)> {
...
let extraClassDefinition = [{
std::string $cppClass::getPtx() {
return std::string("special.op %0, %1, %2;");
}
} ];
```

In the above NVVM Op example:
```mlir
%0 = nvvm.special.op %1, %2 : !llvm.ptr, i32 -> i32
```

The `convert-nvvm-to-llvm` pass generates the inline assembly like below.
The order of arguments is retained, and the read and write modifiers are
set based on the input and result types:
```mlir
%0 = llvm.inline_asm
has_side_effects
asm_dialect =
att "special.op %0, %1, %2;", "=r,l,r" %arg0, %arg1
: (!llvm.ptr, i32) -> i32
```
}];
let cppNamespace = "::mlir::NVVM";
let methods = [
InterfaceMethod<
/*desc=*/[{ Returns PTX assembly with operand number. }],
/*retType=*/"std::string",
/*methodName=*/"getPtx"
>,
InterfaceMethod<
/*desc=*/[{
This function indicates whether the operation is supported by LLVM
intrinsics. It's particularly useful for operations that have
specific cases with LLVM intrinsic support.
}],
/*retType=*/"bool",
/*methodName=*/"hasIntrinsic",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/"return false;"
>,
InterfaceMethod<
/*desc=*/[{Return whether the operation has memory side effects.}],
/*retType=*/"bool",
/*methodName=*/"hasSideEffect",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/"return true;"
>,

InterfaceMethod<
/*desc=*/[{Helper function to generate i32 constant value.}],
/*retType=*/"::mlir::Value",
/*methodName=*/"makeConstantI32",
/*args=*/(ins "::mlir::RewriterBase &":$rewriter, "int" : $val),
/*methodBody=*/"",
/*defaultImpl=*/ [{
mlir::Operation* op = $_op;
return rewriter.create<LLVM::ConstantOp>(
op->getLoc(), rewriter.getIntegerType(32), val);
}]
>,
InterfaceMethod<
/*desc=*/[{
This function supplies the necessary arguments for passing PTX code,
following this order:
1) Adds results
2) Adds operands
3) Adds attributes
}],
/*retType=*/"void",
/*methodName=*/"getAsmValues",
/*args=*/(ins "::mlir::RewriterBase &":$rewriter,
"llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>&" : $asmValues),
/*methodBody=*/"",
/*defaultImpl=*/ [{
mlir::Operation* op = $_op;

// Step 1. Add results
for (auto val : op->getResults())
asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Write});

// Step 2. Add operands
for (auto val : op->getOperands())
asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Read});

// Step 3. Add attributes
for (auto attr : op->getAttrs()) {
if (auto intAttr = dyn_cast<mlir::IntegerAttr>(attr.getValue())) {
::mlir::Value val = makeConstantI32(rewriter, intAttr.getInt());
asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Read});
}
}
}]
>
];
}

#endif // BASICPTXBUILDER_OP_INTERFACE
8 changes: 6 additions & 2 deletions mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,18 @@ mlir_tablegen(LLVMIntrinsicFromLLVMIRConversions.inc -gen-intr-from-llvmir-conve
mlir_tablegen(LLVMConvertibleLLVMIRIntrinsics.inc -gen-convertible-llvmir-intrinsics)
add_public_tablegen_target(MLIRLLVMIntrinsicConversionsIncGen)

set(LLVM_TARGET_DEFINITIONS BasicPtxBuilderInterface.td)
mlir_tablegen(BasicPtxBuilderInterface.h.inc -gen-op-interface-decls)
mlir_tablegen(BasicPtxBuilderInterface.cpp.inc -gen-op-interface-defs)
add_public_tablegen_target(MLIRBasicPtxBuilderInterfaceIncGen)
add_dependencies(mlir-headers MLIRBasicPtxBuilderInterfaceIncGen)

add_mlir_dialect(NVVMOps nvvm)
add_mlir_doc(NVVMOps NVVMDialect Dialects/ -gen-dialect-doc -dialect=nvvm)
set(LLVM_TARGET_DEFINITIONS NVVMOps.td)
mlir_tablegen(NVVMConversions.inc -gen-llvmir-conversions)
mlir_tablegen(NVVMOpsEnums.h.inc -gen-enum-decls)
mlir_tablegen(NVVMOpsEnums.cpp.inc -gen-enum-defs)
mlir_tablegen(NVVMOpsInterface.h.inc -gen-op-interface-decls)
mlir_tablegen(NVVMOpsInterface.cpp.inc -gen-op-interface-defs)
mlir_tablegen(NVVMOpsAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=nvvm)
mlir_tablegen(NVVMOpsAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=nvvm)
add_public_tablegen_target(MLIRNVVMConversionsIncGen)
Expand Down
3 changes: 1 addition & 2 deletions mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#define MLIR_DIALECT_LLVMIR_NVVMDIALECT_H_

#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
Expand All @@ -26,8 +27,6 @@
namespace mlir {
namespace NVVM {

#include "mlir/Dialect/LLVMIR/NVVMOpsInterface.h.inc"

/// NVVM memory space identifiers.
enum NVVMMemorySpace {
/// Global memory space identifier.
Expand Down
Loading