Skip to content

Commit d7cb24e

Browse files
authored
[MLIR][NVVM] Run clang-tidy (#135006)
1 parent d07a216 commit d7cb24e

File tree

1 file changed

+53
-49
lines changed

1 file changed

+53
-49
lines changed

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 53 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#include "llvm/AsmParser/Parser.h"
3434
#include "llvm/IR/Attributes.h"
3535
#include "llvm/IR/Function.h"
36+
#include "llvm/IR/IntrinsicsNVPTX.h"
3637
#include "llvm/IR/Type.h"
3738
#include "llvm/Support/Casting.h"
3839
#include "llvm/Support/FormatVariadic.h"
@@ -56,7 +57,7 @@ using namespace NVVM;
5657
// CpAsyncBulkTensorGlobalToSharedClusterOp (TMA Load)
5758
// CpAsyncBulkTensorPrefetchOp (TMA Prefetch)
5859
// CpAsyncBulkTensorReduceOp (TMA Store-Reduce)
59-
static LogicalResult CpAsyncBulkTensorCommonVerifier(size_t tensorDims,
60+
static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims,
6061
bool isIm2Col,
6162
size_t numIm2ColOffsets,
6263
Location loc) {
@@ -81,7 +82,7 @@ static LogicalResult CpAsyncBulkTensorCommonVerifier(size_t tensorDims,
8182
LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() {
8283
size_t numIm2ColOffsets = getIm2colOffsets().size();
8384
bool isIm2Col = numIm2ColOffsets > 0;
84-
return CpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col,
85+
return cpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col,
8586
numIm2ColOffsets, getLoc());
8687
}
8788

@@ -105,13 +106,13 @@ LogicalResult CpAsyncOp::verify() {
105106
LogicalResult CpAsyncBulkTensorPrefetchOp::verify() {
106107
size_t numIm2ColOffsets = getIm2colOffsets().size();
107108
bool isIm2Col = numIm2ColOffsets > 0;
108-
return CpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col,
109+
return cpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col,
109110
numIm2ColOffsets, getLoc());
110111
}
111112

112113
LogicalResult CpAsyncBulkTensorReduceOp::verify() {
113114
bool isIm2Col = (getMode() == TMAStoreMode::IM2COL);
114-
return CpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col, 0,
115+
return cpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col, 0,
115116
getLoc());
116117
}
117118

@@ -183,14 +184,14 @@ static bool isIntegerPtxType(MMATypes type) {
183184

184185
MMATypes MmaOp::accumPtxType() {
185186
std::optional<mlir::NVVM::MMATypes> val = inferOperandMMAType(
186-
getODSOperands(2).getTypes().front(), /*isAccum=*/true);
187+
getODSOperands(2).getTypes().front(), /*isAccumulator=*/true);
187188
assert(val.has_value() && "accumulator PTX type should always be inferrable");
188189
return val.value();
189190
}
190191

191192
MMATypes MmaOp::resultPtxType() {
192193
std::optional<mlir::NVVM::MMATypes> val =
193-
inferOperandMMAType(getResult().getType(), /*isAccum=*/true);
194+
inferOperandMMAType(getResult().getType(), /*isAccumulator=*/true);
194195
assert(val.has_value() && "result PTX type should always be inferrable");
195196
return val.value();
196197
}
@@ -224,7 +225,7 @@ void MmaOp::print(OpAsmPrinter &p) {
224225
}
225226
}
226227
std::optional<MMATypes> inferredType =
227-
inferOperandMMAType(regTypes.back(), /*isAccum=*/fragIdx >= 2);
228+
inferOperandMMAType(regTypes.back(), /*isAccumulator=*/fragIdx >= 2);
228229
if (inferredType)
229230
ignoreAttrNames.push_back(frag.ptxTypeAttr);
230231
}
@@ -364,14 +365,14 @@ ParseResult MmaOp::parse(OpAsmParser &parser, OperationState &result) {
364365
if (failed(parser.resolveOperands(frag.regs, frag.regTypes,
365366
parser.getNameLoc(), result.operands)))
366367
return failure();
367-
frag.elemtype =
368-
inferOperandMMAType(frag.regTypes[0], /*isAccum=*/iter.index() < 2);
368+
frag.elemtype = inferOperandMMAType(frag.regTypes[0],
369+
/*isAccumulator*/ iter.index() < 2);
369370
}
370371

371372
Type resultType;
372373
if (parser.parseArrow() || parser.parseType(resultType))
373374
return failure();
374-
frags[3].elemtype = inferOperandMMAType(resultType, /*isAccum=*/true);
375+
frags[3].elemtype = inferOperandMMAType(resultType, /*isAccumulator*/ true);
375376

376377
std::array<StringRef, 2> names{"multiplicandAPtxType",
377378
"multiplicandBPtxType"};
@@ -1121,9 +1122,9 @@ LogicalResult NVVM::Tcgen05CpOp::verify() {
11211122

11221123
LogicalResult NVVM::MatchSyncOp::verify() {
11231124
if (getKind() == NVVM::MatchSyncKind::all) {
1124-
auto Type = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
1125-
if (!Type || Type.getBody().size() != 2 ||
1126-
!Type.getBody()[0].isInteger(32) || !Type.getBody()[1].isInteger(1)) {
1125+
auto type = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
1126+
if (!type || type.getBody().size() != 2 ||
1127+
!type.getBody()[0].isInteger(32) || !type.getBody()[1].isInteger(1)) {
11271128
return emitOpError("match.sync 'all' returns a two element struct with "
11281129
"first element as i32 and second element as i1");
11291130
}
@@ -1164,7 +1165,7 @@ CpAsyncOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
11641165
llvm::Intrinsic::ID id;
11651166

11661167
auto cpAsyncOp = cast<NVVM::CpAsyncOp>(op);
1167-
bool hasCpSize = cpAsyncOp.getCpSize() ? true : false;
1168+
bool hasCpSize = static_cast<bool>(cpAsyncOp.getCpSize());
11681169
switch (cpAsyncOp.getSize()) {
11691170
case 4:
11701171
id = GET_CP_ASYNC_ID(ca, 4, hasCpSize);
@@ -1263,6 +1264,8 @@ llvm::Intrinsic::ID CpAsyncBulkTensorReduceOp::getIntrinsicID(
12631264
llvm_unreachable("Invalid Reduction Op for CpAsyncBulkTensorReduceOp");
12641265
}
12651266

1267+
#define _none
1268+
12661269
#define CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
12671270
hasRelu ? llvm::Intrinsic::nvvm_f2tf32_##rnd##relu##sf \
12681271
: llvm::Intrinsic::nvvm_f2tf32_##rnd##sf
@@ -1282,7 +1285,7 @@ llvm::Intrinsic::ID CvtFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
12821285
case RndMode::RZ:
12831286
return GET_CVT_F2TF32_ID(rz, _relu, _satfinite);
12841287
case RndMode::RNA:
1285-
return GET_CVT_F2TF32_ID(rna, , _satfinite);
1288+
return GET_CVT_F2TF32_ID(rna, _none, _satfinite);
12861289
default:
12871290
llvm_unreachable("Invalid RoundingMode for CvtFloatToTF32Op");
12881291
}
@@ -1293,9 +1296,9 @@ Tcgen05AllocOp::getIntrinsicIDAndArgs(Operation &op,
12931296
LLVM::ModuleTranslation &mt,
12941297
llvm::SmallVector<llvm::Value *> &args) {
12951298
auto curOp = cast<NVVM::Tcgen05AllocOp>(op);
1296-
unsigned AS = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
1299+
unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
12971300
.getAddressSpace();
1298-
bool isShared = AS == NVVMMemorySpace::kSharedMemorySpace;
1301+
bool isShared = as == NVVMMemorySpace::kSharedMemorySpace;
12991302
bool is2CTAMode = curOp.getGroup() == Tcgen05GroupKind::CTA_2;
13001303

13011304
llvm::Intrinsic::ID id;
@@ -1342,14 +1345,15 @@ Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op,
13421345
LLVM::ModuleTranslation &mt,
13431346
llvm::SmallVector<llvm::Value *> &args) {
13441347
auto curOp = cast<NVVM::Tcgen05CommitOp>(op);
1345-
unsigned AS = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
1348+
unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
13461349
.getAddressSpace();
1347-
bool isShared = AS == NVVMMemorySpace::kSharedMemorySpace;
1348-
bool hasMulticast = curOp.getMulticastMask() ? true : false;
1350+
bool isShared = as == NVVMMemorySpace::kSharedMemorySpace;
1351+
bool hasMulticast = static_cast<bool>(curOp.getMulticastMask());
13491352
bool is2CTAMode = curOp.getGroup() == Tcgen05GroupKind::CTA_2;
13501353

1351-
auto id = is2CTAMode ? GET_TCGEN05_COMMIT_ID(cg2, isShared, hasMulticast)
1352-
: GET_TCGEN05_COMMIT_ID(cg1, isShared, hasMulticast);
1354+
llvm::Intrinsic::ID id =
1355+
is2CTAMode ? GET_TCGEN05_COMMIT_ID(cg2, isShared, hasMulticast)
1356+
: GET_TCGEN05_COMMIT_ID(cg1, isShared, hasMulticast);
13531357

13541358
// Fill the Intrinsic Args
13551359
args.push_back(mt.lookupValue(curOp.getAddr()));
@@ -1368,9 +1372,9 @@ Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op,
13681372

13691373
#define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta) \
13701374
[&]() -> auto { \
1371-
if (src_fmt == Tcgen05CpSrcFormat::B6x16_P32) \
1375+
if ((src_fmt) == Tcgen05CpSrcFormat::B6x16_P32) \
13721376
return TCGEN05_CP_2CTA(shape_mc, _b6x16_p32, is_2cta); \
1373-
if (src_fmt == Tcgen05CpSrcFormat::B4x16_P64) \
1377+
if ((src_fmt) == Tcgen05CpSrcFormat::B4x16_P64) \
13741378
return TCGEN05_CP_2CTA(shape_mc, _b4x16_p64, is_2cta); \
13751379
return TCGEN05_CP_2CTA(shape_mc, , is_2cta); \
13761380
}()
@@ -1400,47 +1404,47 @@ llvm::Intrinsic::ID Tcgen05CpOp::getIntrinsicID(Operation &op) {
14001404

14011405
// Returns the valid vector length for a given shape and vector length, the
14021406
// function models the table mentioned in the tcgen05.{ld, st} Op description
1403-
static unsigned isValidVectorLength(NVVM::Tcgen05LdStShape Shape,
1404-
unsigned VecLen) {
1405-
if (Shape == NVVM::Tcgen05LdStShape::SHAPE_16X128B)
1406-
return VecLen >= 2;
1407-
if (Shape == NVVM::Tcgen05LdStShape::SHAPE_16X256B)
1408-
return VecLen >= 4;
1407+
static unsigned isValidVectorLength(NVVM::Tcgen05LdStShape shape,
1408+
unsigned vecLen) {
1409+
if (shape == NVVM::Tcgen05LdStShape::SHAPE_16X128B)
1410+
return vecLen >= 2;
1411+
if (shape == NVVM::Tcgen05LdStShape::SHAPE_16X256B)
1412+
return vecLen >= 4;
14091413
return true;
14101414
}
14111415

14121416
LogicalResult Tcgen05LdOp::verify() {
1413-
LogicalResult Result = success();
1417+
LogicalResult result = success();
14141418
if (getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
1415-
Result = emitError("shape 16x32bx2 requires offset argument");
1419+
result = emitError("shape 16x32bx2 requires offset argument");
14161420

1417-
auto ResTy = getRes().getType();
1418-
unsigned ResLen = isa<VectorType>(ResTy)
1419-
? llvm::cast<VectorType>(ResTy).getNumElements()
1421+
auto resTy = getRes().getType();
1422+
unsigned resLen = isa<VectorType>(resTy)
1423+
? llvm::cast<VectorType>(resTy).getNumElements()
14201424
: 1;
1421-
if (!isValidVectorLength(getShape(), ResLen))
1422-
Result = emitError(llvm::formatv("invalid result type length {0} for shape "
1425+
if (!isValidVectorLength(getShape(), resLen))
1426+
result = emitError(llvm::formatv("invalid result type length {0} for shape "
14231427
"{1} in tcgen05.ld Op",
1424-
ResLen, stringifyEnum(getShape())));
1428+
resLen, stringifyEnum(getShape())));
14251429

1426-
return Result;
1430+
return result;
14271431
}
14281432

14291433
LogicalResult Tcgen05StOp::verify() {
1430-
LogicalResult Result = success();
1434+
LogicalResult result = success();
14311435
if (getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
1432-
Result = emitError("shape 16x32bx2 requires offset argument");
1436+
result = emitError("shape 16x32bx2 requires offset argument");
14331437

1434-
auto ValTy = getVal().getType();
1435-
unsigned ValLen = isa<VectorType>(ValTy)
1436-
? llvm::cast<VectorType>(ValTy).getNumElements()
1438+
auto valTy = getVal().getType();
1439+
unsigned valLen = isa<VectorType>(valTy)
1440+
? llvm::cast<VectorType>(valTy).getNumElements()
14371441
: 1;
1438-
if (!isValidVectorLength(getShape(), ValLen))
1439-
Result = emitError(llvm::formatv("invalid input length {0} for shape "
1442+
if (!isValidVectorLength(getShape(), valLen))
1443+
result = emitError(llvm::formatv("invalid input length {0} for shape "
14401444
"{1} in tcgen05.st Op",
1441-
ValLen, stringifyEnum(getShape())));
1445+
valLen, stringifyEnum(getShape())));
14421446

1443-
return Result;
1447+
return result;
14441448
}
14451449

14461450
/// Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might
@@ -1560,7 +1564,7 @@ NVVMTargetAttr::verify(function_ref<InFlightDiagnostic()> emitError,
15601564
return failure();
15611565
}
15621566
if (files && !llvm::all_of(files, [](::mlir::Attribute attr) {
1563-
return attr && mlir::isa<StringAttr>(attr);
1567+
return mlir::isa_and_nonnull<StringAttr>(attr);
15641568
})) {
15651569
emitError() << "All the elements in the `link` array must be strings.";
15661570
return failure();

0 commit comments

Comments
 (0)