20
20
#include " mlir/IR/PatternMatch.h"
21
21
#include " mlir/IR/TypeUtilities.h"
22
22
#include " mlir/Pass/Pass.h"
23
+ #include " llvm/Support/Debug.h"
23
24
#include " llvm/Support/raw_ostream.h"
24
25
26
+ #define DEBUG_TYPE " nvgpu-to-nvvm"
27
+ #define DBGS () (llvm::dbgs() << ' [' << DEBUG_TYPE << " ] " )
28
+ #define DBGSE () (llvm::dbgs())
29
+
25
30
namespace mlir {
26
31
#define GEN_PASS_DEF_CONVERTNVGPUTONVVMPASS
27
32
#include " mlir/Conversion/Passes.h.inc"
@@ -980,9 +985,14 @@ struct NVGPUGenerateGmmaDescriptorLowering
980
985
};
981
986
982
987
int ex4LSB = 4 ;
983
- Value strideDim = makeConst ((layout << 3 ) >> ex4LSB);
984
988
int64_t sizeN = op.getTensorMap ().getType ().getTensor ().getDimSize (0 );
985
- Value leadDim = makeConst ((sizeN * layout) >> ex4LSB);
989
+ uint64_t strideDimVal = (layout << 3 ) >> ex4LSB;
990
+ uint64_t leadDimVal = (sizeN * layout) >> ex4LSB;
991
+ uint64_t offsetVal = 0 ;
992
+
993
+ Value strideDim = makeConst (strideDimVal);
994
+ Value leadDim = makeConst (leadDimVal);
995
+
986
996
Value baseAddr = getStridedElementPtr (
987
997
op->getLoc (), cast<MemRefType>(op.getTensor ().getType ()),
988
998
adaptor.getTensor (), {}, rewriter);
@@ -996,14 +1006,22 @@ struct NVGPUGenerateGmmaDescriptorLowering
996
1006
// // [62,64) swizzle type
997
1007
dsc = insertBit (dsc, makeConst (swizzle), startSwizzleBit);
998
1008
// // [49,52) base_offset
999
- dsc = insertBit (dsc, makeConst (0 ), startOffsetBit);
1009
+ dsc = insertBit (dsc, makeConst (offsetVal ), startOffsetBit);
1000
1010
// // [32,46) stride
1001
1011
dsc = insertBit (dsc, strideDim, startStrideBit);
1002
1012
// // [16,30) leading dimension
1003
1013
dsc = insertBit (dsc, leadDim, startLeadBit);
1004
1014
// // [0,14) start_address
1005
1015
dsc = insertBit (dsc, basePtr14bit, startBaseAddrBit);
1006
1016
1017
+ LLVM_DEBUG (DBGS () << " Generating wgmma.descriptor: "
1018
+ << " leading_off:" << leadDimVal << " \t "
1019
+ << " stride_off :" << strideDimVal << " \t "
1020
+ << " base_offset:" << offsetVal << " \t "
1021
+ << " layout_type:" << swizzle << " ("
1022
+ << nvgpu::stringifyTensorMapSwizzleKind (swizzleKind)
1023
+ << " )\n start_addr : " << baseAddr << " \n " );
1024
+
1007
1025
rewriter.replaceOp (op, dsc);
1008
1026
return success ();
1009
1027
}
0 commit comments