|
| 1 | +Upstream PR: https://github.com/openai/triton/pull/3213 |
| 2 | + |
| 3 | +diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp |
| 4 | +--- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp |
| 5 | ++++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp |
| 6 | +@@ -157,9 +157,10 @@ getSharedMemoryObjectFromStruct(Location |
| 7 | + /*offsets=*/{elems.begin() + 1 + rank, elems.end()}}; |
| 8 | + } |
| 9 | + |
| 10 | +-SmallVector<Value> |
| 11 | +-getStridesFromShapeAndOrder(ArrayRef<int64_t> shape, ArrayRef<unsigned> order, |
| 12 | +- Location loc, ConversionPatternRewriter &rewriter) { |
| 13 | ++SmallVector<Value> getStridesFromShapeAndOrder(ArrayRef<int64_t> shape, |
| 14 | ++ ArrayRef<unsigned> order, |
| 15 | ++ Location loc, |
| 16 | ++ RewriterBase &rewriter) { |
| 17 | + auto rank = shape.size(); |
| 18 | + SmallVector<Value> strides(rank); |
| 19 | + int64_t stride = 1; |
| 20 | +@@ -172,9 +173,8 @@ getStridesFromShapeAndOrder(ArrayRef<int |
| 21 | + |
| 22 | + // Convert an \param index to a multi-dim coordinate given \param shape and |
| 23 | + // \param order. |
| 24 | +-SmallVector<Value> delinearize(ConversionPatternRewriter &rewriter, |
| 25 | +- Location loc, Value linear, |
| 26 | +- ArrayRef<unsigned> shape, |
| 27 | ++SmallVector<Value> delinearize(RewriterBase &rewriter, Location loc, |
| 28 | ++ Value linear, ArrayRef<unsigned> shape, |
| 29 | + ArrayRef<unsigned> order) { |
| 30 | + unsigned rank = shape.size(); |
| 31 | + assert(rank == order.size()); |
| 32 | +@@ -194,9 +194,8 @@ SmallVector<Value> delinearize(Conversio |
| 33 | + return multiDim; |
| 34 | + } |
| 35 | + |
| 36 | +-SmallVector<Value> delinearize(ConversionPatternRewriter &rewriter, |
| 37 | +- Location loc, unsigned linear, |
| 38 | +- ArrayRef<unsigned> shape) { |
| 39 | ++SmallVector<Value> delinearize(RewriterBase &rewriter, Location loc, |
| 40 | ++ unsigned linear, ArrayRef<unsigned> shape) { |
| 41 | + unsigned rank = shape.size(); |
| 42 | + assert(rank > 0); |
| 43 | + SmallVector<Value> multiDim(rank); |
| 44 | +@@ -209,9 +208,8 @@ SmallVector<Value> delinearize(Conversio |
| 45 | + return multiDim; |
| 46 | + } |
| 47 | + |
| 48 | +-SmallVector<Value> delinearize(ConversionPatternRewriter &rewriter, |
| 49 | +- Location loc, Value linear, |
| 50 | +- ArrayRef<unsigned> shape) { |
| 51 | ++SmallVector<Value> delinearize(RewriterBase &rewriter, Location loc, |
| 52 | ++ Value linear, ArrayRef<unsigned> shape) { |
| 53 | + unsigned rank = shape.size(); |
| 54 | + assert(rank > 0); |
| 55 | + SmallVector<Value> multiDim(rank); |
| 56 | +diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.h b/lib/Conversion/TritonGPUToLLVM/Utility.h |
| 57 | +--- a/lib/Conversion/TritonGPUToLLVM/Utility.h |
| 58 | ++++ b/lib/Conversion/TritonGPUToLLVM/Utility.h |
| 59 | +@@ -232,9 +232,10 @@ void createStoreDSmem(Location loc, Patt |
| 60 | + Value ctaId, ArrayRef<Value> values); |
| 61 | + |
| 62 | + /// Helper function to get strides from a given shape and its order |
| 63 | +-SmallVector<Value> |
| 64 | +-getStridesFromShapeAndOrder(ArrayRef<int64_t> shape, ArrayRef<unsigned> order, |
| 65 | +- Location loc, ConversionPatternRewriter &rewriter); |
| 66 | ++SmallVector<Value> getStridesFromShapeAndOrder(ArrayRef<int64_t> shape, |
| 67 | ++ ArrayRef<unsigned> order, |
| 68 | ++ Location loc, |
| 69 | ++ RewriterBase &rewriter); |
| 70 | + struct SharedMemoryObject { |
| 71 | + Value base; // i32 ptr. The start address of the shared memory object after |
| 72 | + // the initial allocation or the last slicing operation. |
| 73 | +@@ -264,7 +265,7 @@ struct SharedMemoryObject { |
| 74 | + |
| 75 | + SharedMemoryObject(Value base, Type baseElemType, ArrayRef<int64_t> shape, |
| 76 | + ArrayRef<unsigned> order, Location loc, |
| 77 | +- ConversionPatternRewriter &rewriter) |
| 78 | ++ RewriterBase &rewriter) |
| 79 | + : base(base), baseElemType(baseElemType) { |
| 80 | + strides = getStridesFromShapeAndOrder(shape, order, loc, rewriter); |
| 81 | + offsets.append(order.size(), i32_val(0)); |
| 82 | +@@ -311,18 +312,15 @@ getSharedMemoryObjectFromStruct(Location |
| 83 | + |
| 84 | + // Convert an \param index to a multi-dim coordinate given \param shape and |
| 85 | + // \param order. |
| 86 | +-SmallVector<Value> delinearize(ConversionPatternRewriter &rewriter, |
| 87 | +- Location loc, Value linear, |
| 88 | +- ArrayRef<unsigned> shape, |
| 89 | ++SmallVector<Value> delinearize(RewriterBase &rewriter, Location loc, |
| 90 | ++ Value linear, ArrayRef<unsigned> shape, |
| 91 | + ArrayRef<unsigned> order); |
| 92 | + |
| 93 | +-SmallVector<Value> delinearize(ConversionPatternRewriter &rewriter, |
| 94 | +- Location loc, unsigned linear, |
| 95 | +- ArrayRef<unsigned> shape); |
| 96 | ++SmallVector<Value> delinearize(RewriterBase &rewriter, Location loc, |
| 97 | ++ unsigned linear, ArrayRef<unsigned> shape); |
| 98 | + |
| 99 | +-SmallVector<Value> delinearize(ConversionPatternRewriter &rewriter, |
| 100 | +- Location loc, Value linear, |
| 101 | +- ArrayRef<unsigned> shape); |
| 102 | ++SmallVector<Value> delinearize(RewriterBase &rewriter, Location loc, |
| 103 | ++ Value linear, ArrayRef<unsigned> shape); |
| 104 | + |
| 105 | + Value linearize(ConversionPatternRewriter &rewriter, Location loc, |
| 106 | + ArrayRef<Value> multiDim, ArrayRef<unsigned> shape, |
| 107 | +@@ -380,22 +378,20 @@ static Value getSharedMemoryBase(Locatio |
| 108 | + |
| 109 | + /* ------------------------------------ */ |
| 110 | + // Returns CTA level thread idx |
| 111 | +-static Value getThreadIdInCTA(ConversionPatternRewriter &rewriter, |
| 112 | +- Location loc) { |
| 113 | ++static Value getThreadIdInCTA(RewriterBase &rewriter, Location loc) { |
| 114 | + Value tid = |
| 115 | + rewriter.create<::mlir::gpu::ThreadIdOp>(loc, ::mlir::gpu::Dimension::x); |
| 116 | + return rewriter.create<arith::IndexCastOp>(loc, i32_ty, tid); |
| 117 | + } |
| 118 | + |
| 119 | + // Returns CTA level thread idx. |
| 120 | +-static Value getThreadId(ConversionPatternRewriter &rewriter, Location loc) { |
| 121 | ++static Value getThreadId(RewriterBase &rewriter, Location loc) { |
| 122 | + Value tid = getThreadIdInCTA(rewriter, loc); |
| 123 | + auto mod = rewriter.getBlock()->getParent()->getParentOfType<ModuleOp>(); |
| 124 | + return tid; |
| 125 | + } |
| 126 | + |
| 127 | +-static Value getClusterCTAId(ConversionPatternRewriter &rewriter, |
| 128 | +- Location loc) { |
| 129 | ++static Value getClusterCTAId(RewriterBase &rewriter, Location loc) { |
| 130 | + return rewriter.create<triton::nvgpu::ClusterCTAIdOp>(loc, |
| 131 | + rewriter.getI32Type()); |
| 132 | + } |
| 133 | +@@ -413,8 +409,8 @@ using ::mlir::triton::gpu::DotOperandEnc |
| 134 | + using ::mlir::triton::gpu::NvidiaMmaEncodingAttr; |
| 135 | + using ::mlir::triton::gpu::SliceEncodingAttr; |
| 136 | + |
| 137 | +-static Value dot(ConversionPatternRewriter &rewriter, Location loc, |
| 138 | +- ArrayRef<Value> offsets, ArrayRef<Value> strides) { |
| 139 | ++static Value dot(RewriterBase &rewriter, Location loc, ArrayRef<Value> offsets, |
| 140 | ++ ArrayRef<Value> strides) { |
| 141 | + assert(offsets.size() == strides.size()); |
| 142 | + Value ret = i32_val(0); |
| 143 | + for (auto [offset, stride] : llvm::zip(offsets, strides)) { |
| 144 | +@@ -428,9 +424,10 @@ static Value dot(ConversionPatternRewrit |
| 145 | + // ----------------------------------------------------------------------- |
| 146 | + |
| 147 | + // Get an index-base for each dimension for a \param blockedLayout. |
| 148 | +-static SmallVector<Value> emitBaseIndexWithinCTAForBlockedLayout( |
| 149 | +- Location loc, ConversionPatternRewriter &rewriter, |
| 150 | +- const BlockedEncodingAttr &blockedLayout, RankedTensorType type) { |
| 151 | ++static SmallVector<Value> |
| 152 | ++emitBaseIndexWithinCTAForBlockedLayout(Location loc, RewriterBase &rewriter, |
| 153 | ++ const BlockedEncodingAttr &blockedLayout, |
| 154 | ++ RankedTensorType type) { |
| 155 | + auto shape = type.getShape(); |
| 156 | + Value threadId = getThreadId(rewriter, loc); |
| 157 | + Value warpSize = i32_val(32); |
| 158 | +@@ -511,9 +508,10 @@ emitOffsetForBlockedLayout(const Blocked |
| 159 | + // Mma layout indices |
| 160 | + // ----------------------------------------------------------------------- |
| 161 | + |
| 162 | +-static SmallVector<Value> emitBaseIndexWithinCTAForMmaLayoutV1( |
| 163 | +- Location loc, ConversionPatternRewriter &rewriter, |
| 164 | +- const NvidiaMmaEncodingAttr &mmaLayout, RankedTensorType type) { |
| 165 | ++static SmallVector<Value> |
| 166 | ++emitBaseIndexWithinCTAForMmaLayoutV1(Location loc, RewriterBase &rewriter, |
| 167 | ++ const NvidiaMmaEncodingAttr &mmaLayout, |
| 168 | ++ RankedTensorType type) { |
| 169 | + auto shape = type.getShape(); |
| 170 | + auto wpt = mmaLayout.getWarpsPerCTA(); |
| 171 | + static constexpr std::array<int, 3> fpw{{2, 2, 1}}; |
| 172 | +@@ -654,9 +652,10 @@ emitOffsetForMmaLayoutV2(const NvidiaMma |
| 173 | + return ret; |
| 174 | + } |
| 175 | + |
| 176 | +-static SmallVector<Value> emitBaseIndexWithinCTAForMmaLayoutV2V3( |
| 177 | +- Location loc, ConversionPatternRewriter &rewriter, |
| 178 | +- const NvidiaMmaEncodingAttr &mmaLayout, RankedTensorType type) { |
| 179 | ++static SmallVector<Value> |
| 180 | ++emitBaseIndexWithinCTAForMmaLayoutV2V3(Location loc, RewriterBase &rewriter, |
| 181 | ++ const NvidiaMmaEncodingAttr &mmaLayout, |
| 182 | ++ RankedTensorType type) { |
| 183 | + auto shape = type.getShape(); |
| 184 | + auto _warpsPerCTA = mmaLayout.getWarpsPerCTA(); |
| 185 | + auto rank = shape.size(); |
| 186 | +@@ -776,9 +775,10 @@ emitOffsetForSliceLayout(const SliceEnco |
| 187 | + // Get offsets / indices for any layout |
| 188 | + // ----------------------------------------------------------------------- |
| 189 | + |
| 190 | +-static SmallVector<Value> |
| 191 | +-emitCTAOffsetForLayout(Location loc, ConversionPatternRewriter &rewriter, |
| 192 | +- Attribute layout, ArrayRef<int64_t> shape) { |
| 193 | ++static SmallVector<Value> emitCTAOffsetForLayout(Location loc, |
| 194 | ++ RewriterBase &rewriter, |
| 195 | ++ Attribute layout, |
| 196 | ++ ArrayRef<int64_t> shape) { |
| 197 | + unsigned rank = shape.size(); |
| 198 | + SmallVector<unsigned> CTAsPerCGA = triton::gpu::getCTAsPerCGA(layout); |
| 199 | + SmallVector<unsigned> CTASplitNum = triton::gpu::getCTASplitNum(layout); |
| 200 | +@@ -806,13 +806,12 @@ emitCTAOffsetForLayout(Location loc, Con |
| 201 | + } |
| 202 | + |
| 203 | + static SmallVector<Value> |
| 204 | +-emitBaseIndexForLayout(Location loc, ConversionPatternRewriter &rewriter, |
| 205 | +- Attribute layout, RankedTensorType type, |
| 206 | +- bool withCTAOffset) { |
| 207 | ++emitBaseIndexForLayout(Location loc, RewriterBase &rewriter, Attribute layout, |
| 208 | ++ RankedTensorType type, bool withCTAOffset) { |
| 209 | + auto shape = type.getShape(); |
| 210 | + |
| 211 | + SmallVector<Value> baseIndex; |
| 212 | +- ConversionPatternRewriter::InsertionGuard guard(rewriter); |
| 213 | ++ RewriterBase::InsertionGuard guard(rewriter); |
| 214 | + SmallVector<Value> result; |
| 215 | + if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) { |
| 216 | + result = emitBaseIndexWithinCTAForBlockedLayout(loc, rewriter, |
| 217 | +@@ -866,7 +865,7 @@ emitOffsetForLayout(Attribute layout, Ra |
| 218 | + // Emit indices calculation within each ConversionPattern, and returns a |
| 219 | + // [elemsPerThread X rank] index matrix. |
| 220 | + static SmallVector<SmallVector<Value>> |
| 221 | +-emitIndices(Location loc, ConversionPatternRewriter &rewriter, Attribute layout, |
| 222 | ++emitIndices(Location loc, RewriterBase &rewriter, Attribute layout, |
| 223 | + RankedTensorType type, bool withCTAOffset) { |
| 224 | + // step 1, delinearize threadId to get the base index |
| 225 | + auto multiDimBase = |
| 226 | +@@ -892,7 +891,7 @@ emitIndices(Location loc, ConversionPatt |
| 227 | + DenseMap<unsigned, Value> static getSwizzledSharedPtrs( |
| 228 | + Location loc, unsigned inVec, RankedTensorType srcTy, |
| 229 | + triton::gpu::SharedEncodingAttr resSharedLayout, Type resElemTy, |
| 230 | +- SharedMemoryObject smemObj, ConversionPatternRewriter &rewriter, |
| 231 | ++ SharedMemoryObject smemObj, RewriterBase &rewriter, |
| 232 | + SmallVectorImpl<Value> &offsetVals, SmallVectorImpl<Value> &srcStrides) { |
| 233 | + // This utility computes the pointers for accessing the provided swizzled |
| 234 | + // shared memory layout `resSharedLayout`. More specifically, it computes, |
| 235 | +diff --git a/unittest/Conversion/TritonGPUToLLVM/DumpLayout.cpp b/unittest/Conversion/TritonGPUToLLVM/DumpLayout.cpp |
| 236 | +--- a/unittest/Conversion/TritonGPUToLLVM/DumpLayout.cpp |
| 237 | ++++ b/unittest/Conversion/TritonGPUToLLVM/DumpLayout.cpp |
| 238 | +@@ -78,7 +78,7 @@ private: |
| 239 | + LowerToLLVMOptions option; |
| 240 | + TritonGPUToLLVMTypeConverter typeConverter; |
| 241 | + Block block; |
| 242 | +- ConversionPatternRewriter rewriter; |
| 243 | ++ IRRewriter rewriter; |
| 244 | + Location loc; |
| 245 | + }; |
| 246 | + |
0 commit comments