Skip to content

Commit eefe169

Browse files
ThomasRaouxvlad-penkin
authored andcommitted
[BACKEND] Relax layout supported by SplitOp (#4653)
Remove the restriction that the split dim needs to be the fastest moving one. As long as all the registers are within a thread we can implement splitOp as a no-op. This allows more layout propagation.
1 parent 34aaee0 commit eefe169

File tree

4 files changed

+71
-32
lines changed

4 files changed

+71
-32
lines changed

lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -172,21 +172,33 @@ struct SplitOpConversion : public ConvertOpToLLVMPattern<SplitOp> {
172172
// verifier):
173173
//
174174
// - The op has a blocked encoding.
175-
// - The last dimension (the one we're spliting) is also the most minor
176-
// dimension, and has sizePerThread=2.
175+
// - The last dimension (the one we're spliting) has sizePerThread=2,
176+
// threadPerWarp=1 and warpPerBlock=1.
177177
//
178-
// With these invariants, split is trivial: Every other value goes into
179-
// return value 0, and every other goes into return value 1.
178+
// With these invariants, split is trivial: We can count how many contiguous
179+
// registers belong to the same chunk then we separate the registers between
180+
// two different chunks.
181+
int numContiguousValues = 1;
182+
auto encoding = cast<BlockedEncodingAttr>(
183+
cast<RankedTensorType>(op.getSrc().getType()).getEncoding());
184+
int splitDim = encoding.getOrder().size() - 1;
185+
for (int i = 0; i < encoding.getOrder().size(); i++) {
186+
if (encoding.getOrder()[i] == splitDim)
187+
break;
188+
numContiguousValues *= encoding.getSizePerThread()[i];
189+
}
180190
Location loc = op->getLoc();
181191
auto typeConverter = getTypeConverter();
182192
SmallVector<Value> srcVals =
183193
unpackLLElements(loc, adaptor.getSrc(), rewriter);
184194
assert(srcVals.size() % 2 == 0);
185195
SmallVector<Value> outLhsVals;
186196
SmallVector<Value> outRhsVals;
187-
for (int i = 0; i < srcVals.size(); i += 2) {
188-
outLhsVals.push_back(srcVals[i]);
189-
outRhsVals.push_back(srcVals[i + 1]);
197+
for (int i = 0; i < srcVals.size(); i += 2 * numContiguousValues) {
198+
for (int j = 0; j < numContiguousValues; j++) {
199+
outLhsVals.push_back(srcVals[i + j]);
200+
outRhsVals.push_back(srcVals[i + numContiguousValues + j]);
201+
}
190202
}
191203
auto resultTy = cast<RankedTensorType>(op.getResult(0).getType());
192204
Value retLhs =

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "triton/Dialect/Triton/IR/Utility.h"
1313
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
1414
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
15+
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
1516
#include "triton/Tools/LinearLayout.h"
1617
#include "triton/Tools/StrUtil.h"
1718
#include "triton/Tools/Sys/GetEnv.hpp"
@@ -2661,22 +2662,21 @@ struct TritonGPUInferLayoutInterface
26612662
loc, "SplitOp requires threadsPerWarp, warpsPerCTA, "
26622663
"and CTAsPerCGA = 1 for the last dimension of the input");
26632664
}
2664-
if (enc.getOrder().front() != enc.getOrder().size() - 1) {
2665-
return emitOptionalError(
2666-
loc, "SplitOp requires the last dimension to be most-minor in order");
2667-
}
26682665
if (enc.getCTALayout().getCTAsPerCGA().back() != 1) {
26692666
return emitOptionalError(
26702667
loc,
26712668
"SplitOp requires the last dimension to be most-minor in CTAOrder");
26722669
}
2673-
2670+
SmallVector<unsigned> newOrder(enc.getOrder());
2671+
int splitDim = newOrder.size() - 1;
2672+
// Remove splitDim from order.
2673+
newOrder.erase(std::remove(newOrder.begin(), newOrder.end(), splitDim),
2674+
newOrder.end());
26742675
dstEnc = BlockedEncodingAttr::get(
26752676
enc.getContext(), //
26762677
ArrayRef(enc.getSizePerThread()).drop_back(1),
26772678
ArrayRef(enc.getThreadsPerWarp()).drop_back(1),
2678-
ArrayRef(enc.getWarpsPerCTA()).drop_back(1),
2679-
ArrayRef(enc.getOrder()).drop_front(1),
2679+
ArrayRef(enc.getWarpsPerCTA()).drop_back(1), ArrayRef(newOrder),
26802680
CTALayoutAttr::get(enc.getContext(), //
26812681
ArrayRef(enc.getCTAsPerCGA()).drop_back(1),
26822682
ArrayRef(enc.getCTASplitNum()).drop_back(1),
@@ -2764,6 +2764,28 @@ struct CanonicalizeConvertFromLocalStore
27642764
}
27652765
};
27662766

2767+
struct CanonicalizeConvertFromSplit
2768+
: public mlir::OpRewritePattern<triton::SplitOp> {
2769+
using OpRewritePattern::OpRewritePattern;
2770+
2771+
mlir::LogicalResult
2772+
matchAndRewrite(triton::SplitOp op,
2773+
PatternRewriter &rewriter) const override {
2774+
auto convert = op.getSrc().getDefiningOp<ConvertLayoutOp>();
2775+
if (!convert)
2776+
return failure();
2777+
auto srcEncoding = convert.getSrc().getType().getEncoding();
2778+
// Multiple source layout can give the same output layout, if the source
2779+
// layout of the convert gives the same destination layout we can skip the
2780+
// convert.
2781+
auto dstEncoding = inferDstEncoding(op, srcEncoding);
2782+
if (dstEncoding != op.getOutLHS().getType().getEncoding())
2783+
return failure();
2784+
rewriter.replaceOpWithNewOp<triton::SplitOp>(op, convert.getSrc());
2785+
return mlir::success();
2786+
}
2787+
};
2788+
27672789
struct CanonicalizeConvertFromConvert
27682790
: public OpRewritePattern<ConvertLayoutOp> {
27692791
using OpRewritePattern::OpRewritePattern;
@@ -2896,6 +2918,7 @@ void ConvertLayoutOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
28962918
patterns.add<CanonicalizeConvertFromHistogram>(context);
28972919
patterns.add<CanonicalizeConvertFromAlloc>(context);
28982920
patterns.add<CanonicalizeConvertFromLocalStore>(context);
2921+
patterns.add<CanonicalizeConvertFromSplit>(context);
28992922
}
29002923

29012924
// LocalAllocOp
@@ -3055,7 +3078,8 @@ int32_t LocalAllocOp::getAlignmentOrDefault() {
30553078
//===----------------------------------------------------------------------===//
30563079

30573080
// Return N-D delinearized indices from a linear index.
3058-
static SmallVector<int64_t> delinearize(int64_t idx, ArrayRef<int64_t> shape) {
3081+
static SmallVector<int64_t> delinearizeIndex(int64_t idx,
3082+
ArrayRef<int64_t> shape) {
30593083
SmallVector<int64_t> ret(shape.size());
30603084
for (int i = shape.size() - 1; i >= 0; i--) {
30613085
ret[i] = idx % shape[i];
@@ -3152,7 +3176,7 @@ std::string mlir::triton::gpu::getLayoutStr(RankedTensorType tensorType,
31523176
int rank = tensorType.getRank();
31533177
bool newLine = true;
31543178
for (int i = 0; i < tensorSize; i++) {
3155-
auto indices = delinearize(i, tensorType.getShape());
3179+
auto indices = delinearizeIndex(i, tensorType.getShape());
31563180
int numOpenBracket = 0;
31573181
for (int j = rank - 1; j >= 0; j--) {
31583182
if (indices[j] % tensorType.getDimSize(j) != 0)
@@ -3167,7 +3191,7 @@ std::string mlir::triton::gpu::getLayoutStr(RankedTensorType tensorType,
31673191
}
31683192

31693193
layoutStr += elementMapping[i];
3170-
auto nextIndices = delinearize(i + 1, tensorType.getShape());
3194+
auto nextIndices = delinearizeIndex(i + 1, tensorType.getShape());
31713195
for (int j = rank - 1; j >= 0; j--) {
31723196
if (nextIndices[j] % tensorType.getDimSize(j) != 0)
31733197
break;

test/Triton/invalid.mlir

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -202,21 +202,6 @@ tt.func public @fn(%arg0: tensor<2xf32>) {
202202

203203
// -----
204204

205-
// Bad order; should start with 2.
206-
#blocked = #triton_gpu.blocked<{sizePerThread = [1,1,2], threadsPerWarp = [1,32,1], warpsPerCTA = [1,1,1], order = [1,2,0]}>
207-
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1,1], threadsPerWarp = [1,32], warpsPerCTA = [1,1], order = [1,0]}>
208-
209-
module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
210-
tt.func public @fn(%arg0: tensor<2x2x2xf32, #blocked>) {
211-
// expected-error @+2 {{last dimension}}
212-
// expected-error @+1 {{op failed to infer returned types}}
213-
%a, %b = tt.split %arg0 : tensor<2x2x2xf32, #blocked> -> tensor<2x2xf32, #blocked1>
214-
tt.return
215-
}
216-
} // end module
217-
218-
// -----
219-
220205
#blocked = #triton_gpu.blocked<{sizePerThread = [1,1,2], threadsPerWarp = [1,32,1], warpsPerCTA = [1,1,1], order = [2,0,1]}>
221206
// Bad order, should be [1,0].
222207
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1,1], threadsPerWarp = [1,32], warpsPerCTA = [1,1], order = [1,0]}>

test/TritonGPU/combine.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2589,3 +2589,21 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 :
25892589
// CHECK: tt.return %[[W]]#0, %[[W]]#1 : tensor<64x64xf32, #mma>, tensor<64x128xf32, #mma1>
25902590
}
25912591
}
2592+
2593+
// -----
2594+
2595+
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 64, 2], threadsPerWarp = [32, 1, 1], warpsPerCTA = [4, 1, 1], order = [0, 1, 2]}>
2596+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}>
2597+
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1, 2], threadsPerWarp = [1, 32, 1], warpsPerCTA = [2, 2, 1], order = [2, 1, 0]}>
2598+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
2599+
// CHECK-LABEL: @split_propagation
2600+
// CHECK-SAME: (%[[ARG:.+]]: tensor<128x64x2xf32
2601+
// CHECK: %[[S:.+]], %{{.+}} = tt.split %[[ARG]]
2602+
// CHECK: %[[C:.+]] = triton_gpu.convert_layout %[[S]]
2603+
// CHECK: tt.return %[[C]]
2604+
tt.func public @split_propagation(%arg0: tensor<128x64x2xf32, #blocked>) -> tensor<128x64xf32, #blocked1> {
2605+
%0 = triton_gpu.convert_layout %arg0 : tensor<128x64x2xf32, #blocked> -> tensor<128x64x2xf32, #blocked2>
2606+
%outLHS, %outRHS = tt.split %0 : tensor<128x64x2xf32, #blocked2> -> tensor<128x64xf32, #blocked1>
2607+
tt.return %outLHS : tensor<128x64xf32, #blocked1>
2608+
}
2609+
}

0 commit comments

Comments
 (0)