Skip to content

[MLIR][NVVM] Add support for tcgen05.{ld, st} #130728

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 202 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2929,6 +2929,208 @@ def NVVM_Tcgen05CpOp : NVVM_Op<"tcgen05.cp"> {
}];
}

//===----------------------------------------------------------------------===//
// NVVM tcgen05 LdSt Shape Attr
//===----------------------------------------------------------------------===//

def Tcgen05LdStShape16x64b: I32EnumAttrCase<"SHAPE_16X64B", 0, "shape_16x64b">;
def Tcgen05LdStShape16x128b: I32EnumAttrCase<"SHAPE_16X128B", 1, "shape_16x128b">;
def Tcgen05LdStShape16x256b: I32EnumAttrCase<"SHAPE_16X256B", 2, "shape_16x256b">;
def Tcgen05LdStShape32x32b: I32EnumAttrCase<"SHAPE_32X32B", 3, "shape_32x32b">;
def Tcgen05LdStShape16x32bx2: I32EnumAttrCase<"SHAPE_16X32BX2", 4, "shape_16x32bx2">;

def Tcgen05LdStShape: I32EnumAttr<
"Tcgen05LdStShape",
"",
[Tcgen05LdStShape16x64b, Tcgen05LdStShape16x128b, Tcgen05LdStShape16x256b,
Tcgen05LdStShape32x32b, Tcgen05LdStShape16x32bx2]
> {
let cppNamespace = "::mlir::NVVM";
let genSpecializedAttr = 0;
}

def Tcgen05LdStShapeAttr: EnumAttr<NVVM_Dialect, Tcgen05LdStShape, "tcgen05_ldst_shape"> {
let assemblyFormat = "`<` $value `>`";
}

//===----------------------------------------------------------------------===//
// NVVM tcgen05.ld Op
//===----------------------------------------------------------------------===//

def NVVM_Tcgen05LdOp : NVVM_Op<"tcgen05.ld"> {
let summary = "tensor memory load instructions";
let arguments = (ins
// Attributes
UnitAttr:$pack,
Tcgen05LdStShapeAttr:$shape,
// Arguments
LLVM_PointerTensor:$tmemAddr,
Optional<I64>:$offset
);

let results = (outs AnyTypeOf<[I32, VectorOfLengthAndType<
[2, 4, 8, 16, 32, 64, 128], [I32]>]>:$res);

let assemblyFormat = [{
$tmemAddr (`,` $offset^)? (`pack` $pack^)? attr-dict `:` type($res)
}];

let description = [{
Instruction `tcgen05.ld` asynchronously loads data from the Tensor Memory at
the location specified by the 32-bit address operand `tmemAddr` into the
destination register `res`, collectively across all threads of the warps.

The `shape` and the `num` attribute together determines the total
dimension of the data which is loaded from the Tensor Memory. The `shape`
attribute indicates the base dimension of data to be accessed as described
in the Data Movement Shape. The `num` attribute indicates the repeat
factor on the base dimension resulting in the total dimension of the data
that is accessed.

The shape `16x32bx2` performs two accesses into Tensor Memory of the shape
`16x32b`. The base address of the first access is specified by `tmemAddr`
and the base address of the second access is specified by
`tmemAddr + offset`, where `offset` is an immediate argument.

The unit attribute `pack` can be used to pack two 16-bit
elements from adjacent columns into a single 32-bit element during the load.

The following table describes the size of the vector for various combinations
of `num` and `shape` attributes
|=====================================================================|
| num/shape | 16x32bx2/16x64b/32x32b | 16x128b | 16x256b |
|=====================================================================|
| x1 | 1 | 2 | 4 |
| x2 | 2 | 4 | 8 |
| x4 | 4 | 8 | 16 |
| x8 | 8 | 16 | 32 |
| x16 | 16 | 32 | 64 |
| x32 | 32 | 64 | 128 |
| x64 | 64 | 128 | NA |
| x128 | 128 | NA | NA |
|=====================================================================|

Example:
```mlir
nvvm.tcgen05.ld %tmemAddr, %offset pack {
shape = #nvvm.tcgen05_ldst_shape<shape_16x32bx2>,
} : <2xi32>
```

[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-st)
}];

let hasVerifier = 1;

string llvmBuilder = [{
llvm::LLVMContext &Context = moduleTranslation.getLLVMContext();
auto Pack = llvm::ConstantInt::get(Context, llvm::APInt(1, $pack));

unsigned num = $_resultType->isVectorTy()
? llvm::cast<llvm::VectorType>($_resultType)
->getElementCount()
.getFixedValue()
: 1;

auto ID = getTcgen05LdIntrinsicID($shape, num);
if (ID == llvm::Intrinsic::not_intrinsic)
llvm::report_fatal_error("unknow intrinsic signature for tcgen05.ld");

if ($offset)
$res = createIntrinsicCall(builder, ID, {$tmemAddr, $offset, Pack});
else
$res = createIntrinsicCall(builder, ID, {$tmemAddr, Pack});
}];
}

//===----------------------------------------------------------------------===//
// NVVM tcgen05.st Op
//===----------------------------------------------------------------------===//

def NVVM_Tcgen05StOp : NVVM_Op<"tcgen05.st"> {
let summary = "tensor memory store instructions";
let arguments = (ins
// Attributes
UnitAttr:$unpack,
Tcgen05LdStShapeAttr:$shape,
// Arguments
LLVM_PointerTensor:$tmemAddr,
AnyTypeOf<[I32, VectorOfLengthAndType<
[2, 4, 8, 16, 32, 64, 128], [I32]>]>:$val,
Optional<I64>:$offset
);

let assemblyFormat = [{
$tmemAddr `,` $val (`,` $offset^)? (`unpack` $unpack^)? attr-dict `:` type($val)
}];

let description = [{
Instruction `tcgen05.st` asynchronously stores data from the source register `r`
into the Tensor Memory at the location specified by the 32-bit address operand
`tmemAddr`, collectively across all threads of the warps.

The `shape` and the `num` attribute together determines the total dimension of
the data which is stored to the Tensor Memory. The `shape` indicates the base
dimension of data to be accessed. The `num` attribute indicates the repeat
factor on the base dimension resulting in the total dimension of the data that
is accessed.

The shape `16x32bx2` performs two accesses into Tensor Memory of the shape
`16x32b`. The base address of the first access is specified by `tmemAddr`
and the base address of the second access is specified by
`tmemAddr + offset`, where `offset` is an immediate argument.

The unit attribute `unpack` can be used to unpack a 32-bit element
in the register into two 16-bit elements and store them in adjacent columns.

The following table describes the size of the vector for various combinations
of `num` and `shape` attributes
|=====================================================================|
| num/shape | 16x32bx2/16x64b/32x32b | 16x128b | 16x256b |
|=====================================================================|
| x1 | 1 | 2 | 4 |
| x2 | 2 | 4 | 8 |
| x4 | 4 | 8 | 16 |
| x8 | 8 | 16 | 32 |
| x16 | 16 | 32 | 64 |
| x32 | 32 | 64 | 128 |
| x64 | 64 | 128 | NA |
| x128 | 128 | NA | NA |
|=====================================================================|

Example:
```mlir
nvvm.tcgen05.st %tmemAddr, %val, %offset unpack {
shape = #nvvm.tcgen05_ldst_shape<shape_16x32bx2>,
} : <2xi32>
```

[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-st)
}];

string llvmBuilder = [{
llvm::LLVMContext &Context = moduleTranslation.getLLVMContext();
auto Unpack = llvm::ConstantInt::get(Context, llvm::APInt(1, $unpack));

auto valTy = $val->getType();
uint32_t num = valTy->isVectorTy() ? llvm::cast<llvm::VectorType>(valTy)
->getElementCount()
.getFixedValue()
: 1;

auto ID = getTcgen05StIntrinsicID($shape, num);
if (ID == llvm::Intrinsic::not_intrinsic)
llvm::report_fatal_error("unknow intrinsic signature for tcgen05.st");

if ($offset)
createIntrinsicCall(builder, ID, {$tmemAddr, $offset, $val, Unpack});
else
createIntrinsicCall(builder, ID, {$tmemAddr, $val, Unpack});
}];

let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
// NVVM target attribute.
//===----------------------------------------------------------------------===//
Expand Down
46 changes: 46 additions & 0 deletions mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include "llvm/IR/Function.h"
#include "llvm/IR/Type.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/raw_ostream.h"
#include <cassert>
Expand Down Expand Up @@ -1387,6 +1388,51 @@ llvm::Intrinsic::ID Tcgen05CpOp::getIntrinsicID(Operation &op) {
llvm_unreachable("Invalid shape in tcgen05 cp Op");
}

// Returns the valid vector length for a given shape and vector length, the
// function models the table mentioned in the tcgen05.{ld, st} Op description
static unsigned isValidVectorLength(NVVM::Tcgen05LdStShape Shape,
unsigned VecLen) {
if (Shape == NVVM::Tcgen05LdStShape::SHAPE_16X128B)
return VecLen >= 2;
if (Shape == NVVM::Tcgen05LdStShape::SHAPE_16X256B)
return VecLen >= 4;
return true;
}

LogicalResult Tcgen05LdOp::verify() {
LogicalResult Result = success();
if (getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
Result = emitError("shape 16x32bx2 requires offset argument");

auto ResTy = getRes().getType();
unsigned ResLen = isa<VectorType>(ResTy)
? llvm::cast<VectorType>(ResTy).getNumElements()
: 1;
if (!isValidVectorLength(getShape(), ResLen))
Result = emitError(llvm::formatv("invalid result type length {0} for shape "
"{1} in tcgen05.ld Op",
ResLen, stringifyEnum(getShape())));

return Result;
}

LogicalResult Tcgen05StOp::verify() {
LogicalResult Result = success();
if (getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
Result = emitError("shape 16x32bx2 requires offset argument");

auto ValTy = getVal().getType();
unsigned ValLen = isa<VectorType>(ValTy)
? llvm::cast<VectorType>(ValTy).getNumElements()
: 1;
if (!isValidVectorLength(getShape(), ValLen))
Result = emitError(llvm::formatv("invalid input length {0} for shape "
"{1} in tcgen05.st Op",
ValLen, stringifyEnum(getShape())));

return Result;
}

/// Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might
/// have ConstantRangeAttr.
static void nvvmInferResultRanges(Operation *op, Value result,
Expand Down
106 changes: 106 additions & 0 deletions mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,112 @@ static unsigned getUnidirectionalFenceProxyID(NVVM::ProxyKind fromProxy,
llvm_unreachable("Unsupported proxy kinds");
}

#define TCGEN05LD(SHAPE, NUM) llvm::Intrinsic::nvvm_tcgen05_ld_##SHAPE##_##NUM

static llvm::Intrinsic::ID
getTcgen05LdIntrinsicID(mlir::NVVM::Tcgen05LdStShape shape, uint32_t num) {
llvm::Intrinsic::ID Shape16x64b[] = {
TCGEN05LD(16x64b, x1), TCGEN05LD(16x64b, x2), TCGEN05LD(16x64b, x4),
TCGEN05LD(16x64b, x8), TCGEN05LD(16x64b, x16), TCGEN05LD(16x64b, x32),
TCGEN05LD(16x64b, x64), TCGEN05LD(16x64b, x128),
};

llvm::Intrinsic::ID Shape16x128b[] = {
TCGEN05LD(16x128b, x1), TCGEN05LD(16x128b, x2), TCGEN05LD(16x128b, x4),
TCGEN05LD(16x128b, x8), TCGEN05LD(16x128b, x16), TCGEN05LD(16x128b, x32),
TCGEN05LD(16x128b, x64),
};

llvm::Intrinsic::ID Shape16x256b[] = {
TCGEN05LD(16x256b, x1), TCGEN05LD(16x256b, x2), TCGEN05LD(16x256b, x4),
TCGEN05LD(16x256b, x8), TCGEN05LD(16x256b, x16), TCGEN05LD(16x256b, x32),
};

llvm::Intrinsic::ID Shape16x32bx2[] = {
TCGEN05LD(16x32bx2, x1), TCGEN05LD(16x32bx2, x2),
TCGEN05LD(16x32bx2, x4), TCGEN05LD(16x32bx2, x8),
TCGEN05LD(16x32bx2, x16), TCGEN05LD(16x32bx2, x32),
TCGEN05LD(16x32bx2, x64), TCGEN05LD(16x32bx2, x128),
};

llvm::Intrinsic::ID Shape32x32b[] = {
TCGEN05LD(32x32b, x1), TCGEN05LD(32x32b, x2), TCGEN05LD(32x32b, x4),
TCGEN05LD(32x32b, x8), TCGEN05LD(32x32b, x16), TCGEN05LD(32x32b, x32),
TCGEN05LD(32x32b, x64), TCGEN05LD(32x32b, x128),
};

// `num` contains the length of vector and log2 of `num` returns the index
// into the shape array
unsigned Idx = std::log2(num);

switch (shape) {
case NVVM::Tcgen05LdStShape::SHAPE_16X64B:
return Shape16x64b[Idx];
case NVVM::Tcgen05LdStShape::SHAPE_16X128B:
return Shape16x128b[Idx - 1];
case NVVM::Tcgen05LdStShape::SHAPE_16X256B:
return Shape16x256b[Idx - 2];
case NVVM::Tcgen05LdStShape::SHAPE_32X32B:
return Shape32x32b[Idx];
case NVVM::Tcgen05LdStShape::SHAPE_16X32BX2:
return Shape16x32bx2[Idx];
}
llvm_unreachable("unhandled tcgen05.ld lowering");
}

#define TCGEN05ST(SHAPE, NUM) llvm::Intrinsic::nvvm_tcgen05_st_##SHAPE##_##NUM

static llvm::Intrinsic::ID
getTcgen05StIntrinsicID(mlir::NVVM::Tcgen05LdStShape shape, uint32_t num) {
llvm::Intrinsic::ID Shape16x64b[] = {
TCGEN05ST(16x64b, x1), TCGEN05ST(16x64b, x2), TCGEN05ST(16x64b, x4),
TCGEN05ST(16x64b, x8), TCGEN05ST(16x64b, x16), TCGEN05ST(16x64b, x32),
TCGEN05ST(16x64b, x64), TCGEN05ST(16x64b, x128),
};

llvm::Intrinsic::ID Shape16x128b[] = {
TCGEN05ST(16x128b, x1), TCGEN05ST(16x128b, x2), TCGEN05ST(16x128b, x4),
TCGEN05ST(16x128b, x8), TCGEN05ST(16x128b, x16), TCGEN05ST(16x128b, x32),
TCGEN05ST(16x128b, x64),
};

llvm::Intrinsic::ID Shape16x256b[] = {
TCGEN05ST(16x256b, x1), TCGEN05ST(16x256b, x2), TCGEN05ST(16x256b, x4),
TCGEN05ST(16x256b, x8), TCGEN05ST(16x256b, x16), TCGEN05ST(16x256b, x32),
};

llvm::Intrinsic::ID Shape16x32bx2[] = {
TCGEN05ST(16x32bx2, x1), TCGEN05ST(16x32bx2, x2),
TCGEN05ST(16x32bx2, x4), TCGEN05ST(16x32bx2, x8),
TCGEN05ST(16x32bx2, x16), TCGEN05ST(16x32bx2, x32),
TCGEN05ST(16x32bx2, x64), TCGEN05ST(16x32bx2, x128),
};

llvm::Intrinsic::ID Shape32x32b[] = {
TCGEN05ST(32x32b, x1), TCGEN05ST(32x32b, x2), TCGEN05ST(32x32b, x4),
TCGEN05ST(32x32b, x8), TCGEN05ST(32x32b, x16), TCGEN05ST(32x32b, x32),
TCGEN05ST(32x32b, x64), TCGEN05ST(32x32b, x128),
};

// `num` contains the length of vector and log2 of `num` returns the index
// into the shape array
unsigned Idx = std::log2(num);

switch (shape) {
case NVVM::Tcgen05LdStShape::SHAPE_16X64B:
return Shape16x64b[Idx];
case NVVM::Tcgen05LdStShape::SHAPE_16X128B:
return Shape16x128b[Idx - 1];
case NVVM::Tcgen05LdStShape::SHAPE_16X256B:
return Shape16x256b[Idx - 2];
case NVVM::Tcgen05LdStShape::SHAPE_32X32B:
return Shape32x32b[Idx];
case NVVM::Tcgen05LdStShape::SHAPE_16X32BX2:
return Shape16x32bx2[Idx];
}
llvm_unreachable("unhandled tcgen05.st lowering");
}

namespace {
/// Implementation of the dialect interface that converts operations belonging
/// to the NVVM dialect to LLVM IR.
Expand Down
Loading