Skip to content

[MLIR][NVVM] Run clang-tidy #135006

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 10, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 58 additions & 51 deletions mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "llvm/AsmParser/Parser.h"
#include "llvm/IR/Attributes.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IntrinsicsNVPTX.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see this as an addition-only change.

For my understanding,
Please let me know why this header is added as part of this change..

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah because there was auto usage without clear type. I used a real type name, so need to add intrinsic.

#include "llvm/IR/Type.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/FormatVariadic.h"
Expand All @@ -56,7 +57,7 @@ using namespace NVVM;
// CpAsyncBulkTensorGlobalToSharedClusterOp (TMA Load)
// CpAsyncBulkTensorPrefetchOp (TMA Prefetch)
// CpAsyncBulkTensorReduceOp (TMA Store-Reduce)
static LogicalResult CpAsyncBulkTensorCommonVerifier(size_t tensorDims,
static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims,
bool isIm2Col,
size_t numIm2ColOffsets,
Location loc) {
Expand All @@ -81,7 +82,7 @@ static LogicalResult CpAsyncBulkTensorCommonVerifier(size_t tensorDims,
LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() {
size_t numIm2ColOffsets = getIm2colOffsets().size();
bool isIm2Col = numIm2ColOffsets > 0;
return CpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col,
return cpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col,
numIm2ColOffsets, getLoc());
}

Expand All @@ -105,13 +106,13 @@ LogicalResult CpAsyncOp::verify() {
LogicalResult CpAsyncBulkTensorPrefetchOp::verify() {
size_t numIm2ColOffsets = getIm2colOffsets().size();
bool isIm2Col = numIm2ColOffsets > 0;
return CpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col,
return cpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col,
numIm2ColOffsets, getLoc());
}

LogicalResult CpAsyncBulkTensorReduceOp::verify() {
bool isIm2Col = (getMode() == TMAStoreMode::IM2COL);
return CpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col, 0,
return cpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col, 0,
getLoc());
}

Expand Down Expand Up @@ -183,14 +184,14 @@ static bool isIntegerPtxType(MMATypes type) {

MMATypes MmaOp::accumPtxType() {
std::optional<mlir::NVVM::MMATypes> val = inferOperandMMAType(
getODSOperands(2).getTypes().front(), /*isAccum=*/true);
getODSOperands(2).getTypes().front(), /*isAccumulator=*/true);
assert(val.has_value() && "accumulator PTX type should always be inferrable");
return val.value();
}

MMATypes MmaOp::resultPtxType() {
std::optional<mlir::NVVM::MMATypes> val =
inferOperandMMAType(getResult().getType(), /*isAccum=*/true);
inferOperandMMAType(getResult().getType(), /*isAccumulator=*/true);
assert(val.has_value() && "result PTX type should always be inferrable");
return val.value();
}
Expand Down Expand Up @@ -224,7 +225,7 @@ void MmaOp::print(OpAsmPrinter &p) {
}
}
std::optional<MMATypes> inferredType =
inferOperandMMAType(regTypes.back(), /*isAccum=*/fragIdx >= 2);
inferOperandMMAType(regTypes.back(), /*isAccumulator=*/fragIdx >= 2);
if (inferredType)
ignoreAttrNames.push_back(frag.ptxTypeAttr);
}
Expand All @@ -243,7 +244,8 @@ void MmaOp::print(OpAsmPrinter &p) {
p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);

// Print the types of the operands and result.
p << " : " << "(";
p << " : "
<< "(";
llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
frags[1].regs[0].getType(),
frags[2].regs[0].getType()},
Expand Down Expand Up @@ -363,14 +365,14 @@ ParseResult MmaOp::parse(OpAsmParser &parser, OperationState &result) {
if (failed(parser.resolveOperands(frag.regs, frag.regTypes,
parser.getNameLoc(), result.operands)))
return failure();
frag.elemtype =
inferOperandMMAType(frag.regTypes[0], /*isAccum=*/iter.index() < 2);
frag.elemtype = inferOperandMMAType(frag.regTypes[0],
/*isAccumulator*/ iter.index() < 2);
}

Type resultType;
if (parser.parseArrow() || parser.parseType(resultType))
return failure();
frags[3].elemtype = inferOperandMMAType(resultType, /*isAccum=*/true);
frags[3].elemtype = inferOperandMMAType(resultType, /*isAccumulator*/ true);

std::array<StringRef, 2> names{"multiplicandAPtxType",
"multiplicandBPtxType"};
Expand Down Expand Up @@ -992,7 +994,9 @@ std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
ss << "},";
// Need to map read/write registers correctly.
regCnt = (regCnt * 2);
ss << " $" << (regCnt) << "," << " $" << (regCnt + 1) << "," << " p";
ss << " $" << (regCnt) << ","
<< " $" << (regCnt + 1) << ","
<< " p";
if (getTypeD() != WGMMATypes::s32) {
ss << ", $" << (regCnt + 3) << ", $" << (regCnt + 4);
}
Expand Down Expand Up @@ -1118,9 +1122,9 @@ LogicalResult NVVM::Tcgen05CpOp::verify() {

LogicalResult NVVM::MatchSyncOp::verify() {
if (getKind() == NVVM::MatchSyncKind::all) {
auto Type = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
if (!Type || Type.getBody().size() != 2 ||
!Type.getBody()[0].isInteger(32) || !Type.getBody()[1].isInteger(1)) {
auto type = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
if (!type || type.getBody().size() != 2 ||
!type.getBody()[0].isInteger(32) || !type.getBody()[1].isInteger(1)) {
return emitOpError("match.sync 'all' returns a two element struct with "
"first element as i32 and second element as i1");
}
Expand Down Expand Up @@ -1161,7 +1165,7 @@ CpAsyncOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
llvm::Intrinsic::ID id;

auto cpAsyncOp = cast<NVVM::CpAsyncOp>(op);
bool hasCpSize = cpAsyncOp.getCpSize() ? true : false;
bool hasCpSize = static_cast<bool>(cpAsyncOp.getCpSize());
switch (cpAsyncOp.getSize()) {
case 4:
id = GET_CP_ASYNC_ID(ca, 4, hasCpSize);
Expand Down Expand Up @@ -1260,6 +1264,8 @@ llvm::Intrinsic::ID CpAsyncBulkTensorReduceOp::getIntrinsicID(
llvm_unreachable("Invalid Reduction Op for CpAsyncBulkTensorReduceOp");
}

#define _none
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this one also auto-fixed by clang-tidy?

I wonder if we should name it better like macro_null_str (or something similar) to convey that this is used as an empty string in macros..

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not auto-generated but clang-tidy complaint about that.


#define CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
hasRelu ? llvm::Intrinsic::nvvm_f2tf32_##rnd##relu##sf \
: llvm::Intrinsic::nvvm_f2tf32_##rnd##sf
Expand All @@ -1279,7 +1285,7 @@ llvm::Intrinsic::ID CvtFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
case RndMode::RZ:
return GET_CVT_F2TF32_ID(rz, _relu, _satfinite);
case RndMode::RNA:
return GET_CVT_F2TF32_ID(rna, , _satfinite);
return GET_CVT_F2TF32_ID(rna, _none, _satfinite);
default:
llvm_unreachable("Invalid RoundingMode for CvtFloatToTF32Op");
}
Expand All @@ -1290,9 +1296,9 @@ Tcgen05AllocOp::getIntrinsicIDAndArgs(Operation &op,
LLVM::ModuleTranslation &mt,
llvm::SmallVector<llvm::Value *> &args) {
auto curOp = cast<NVVM::Tcgen05AllocOp>(op);
unsigned AS = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
.getAddressSpace();
bool isShared = AS == NVVMMemorySpace::kSharedMemorySpace;
bool isShared = as == NVVMMemorySpace::kSharedMemorySpace;
bool is2CTAMode = curOp.getGroup() == Tcgen05GroupKind::CTA_2;

llvm::Intrinsic::ID id;
Expand Down Expand Up @@ -1339,14 +1345,15 @@ Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op,
LLVM::ModuleTranslation &mt,
llvm::SmallVector<llvm::Value *> &args) {
auto curOp = cast<NVVM::Tcgen05CommitOp>(op);
unsigned AS = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
.getAddressSpace();
bool isShared = AS == NVVMMemorySpace::kSharedMemorySpace;
bool hasMulticast = curOp.getMulticastMask() ? true : false;
bool isShared = as == NVVMMemorySpace::kSharedMemorySpace;
bool hasMulticast = static_cast<bool>(curOp.getMulticastMask());
bool is2CTAMode = curOp.getGroup() == Tcgen05GroupKind::CTA_2;

auto id = is2CTAMode ? GET_TCGEN05_COMMIT_ID(cg2, isShared, hasMulticast)
: GET_TCGEN05_COMMIT_ID(cg1, isShared, hasMulticast);
llvm::Intrinsic::ID id =
is2CTAMode ? GET_TCGEN05_COMMIT_ID(cg2, isShared, hasMulticast)
: GET_TCGEN05_COMMIT_ID(cg1, isShared, hasMulticast);

// Fill the Intrinsic Args
args.push_back(mt.lookupValue(curOp.getAddr()));
Expand All @@ -1365,9 +1372,9 @@ Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op,

#define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta) \
[&]() -> auto { \
if (src_fmt == Tcgen05CpSrcFormat::B6x16_P32) \
if ((src_fmt) == Tcgen05CpSrcFormat::B6x16_P32) \
return TCGEN05_CP_2CTA(shape_mc, _b6x16_p32, is_2cta); \
if (src_fmt == Tcgen05CpSrcFormat::B4x16_P64) \
if ((src_fmt) == Tcgen05CpSrcFormat::B4x16_P64) \
return TCGEN05_CP_2CTA(shape_mc, _b4x16_p64, is_2cta); \
return TCGEN05_CP_2CTA(shape_mc, , is_2cta); \
}()
Expand Down Expand Up @@ -1397,47 +1404,47 @@ llvm::Intrinsic::ID Tcgen05CpOp::getIntrinsicID(Operation &op) {

// Returns the valid vector length for a given shape and vector length, the
// function models the table mentioned in the tcgen05.{ld, st} Op description
static unsigned isValidVectorLength(NVVM::Tcgen05LdStShape Shape,
unsigned VecLen) {
if (Shape == NVVM::Tcgen05LdStShape::SHAPE_16X128B)
return VecLen >= 2;
if (Shape == NVVM::Tcgen05LdStShape::SHAPE_16X256B)
return VecLen >= 4;
static unsigned isValidVectorLength(NVVM::Tcgen05LdStShape shape,
unsigned vecLen) {
if (shape == NVVM::Tcgen05LdStShape::SHAPE_16X128B)
return vecLen >= 2;
if (shape == NVVM::Tcgen05LdStShape::SHAPE_16X256B)
return vecLen >= 4;
return true;
}

LogicalResult Tcgen05LdOp::verify() {
LogicalResult Result = success();
LogicalResult result = success();
if (getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
Result = emitError("shape 16x32bx2 requires offset argument");
result = emitError("shape 16x32bx2 requires offset argument");

auto ResTy = getRes().getType();
unsigned ResLen = isa<VectorType>(ResTy)
? llvm::cast<VectorType>(ResTy).getNumElements()
auto resTy = getRes().getType();
unsigned resLen = isa<VectorType>(resTy)
? llvm::cast<VectorType>(resTy).getNumElements()
: 1;
if (!isValidVectorLength(getShape(), ResLen))
Result = emitError(llvm::formatv("invalid result type length {0} for shape "
if (!isValidVectorLength(getShape(), resLen))
result = emitError(llvm::formatv("invalid result type length {0} for shape "
"{1} in tcgen05.ld Op",
ResLen, stringifyEnum(getShape())));
resLen, stringifyEnum(getShape())));

return Result;
return result;
}

LogicalResult Tcgen05StOp::verify() {
LogicalResult Result = success();
LogicalResult result = success();
if (getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
Result = emitError("shape 16x32bx2 requires offset argument");
result = emitError("shape 16x32bx2 requires offset argument");

auto ValTy = getVal().getType();
unsigned ValLen = isa<VectorType>(ValTy)
? llvm::cast<VectorType>(ValTy).getNumElements()
auto valTy = getVal().getType();
unsigned valLen = isa<VectorType>(valTy)
? llvm::cast<VectorType>(valTy).getNumElements()
: 1;
if (!isValidVectorLength(getShape(), ValLen))
Result = emitError(llvm::formatv("invalid input length {0} for shape "
if (!isValidVectorLength(getShape(), valLen))
result = emitError(llvm::formatv("invalid input length {0} for shape "
"{1} in tcgen05.st Op",
ValLen, stringifyEnum(getShape())));
valLen, stringifyEnum(getShape())));

return Result;
return result;
}

/// Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might
Expand Down Expand Up @@ -1557,7 +1564,7 @@ NVVMTargetAttr::verify(function_ref<InFlightDiagnostic()> emitError,
return failure();
}
if (files && !llvm::all_of(files, [](::mlir::Attribute attr) {
return attr && mlir::isa<StringAttr>(attr);
return mlir::isa_and_nonnull<StringAttr>(attr);
})) {
emitError() << "All the elements in the `link` array must be strings.";
return failure();
Expand Down