33
33
#include " llvm/AsmParser/Parser.h"
34
34
#include " llvm/IR/Attributes.h"
35
35
#include " llvm/IR/Function.h"
36
+ #include " llvm/IR/IntrinsicsNVPTX.h"
36
37
#include " llvm/IR/Type.h"
37
38
#include " llvm/Support/Casting.h"
38
39
#include " llvm/Support/FormatVariadic.h"
@@ -56,7 +57,7 @@ using namespace NVVM;
56
57
// CpAsyncBulkTensorGlobalToSharedClusterOp (TMA Load)
57
58
// CpAsyncBulkTensorPrefetchOp (TMA Prefetch)
58
59
// CpAsyncBulkTensorReduceOp (TMA Store-Reduce)
59
- static LogicalResult CpAsyncBulkTensorCommonVerifier (size_t tensorDims,
60
+ static LogicalResult cpAsyncBulkTensorCommonVerifier (size_t tensorDims,
60
61
bool isIm2Col,
61
62
size_t numIm2ColOffsets,
62
63
Location loc) {
@@ -81,7 +82,7 @@ static LogicalResult CpAsyncBulkTensorCommonVerifier(size_t tensorDims,
81
82
LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify () {
82
83
size_t numIm2ColOffsets = getIm2colOffsets ().size ();
83
84
bool isIm2Col = numIm2ColOffsets > 0 ;
84
- return CpAsyncBulkTensorCommonVerifier (getCoordinates ().size (), isIm2Col,
85
+ return cpAsyncBulkTensorCommonVerifier (getCoordinates ().size (), isIm2Col,
85
86
numIm2ColOffsets, getLoc ());
86
87
}
87
88
@@ -105,13 +106,13 @@ LogicalResult CpAsyncOp::verify() {
105
106
LogicalResult CpAsyncBulkTensorPrefetchOp::verify () {
106
107
size_t numIm2ColOffsets = getIm2colOffsets ().size ();
107
108
bool isIm2Col = numIm2ColOffsets > 0 ;
108
- return CpAsyncBulkTensorCommonVerifier (getCoordinates ().size (), isIm2Col,
109
+ return cpAsyncBulkTensorCommonVerifier (getCoordinates ().size (), isIm2Col,
109
110
numIm2ColOffsets, getLoc ());
110
111
}
111
112
112
113
LogicalResult CpAsyncBulkTensorReduceOp::verify () {
113
114
bool isIm2Col = (getMode () == TMAStoreMode::IM2COL);
114
- return CpAsyncBulkTensorCommonVerifier (getCoordinates ().size (), isIm2Col, 0 ,
115
+ return cpAsyncBulkTensorCommonVerifier (getCoordinates ().size (), isIm2Col, 0 ,
115
116
getLoc ());
116
117
}
117
118
@@ -183,14 +184,14 @@ static bool isIntegerPtxType(MMATypes type) {
183
184
184
185
MMATypes MmaOp::accumPtxType () {
185
186
std::optional<mlir::NVVM::MMATypes> val = inferOperandMMAType (
186
- getODSOperands (2 ).getTypes ().front (), /* isAccum =*/ true );
187
+ getODSOperands (2 ).getTypes ().front (), /* isAccumulator =*/ true );
187
188
assert (val.has_value () && " accumulator PTX type should always be inferrable" );
188
189
return val.value ();
189
190
}
190
191
191
192
MMATypes MmaOp::resultPtxType () {
192
193
std::optional<mlir::NVVM::MMATypes> val =
193
- inferOperandMMAType (getResult ().getType (), /* isAccum =*/ true );
194
+ inferOperandMMAType (getResult ().getType (), /* isAccumulator =*/ true );
194
195
assert (val.has_value () && " result PTX type should always be inferrable" );
195
196
return val.value ();
196
197
}
@@ -224,7 +225,7 @@ void MmaOp::print(OpAsmPrinter &p) {
224
225
}
225
226
}
226
227
std::optional<MMATypes> inferredType =
227
- inferOperandMMAType (regTypes.back (), /* isAccum =*/ fragIdx >= 2 );
228
+ inferOperandMMAType (regTypes.back (), /* isAccumulator =*/ fragIdx >= 2 );
228
229
if (inferredType)
229
230
ignoreAttrNames.push_back (frag.ptxTypeAttr );
230
231
}
@@ -364,14 +365,14 @@ ParseResult MmaOp::parse(OpAsmParser &parser, OperationState &result) {
364
365
if (failed (parser.resolveOperands (frag.regs , frag.regTypes ,
365
366
parser.getNameLoc (), result.operands )))
366
367
return failure ();
367
- frag.elemtype =
368
- inferOperandMMAType (frag. regTypes [ 0 ], /* isAccum= */ iter.index () < 2 );
368
+ frag.elemtype = inferOperandMMAType (frag. regTypes [ 0 ],
369
+ /* isAccumulator */ iter.index () < 2 );
369
370
}
370
371
371
372
Type resultType;
372
373
if (parser.parseArrow () || parser.parseType (resultType))
373
374
return failure ();
374
- frags[3 ].elemtype = inferOperandMMAType (resultType, /* isAccum= */ true );
375
+ frags[3 ].elemtype = inferOperandMMAType (resultType, /* isAccumulator */ true );
375
376
376
377
std::array<StringRef, 2 > names{" multiplicandAPtxType" ,
377
378
" multiplicandBPtxType" };
@@ -1121,9 +1122,9 @@ LogicalResult NVVM::Tcgen05CpOp::verify() {
1121
1122
1122
1123
LogicalResult NVVM::MatchSyncOp::verify () {
1123
1124
if (getKind () == NVVM::MatchSyncKind::all) {
1124
- auto Type = llvm::dyn_cast<LLVM::LLVMStructType>(getType ());
1125
- if (!Type || Type .getBody ().size () != 2 ||
1126
- !Type .getBody ()[0 ].isInteger (32 ) || !Type .getBody ()[1 ].isInteger (1 )) {
1125
+ auto type = llvm::dyn_cast<LLVM::LLVMStructType>(getType ());
1126
+ if (!type || type .getBody ().size () != 2 ||
1127
+ !type .getBody ()[0 ].isInteger (32 ) || !type .getBody ()[1 ].isInteger (1 )) {
1127
1128
return emitOpError (" match.sync 'all' returns a two element struct with "
1128
1129
" first element as i32 and second element as i1" );
1129
1130
}
@@ -1164,7 +1165,7 @@ CpAsyncOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
1164
1165
llvm::Intrinsic::ID id;
1165
1166
1166
1167
auto cpAsyncOp = cast<NVVM::CpAsyncOp>(op);
1167
- bool hasCpSize = cpAsyncOp.getCpSize () ? true : false ;
1168
+ bool hasCpSize = static_cast < bool >( cpAsyncOp.getCpSize ()) ;
1168
1169
switch (cpAsyncOp.getSize ()) {
1169
1170
case 4 :
1170
1171
id = GET_CP_ASYNC_ID (ca, 4 , hasCpSize);
@@ -1263,6 +1264,8 @@ llvm::Intrinsic::ID CpAsyncBulkTensorReduceOp::getIntrinsicID(
1263
1264
llvm_unreachable (" Invalid Reduction Op for CpAsyncBulkTensorReduceOp" );
1264
1265
}
1265
1266
1267
+ #define _none
1268
+
1266
1269
#define CVT_F2TF32_ID_IMPL (rnd, relu, sf ) \
1267
1270
hasRelu ? llvm::Intrinsic::nvvm_f2tf32_##rnd##relu##sf \
1268
1271
: llvm::Intrinsic::nvvm_f2tf32_##rnd##sf
@@ -1282,7 +1285,7 @@ llvm::Intrinsic::ID CvtFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
1282
1285
case RndMode::RZ:
1283
1286
return GET_CVT_F2TF32_ID (rz, _relu, _satfinite);
1284
1287
case RndMode::RNA:
1285
- return GET_CVT_F2TF32_ID (rna, , _satfinite);
1288
+ return GET_CVT_F2TF32_ID (rna, _none , _satfinite);
1286
1289
default :
1287
1290
llvm_unreachable (" Invalid RoundingMode for CvtFloatToTF32Op" );
1288
1291
}
@@ -1293,9 +1296,9 @@ Tcgen05AllocOp::getIntrinsicIDAndArgs(Operation &op,
1293
1296
LLVM::ModuleTranslation &mt,
1294
1297
llvm::SmallVector<llvm::Value *> &args) {
1295
1298
auto curOp = cast<NVVM::Tcgen05AllocOp>(op);
1296
- unsigned AS = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr ().getType ())
1299
+ unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr ().getType ())
1297
1300
.getAddressSpace ();
1298
- bool isShared = AS == NVVMMemorySpace::kSharedMemorySpace ;
1301
+ bool isShared = as == NVVMMemorySpace::kSharedMemorySpace ;
1299
1302
bool is2CTAMode = curOp.getGroup () == Tcgen05GroupKind::CTA_2;
1300
1303
1301
1304
llvm::Intrinsic::ID id;
@@ -1342,14 +1345,15 @@ Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op,
1342
1345
LLVM::ModuleTranslation &mt,
1343
1346
llvm::SmallVector<llvm::Value *> &args) {
1344
1347
auto curOp = cast<NVVM::Tcgen05CommitOp>(op);
1345
- unsigned AS = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr ().getType ())
1348
+ unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr ().getType ())
1346
1349
.getAddressSpace ();
1347
- bool isShared = AS == NVVMMemorySpace::kSharedMemorySpace ;
1348
- bool hasMulticast = curOp.getMulticastMask () ? true : false ;
1350
+ bool isShared = as == NVVMMemorySpace::kSharedMemorySpace ;
1351
+ bool hasMulticast = static_cast < bool >( curOp.getMulticastMask ()) ;
1349
1352
bool is2CTAMode = curOp.getGroup () == Tcgen05GroupKind::CTA_2;
1350
1353
1351
- auto id = is2CTAMode ? GET_TCGEN05_COMMIT_ID (cg2, isShared, hasMulticast)
1352
- : GET_TCGEN05_COMMIT_ID (cg1, isShared, hasMulticast);
1354
+ llvm::Intrinsic::ID id =
1355
+ is2CTAMode ? GET_TCGEN05_COMMIT_ID (cg2, isShared, hasMulticast)
1356
+ : GET_TCGEN05_COMMIT_ID (cg1, isShared, hasMulticast);
1353
1357
1354
1358
// Fill the Intrinsic Args
1355
1359
args.push_back (mt.lookupValue (curOp.getAddr ()));
@@ -1368,9 +1372,9 @@ Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op,
1368
1372
1369
1373
#define GET_TCGEN05_CP_ID (shape_mc, src_fmt, is_2cta ) \
1370
1374
[&]() -> auto { \
1371
- if (src_fmt == Tcgen05CpSrcFormat::B6x16_P32) \
1375
+ if (( src_fmt) == Tcgen05CpSrcFormat::B6x16_P32) \
1372
1376
return TCGEN05_CP_2CTA (shape_mc, _b6x16_p32, is_2cta); \
1373
- if (src_fmt == Tcgen05CpSrcFormat::B4x16_P64) \
1377
+ if (( src_fmt) == Tcgen05CpSrcFormat::B4x16_P64) \
1374
1378
return TCGEN05_CP_2CTA (shape_mc, _b4x16_p64, is_2cta); \
1375
1379
return TCGEN05_CP_2CTA (shape_mc, , is_2cta); \
1376
1380
}()
@@ -1400,47 +1404,47 @@ llvm::Intrinsic::ID Tcgen05CpOp::getIntrinsicID(Operation &op) {
1400
1404
1401
1405
// Returns the valid vector length for a given shape and vector length, the
1402
1406
// function models the table mentioned in the tcgen05.{ld, st} Op description
1403
- static unsigned isValidVectorLength (NVVM::Tcgen05LdStShape Shape ,
1404
- unsigned VecLen ) {
1405
- if (Shape == NVVM::Tcgen05LdStShape::SHAPE_16X128B)
1406
- return VecLen >= 2 ;
1407
- if (Shape == NVVM::Tcgen05LdStShape::SHAPE_16X256B)
1408
- return VecLen >= 4 ;
1407
+ static unsigned isValidVectorLength (NVVM::Tcgen05LdStShape shape ,
1408
+ unsigned vecLen ) {
1409
+ if (shape == NVVM::Tcgen05LdStShape::SHAPE_16X128B)
1410
+ return vecLen >= 2 ;
1411
+ if (shape == NVVM::Tcgen05LdStShape::SHAPE_16X256B)
1412
+ return vecLen >= 4 ;
1409
1413
return true ;
1410
1414
}
1411
1415
1412
1416
LogicalResult Tcgen05LdOp::verify () {
1413
- LogicalResult Result = success ();
1417
+ LogicalResult result = success ();
1414
1418
if (getShape () == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset ())
1415
- Result = emitError (" shape 16x32bx2 requires offset argument" );
1419
+ result = emitError (" shape 16x32bx2 requires offset argument" );
1416
1420
1417
- auto ResTy = getRes ().getType ();
1418
- unsigned ResLen = isa<VectorType>(ResTy )
1419
- ? llvm::cast<VectorType>(ResTy ).getNumElements ()
1421
+ auto resTy = getRes ().getType ();
1422
+ unsigned resLen = isa<VectorType>(resTy )
1423
+ ? llvm::cast<VectorType>(resTy ).getNumElements ()
1420
1424
: 1 ;
1421
- if (!isValidVectorLength (getShape (), ResLen ))
1422
- Result = emitError (llvm::formatv (" invalid result type length {0} for shape "
1425
+ if (!isValidVectorLength (getShape (), resLen ))
1426
+ result = emitError (llvm::formatv (" invalid result type length {0} for shape "
1423
1427
" {1} in tcgen05.ld Op" ,
1424
- ResLen , stringifyEnum (getShape ())));
1428
+ resLen , stringifyEnum (getShape ())));
1425
1429
1426
- return Result ;
1430
+ return result ;
1427
1431
}
1428
1432
1429
1433
LogicalResult Tcgen05StOp::verify () {
1430
- LogicalResult Result = success ();
1434
+ LogicalResult result = success ();
1431
1435
if (getShape () == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset ())
1432
- Result = emitError (" shape 16x32bx2 requires offset argument" );
1436
+ result = emitError (" shape 16x32bx2 requires offset argument" );
1433
1437
1434
- auto ValTy = getVal ().getType ();
1435
- unsigned ValLen = isa<VectorType>(ValTy )
1436
- ? llvm::cast<VectorType>(ValTy ).getNumElements ()
1438
+ auto valTy = getVal ().getType ();
1439
+ unsigned valLen = isa<VectorType>(valTy )
1440
+ ? llvm::cast<VectorType>(valTy ).getNumElements ()
1437
1441
: 1 ;
1438
- if (!isValidVectorLength (getShape (), ValLen ))
1439
- Result = emitError (llvm::formatv (" invalid input length {0} for shape "
1442
+ if (!isValidVectorLength (getShape (), valLen ))
1443
+ result = emitError (llvm::formatv (" invalid input length {0} for shape "
1440
1444
" {1} in tcgen05.st Op" ,
1441
- ValLen , stringifyEnum (getShape ())));
1445
+ valLen , stringifyEnum (getShape ())));
1442
1446
1443
- return Result ;
1447
+ return result ;
1444
1448
}
1445
1449
1446
1450
// / Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might
@@ -1560,7 +1564,7 @@ NVVMTargetAttr::verify(function_ref<InFlightDiagnostic()> emitError,
1560
1564
return failure ();
1561
1565
}
1562
1566
if (files && !llvm::all_of (files, [](::mlir::Attribute attr) {
1563
- return attr && mlir::isa <StringAttr>(attr);
1567
+ return mlir::isa_and_nonnull <StringAttr>(attr);
1564
1568
})) {
1565
1569
emitError () << " All the elements in the `link` array must be strings." ;
1566
1570
return failure ();
0 commit comments