12
12
#include " triton/Dialect/Triton/IR/Utility.h"
13
13
#include " triton/Dialect/TritonGPU/IR/Dialect.h"
14
14
#include " triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
15
+ #include " triton/Dialect/TritonGPU/Transforms/Utility.h"
15
16
#include " triton/Tools/LinearLayout.h"
16
17
#include " triton/Tools/StrUtil.h"
17
18
#include " triton/Tools/Sys/GetEnv.hpp"
@@ -2661,22 +2662,21 @@ struct TritonGPUInferLayoutInterface
2661
2662
loc, " SplitOp requires threadsPerWarp, warpsPerCTA, "
2662
2663
" and CTAsPerCGA = 1 for the last dimension of the input" );
2663
2664
}
2664
- if (enc.getOrder ().front () != enc.getOrder ().size () - 1 ) {
2665
- return emitOptionalError (
2666
- loc, " SplitOp requires the last dimension to be most-minor in order" );
2667
- }
2668
2665
if (enc.getCTALayout ().getCTAsPerCGA ().back () != 1 ) {
2669
2666
return emitOptionalError (
2670
2667
loc,
2671
2668
" SplitOp requires the last dimension to be most-minor in CTAOrder" );
2672
2669
}
2673
-
2670
+ SmallVector<unsigned > newOrder (enc.getOrder ());
2671
+ int splitDim = newOrder.size () - 1 ;
2672
+ // Remove splitDim from order.
2673
+ newOrder.erase (std::remove (newOrder.begin (), newOrder.end (), splitDim),
2674
+ newOrder.end ());
2674
2675
dstEnc = BlockedEncodingAttr::get (
2675
2676
enc.getContext (), //
2676
2677
ArrayRef (enc.getSizePerThread ()).drop_back (1 ),
2677
2678
ArrayRef (enc.getThreadsPerWarp ()).drop_back (1 ),
2678
- ArrayRef (enc.getWarpsPerCTA ()).drop_back (1 ),
2679
- ArrayRef (enc.getOrder ()).drop_front (1 ),
2679
+ ArrayRef (enc.getWarpsPerCTA ()).drop_back (1 ), ArrayRef (newOrder),
2680
2680
CTALayoutAttr::get (enc.getContext (), //
2681
2681
ArrayRef (enc.getCTAsPerCGA ()).drop_back (1 ),
2682
2682
ArrayRef (enc.getCTASplitNum ()).drop_back (1 ),
@@ -2764,6 +2764,28 @@ struct CanonicalizeConvertFromLocalStore
2764
2764
}
2765
2765
};
2766
2766
2767
+ struct CanonicalizeConvertFromSplit
2768
+ : public mlir::OpRewritePattern<triton::SplitOp> {
2769
+ using OpRewritePattern::OpRewritePattern;
2770
+
2771
+ mlir::LogicalResult
2772
+ matchAndRewrite (triton::SplitOp op,
2773
+ PatternRewriter &rewriter) const override {
2774
+ auto convert = op.getSrc ().getDefiningOp <ConvertLayoutOp>();
2775
+ if (!convert)
2776
+ return failure ();
2777
+ auto srcEncoding = convert.getSrc ().getType ().getEncoding ();
2778
+ // Multiple source layout can give the same output layout, if the source
2779
+ // layout of the convert gives the same destination layout we can skip the
2780
+ // convert.
2781
+ auto dstEncoding = inferDstEncoding (op, srcEncoding);
2782
+ if (dstEncoding != op.getOutLHS ().getType ().getEncoding ())
2783
+ return failure ();
2784
+ rewriter.replaceOpWithNewOp <triton::SplitOp>(op, convert.getSrc ());
2785
+ return mlir::success ();
2786
+ }
2787
+ };
2788
+
2767
2789
struct CanonicalizeConvertFromConvert
2768
2790
: public OpRewritePattern<ConvertLayoutOp> {
2769
2791
using OpRewritePattern::OpRewritePattern;
@@ -2896,6 +2918,7 @@ void ConvertLayoutOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2896
2918
patterns.add <CanonicalizeConvertFromHistogram>(context);
2897
2919
patterns.add <CanonicalizeConvertFromAlloc>(context);
2898
2920
patterns.add <CanonicalizeConvertFromLocalStore>(context);
2921
+ patterns.add <CanonicalizeConvertFromSplit>(context);
2899
2922
}
2900
2923
2901
2924
// LocalAllocOp
@@ -3055,7 +3078,8 @@ int32_t LocalAllocOp::getAlignmentOrDefault() {
3055
3078
// ===----------------------------------------------------------------------===//
3056
3079
3057
3080
// Return N-D delinearized indices from a linear index.
3058
- static SmallVector<int64_t > delinearize (int64_t idx, ArrayRef<int64_t > shape) {
3081
+ static SmallVector<int64_t > delinearizeIndex (int64_t idx,
3082
+ ArrayRef<int64_t > shape) {
3059
3083
SmallVector<int64_t > ret (shape.size ());
3060
3084
for (int i = shape.size () - 1 ; i >= 0 ; i--) {
3061
3085
ret[i] = idx % shape[i];
@@ -3152,7 +3176,7 @@ std::string mlir::triton::gpu::getLayoutStr(RankedTensorType tensorType,
3152
3176
int rank = tensorType.getRank ();
3153
3177
bool newLine = true ;
3154
3178
for (int i = 0 ; i < tensorSize; i++) {
3155
- auto indices = delinearize (i, tensorType.getShape ());
3179
+ auto indices = delinearizeIndex (i, tensorType.getShape ());
3156
3180
int numOpenBracket = 0 ;
3157
3181
for (int j = rank - 1 ; j >= 0 ; j--) {
3158
3182
if (indices[j] % tensorType.getDimSize (j) != 0 )
@@ -3167,7 +3191,7 @@ std::string mlir::triton::gpu::getLayoutStr(RankedTensorType tensorType,
3167
3191
}
3168
3192
3169
3193
layoutStr += elementMapping[i];
3170
- auto nextIndices = delinearize (i + 1 , tensorType.getShape ());
3194
+ auto nextIndices = delinearizeIndex (i + 1 , tensorType.getShape ());
3171
3195
for (int j = rank - 1 ; j >= 0 ; j--) {
3172
3196
if (nextIndices[j] % tensorType.getDimSize (j) != 0 )
3173
3197
break ;
0 commit comments