Skip to content

Commit b96d069

Browse files
committed
[NVGPU] Add debug in nvgpu (nfc)
Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D159343
1 parent de4d742 commit b96d069

File tree

1 file changed

+21
-3
lines changed

1 file changed

+21
-3
lines changed

mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,13 @@
2020
#include "mlir/IR/PatternMatch.h"
2121
#include "mlir/IR/TypeUtilities.h"
2222
#include "mlir/Pass/Pass.h"
23+
#include "llvm/Support/Debug.h"
2324
#include "llvm/Support/raw_ostream.h"
2425

26+
#define DEBUG_TYPE "nvgpu-to-nvvm"
27+
#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
28+
#define DBGSE() (llvm::dbgs())
29+
2530
namespace mlir {
2631
#define GEN_PASS_DEF_CONVERTNVGPUTONVVMPASS
2732
#include "mlir/Conversion/Passes.h.inc"
@@ -980,9 +985,14 @@ struct NVGPUGenerateGmmaDescriptorLowering
980985
};
981986

982987
int ex4LSB = 4;
983-
Value strideDim = makeConst((layout << 3) >> ex4LSB);
984988
int64_t sizeN = op.getTensorMap().getType().getTensor().getDimSize(0);
985-
Value leadDim = makeConst((sizeN * layout) >> ex4LSB);
989+
uint64_t strideDimVal = (layout << 3) >> ex4LSB;
990+
uint64_t leadDimVal = (sizeN * layout) >> ex4LSB;
991+
uint64_t offsetVal = 0;
992+
993+
Value strideDim = makeConst(strideDimVal);
994+
Value leadDim = makeConst(leadDimVal);
995+
986996
Value baseAddr = getStridedElementPtr(
987997
op->getLoc(), cast<MemRefType>(op.getTensor().getType()),
988998
adaptor.getTensor(), {}, rewriter);
@@ -996,14 +1006,22 @@ struct NVGPUGenerateGmmaDescriptorLowering
9961006
// // [62,64) swizzle type
9971007
dsc = insertBit(dsc, makeConst(swizzle), startSwizzleBit);
9981008
// // [49,52) base_offset
999-
dsc = insertBit(dsc, makeConst(0), startOffsetBit);
1009+
dsc = insertBit(dsc, makeConst(offsetVal), startOffsetBit);
10001010
// // [32,46) stride
10011011
dsc = insertBit(dsc, strideDim, startStrideBit);
10021012
// // [16,30) leading dimension
10031013
dsc = insertBit(dsc, leadDim, startLeadBit);
10041014
// // [0,14) start_address
10051015
dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit);
10061016

1017+
LLVM_DEBUG(DBGS() << "Generating wgmma.descriptor: "
1018+
<< "leading_off:" << leadDimVal << "\t"
1019+
<< "stride_off :" << strideDimVal << "\t"
1020+
<< "base_offset:" << offsetVal << "\t"
1021+
<< "layout_type:" << swizzle << " ("
1022+
<< nvgpu::stringifyTensorMapSwizzleKind(swizzleKind)
1023+
<< ")\n start_addr : " << baseAddr << "\n");
1024+
10071025
rewriter.replaceOp(op, dsc);
10081026
return success();
10091027
}

0 commit comments

Comments
 (0)