Skip to content

Commit d6d6e7e

Browse files
committed
uplift pack over broadcast
1 parent b26b37e commit d6d6e7e

File tree

3 files changed

+148
-52
lines changed

3 files changed

+148
-52
lines changed

lib/gc/Transforms/PropagateLayout.cpp

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "mlir/IR/PatternMatch.h"
2020
#include "mlir/Transforms/DialectConversion.h"
2121
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22+
#include "llvm/ADT/DenseMap.h"
2223

2324
#include "gc/Dialect/Linalgx/LinalgxDialect.h"
2425
#include "gc/Dialect/Linalgx/LinalgxOps.h"
@@ -495,6 +496,83 @@ struct PackVNNI<linalg::GenericOp>
495496
}
496497
};
497498

499+
/*
500+
Match patterns like broadcast + pack, uplift pack
501+
*/
502+
struct UpliftPackOverBroadcast : public OpRewritePattern<tensor::PackOp> {
503+
UpliftPackOverBroadcast(MLIRContext *context, PatternBenefit benefit = 1)
504+
: OpRewritePattern<tensor::PackOp>(context, benefit) {}
505+
LogicalResult matchAndRewrite(tensor::PackOp pack,
506+
PatternRewriter &rewriter) const override {
507+
auto broadcastOp = pack.getSource().getDefiningOp<linalg::BroadcastOp>();
508+
if (!broadcastOp || !broadcastOp.getResult()[0].hasOneUse()) {
509+
return failure();
510+
}
511+
SmallVector<int64_t> innerTileSizes = pack.getStaticTiles();
512+
SmallVector<int64_t> innerDimsPos(pack.getInnerDimsPos());
513+
SmallVector<int64_t> outerDimsPerm(pack.getOuterDimsPerm());
514+
int64_t rank =
515+
cast<ShapedType>(pack.getSource().getType()).getShape().size();
516+
if (outerDimsPerm.empty()) {
517+
outerDimsPerm.resize(rank);
518+
std::iota(outerDimsPerm.begin(), outerDimsPerm.end(), 0);
519+
}
520+
ArrayRef<int64_t> broadcastAxis = broadcastOp.getDimensions();
521+
SmallVector<int64_t> newInnerDimsPos, newOuterDimsPerm, packedBroadcastAxis;
522+
SmallVector<OpFoldResult> newInnerTileSizes;
523+
llvm::SmallDenseMap<int64_t, int64_t> axisMapping;
524+
int64_t axisCounter = 0;
525+
for (int64_t axis = 0; axis < rank; ++axis) {
526+
if (std::find(broadcastAxis.begin(), broadcastAxis.end(), axis) ==
527+
broadcastAxis.end()) {
528+
// if the axis is not broadcasted, keep it
529+
axisMapping[axis] = axisCounter++;
530+
}
531+
}
532+
// update broadcast dims
533+
for (auto [index, axis] : llvm::enumerate(outerDimsPerm)) {
534+
if (std::find(broadcastAxis.begin(), broadcastAxis.end(), axis) !=
535+
broadcastAxis.end()) {
536+
packedBroadcastAxis.push_back(index);
537+
}
538+
}
539+
for (auto [index, axis] : llvm::enumerate(innerDimsPos)) {
540+
if (std::find(broadcastAxis.begin(), broadcastAxis.end(), axis) !=
541+
broadcastAxis.end()) {
542+
packedBroadcastAxis.push_back(index + rank);
543+
}
544+
}
545+
// update packing axis
546+
for (auto [index, axis] : llvm::enumerate(outerDimsPerm)) {
547+
if (std::find(broadcastAxis.begin(), broadcastAxis.end(), axis) ==
548+
broadcastAxis.end()) {
549+
newOuterDimsPerm.push_back(axisMapping[axis]);
550+
}
551+
}
552+
for (auto [index, axis] : llvm::enumerate(innerDimsPos)) {
553+
if (std::find(broadcastAxis.begin(), broadcastAxis.end(), axis) ==
554+
broadcastAxis.end()) {
555+
newInnerDimsPos.push_back(axisMapping[axis]);
556+
newInnerTileSizes.push_back(
557+
rewriter.getIndexAttr(innerTileSizes[index]));
558+
}
559+
}
560+
// replace ops
561+
auto loc = broadcastOp.getLoc();
562+
auto dest = tensor::PackOp::createDestinationTensor(
563+
rewriter, loc, broadcastOp.getDpsInputs()[0], newInnerTileSizes,
564+
newInnerDimsPos, newOuterDimsPerm);
565+
Value packedSource = rewriter.create<tensor::PackOp>(
566+
loc, broadcastOp.getDpsInputs()[0], dest, newInnerDimsPos,
567+
newInnerTileSizes,
568+
/*padding=*/std::nullopt, newOuterDimsPerm);
569+
auto newBroadcastOp = rewriter.create<linalg::BroadcastOp>(
570+
loc, packedSource, pack.getDest(), packedBroadcastAxis);
571+
rewriter.replaceOp(pack, newBroadcastOp.getResults());
572+
return success();
573+
}
574+
};
575+
498576
void PropagateLayoutOnNamedOps::runOnOperation() {
499577
MLIRContext *ctx = &getContext();
500578
mlir::Operation *graph = getOperation();
@@ -541,6 +619,12 @@ void PropagateLayoutOnNamedOps::runOnOperation() {
541619
};
542620
if (failed(namedOpLayoutPropagation(ctx, graph, layoutControlFn)))
543621
return signalPassFailure();
622+
623+
// stage4: uplift pack through broadcast
624+
RewritePatternSet upliftPatterns(&getContext());
625+
upliftPatterns.add<UpliftPackOverBroadcast>(ctx);
626+
if (failed(applyPatternsAndFoldGreedily(graph, std::move(upliftPatterns))))
627+
return signalPassFailure();
544628
}
545629

546630
} // namespace gc
Lines changed: 6 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,12 @@
11
// RUN: gc-opt %s --split-input-file --propagate-layout-on-named-ops | FileCheck %s
22

3-
// CHECK-LABEL: @single_matmul_f32
4-
func.func @single_matmul_f32(%arg0: tensor<128x64xf32>, %arg1: tensor<64x32xf32>) -> tensor<128x32xf32> {
3+
// CHECK-LABEL: @matmul_add
4+
func.func @matmul_add(%arg0: tensor<128x64xf32>, %arg1: tensor<64x32xf32>, %arg2: tensor<32xf32>) -> tensor<128x32xf32> {
55
%cst = arith.constant 0.000000e+00 : f32
66
%0 = tensor.empty() : tensor<128x32xf32>
77
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<128x32xf32>) -> tensor<128x32xf32>
8-
%2 = linalg.matmul ins(%arg0, %arg1 : tensor<128x64xf32>, tensor<64x32xf32>) outs(%0 : tensor<128x32xf32>) -> tensor<128x32xf32>
9-
return %2 : tensor<128x32xf32>
8+
%2 = linalg.matmul ins(%arg0, %arg1 : tensor<128x64xf32>, tensor<64x32xf32>) outs(%1 : tensor<128x32xf32>) -> tensor<128x32xf32>
9+
%3 = linalg.broadcast ins(%arg2 : tensor<32xf32>) outs(%0 : tensor<128x32xf32>) dimensions = [0]
10+
%4 = linalg.add ins(%2, %3 : tensor<128x32xf32>, tensor<128x32xf32>) outs(%0 : tensor<128x32xf32>) -> tensor<128x32xf32>
11+
return %4 : tensor<128x32xf32>
1012
}
11-
// CHECK-COUNT-3: tensor.pack
12-
// CHECK-COUNT-1: linalg.generic
13-
// CHECK-COUNT-1: tensor.unpack
14-
15-
// CHECK-LABEL: @single_matmul_bf16
16-
func.func @single_matmul_bf16(%arg0: tensor<128x64xbf16>, %arg1: tensor<64x32xbf16>) -> tensor<128x32xbf16> {
17-
%cst = arith.constant 0.000000e+00 : bf16
18-
%0 = tensor.empty() : tensor<128x32xbf16>
19-
%1 = linalg.fill ins(%cst : bf16) outs(%0 : tensor<128x32xbf16>) -> tensor<128x32xbf16>
20-
%2 = linalg.matmul ins(%arg0, %arg1 : tensor<128x64xbf16>, tensor<64x32xbf16>) outs(%0 : tensor<128x32xbf16>) -> tensor<128x32xbf16>
21-
return %2 : tensor<128x32xbf16>
22-
}
23-
// CHECK-COUNT-4: tensor.pack
24-
// CHECK-COUNT-1: linalgx.mm4d_vnni
25-
// CHECK-COUNT-1: tensor.unpack
26-
27-
// CHECK-LABEL: @single_batch_matmul_bf16
28-
func.func @single_batch_matmul_bf16(%arg0: tensor<64x128x64xbf16>, %arg1: tensor<64x64x32xbf16>) -> tensor<64x128x32xbf16> {
29-
%cst = arith.constant 0.000000e+00 : bf16
30-
%0 = tensor.empty() : tensor<64x128x32xbf16>
31-
%1 = linalg.fill ins(%cst : bf16) outs(%0 : tensor<64x128x32xbf16>) -> tensor<64x128x32xbf16>
32-
%2 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<64x128x64xbf16>, tensor<64x64x32xbf16>) outs(%0 : tensor<64x128x32xbf16>) -> tensor<64x128x32xbf16>
33-
return %2 : tensor<64x128x32xbf16>
34-
}
35-
// CHECK-COUNT-4: tensor.pack
36-
// CHECK-COUNT-1: linalg.generic
37-
// CHECK-COUNT-1: tensor.unpack
38-
39-
func.func @pack_vnni_mmt4d(%arg0: tensor<4x2x32x32xbf16>, %arg1: tensor<1x2x32x32xbf16>) -> tensor<4x1x32x32xbf16> {
40-
%cst = arith.constant 0.000000e+00 : bf16
41-
%0 = tensor.empty() : tensor<4x1x32x32xbf16>
42-
%1 = linalg.fill ins(%cst : bf16) outs(%0 : tensor<4x1x32x32xbf16>) -> tensor<4x1x32x32xbf16>
43-
%2 = linalg.mmt4d ins(%arg0, %arg1 : tensor<4x2x32x32xbf16>, tensor<1x2x32x32xbf16>) outs(%0 : tensor<4x1x32x32xbf16>) -> tensor<4x1x32x32xbf16>
44-
return %2 : tensor<4x1x32x32xbf16>
45-
}
46-
// CHECK-COUNT-1: tensor.pack
47-
// CHECK-COUNT-1: linalgx.mm4d_vnni
48-
49-
func.func @pack_vnni_batchmmt4d(%arg0: tensor<4x4x2x32x32xbf16>, %arg1: tensor<4x1x2x32x32xbf16>) -> tensor<4x4x1x32x32xbf16> {
50-
%cst = arith.constant 0.000000e+00 : bf16
51-
%0 = tensor.empty() : tensor<4x4x1x32x32xbf16>
52-
%1 = linalg.fill ins(%cst : bf16) outs(%0 : tensor<4x4x1x32x32xbf16>) -> tensor<4x4x1x32x32xbf16>
53-
%2 = linalg.batch_mmt4d ins(%arg0, %arg1 : tensor<4x4x2x32x32xbf16>, tensor<4x1x2x32x32xbf16>) outs(%0 : tensor<4x4x1x32x32xbf16>) -> tensor<4x4x1x32x32xbf16>
54-
return %2 : tensor<4x4x1x32x32xbf16>
55-
}
56-
// CHECK-COUNT-1: tensor.pack
57-
// CHECK-COUNT-1: linalg.generic
58-
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
// RUN: gc-opt %s --split-input-file --propagate-layout-on-named-ops | FileCheck %s
2+
3+
// CHECK-LABEL: @single_matmul_f32
4+
func.func @single_matmul_f32(%arg0: tensor<128x64xf32>, %arg1: tensor<64x32xf32>) -> tensor<128x32xf32> {
5+
%cst = arith.constant 0.000000e+00 : f32
6+
%0 = tensor.empty() : tensor<128x32xf32>
7+
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<128x32xf32>) -> tensor<128x32xf32>
8+
%2 = linalg.matmul ins(%arg0, %arg1 : tensor<128x64xf32>, tensor<64x32xf32>) outs(%0 : tensor<128x32xf32>) -> tensor<128x32xf32>
9+
return %2 : tensor<128x32xf32>
10+
}
11+
// CHECK-COUNT-3: tensor.pack
12+
// CHECK-COUNT-1: linalg.generic
13+
// CHECK-COUNT-1: tensor.unpack
14+
15+
// CHECK-LABEL: @single_matmul_bf16
16+
func.func @single_matmul_bf16(%arg0: tensor<128x64xbf16>, %arg1: tensor<64x32xbf16>) -> tensor<128x32xbf16> {
17+
%cst = arith.constant 0.000000e+00 : bf16
18+
%0 = tensor.empty() : tensor<128x32xbf16>
19+
%1 = linalg.fill ins(%cst : bf16) outs(%0 : tensor<128x32xbf16>) -> tensor<128x32xbf16>
20+
%2 = linalg.matmul ins(%arg0, %arg1 : tensor<128x64xbf16>, tensor<64x32xbf16>) outs(%0 : tensor<128x32xbf16>) -> tensor<128x32xbf16>
21+
return %2 : tensor<128x32xbf16>
22+
}
23+
// CHECK-COUNT-4: tensor.pack
24+
// CHECK-COUNT-1: linalgx.mm4d_vnni
25+
// CHECK-COUNT-1: tensor.unpack
26+
27+
// CHECK-LABEL: @single_batch_matmul_bf16
28+
func.func @single_batch_matmul_bf16(%arg0: tensor<64x128x64xbf16>, %arg1: tensor<64x64x32xbf16>) -> tensor<64x128x32xbf16> {
29+
%cst = arith.constant 0.000000e+00 : bf16
30+
%0 = tensor.empty() : tensor<64x128x32xbf16>
31+
%1 = linalg.fill ins(%cst : bf16) outs(%0 : tensor<64x128x32xbf16>) -> tensor<64x128x32xbf16>
32+
%2 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<64x128x64xbf16>, tensor<64x64x32xbf16>) outs(%0 : tensor<64x128x32xbf16>) -> tensor<64x128x32xbf16>
33+
return %2 : tensor<64x128x32xbf16>
34+
}
35+
// CHECK-COUNT-4: tensor.pack
36+
// CHECK-COUNT-1: linalg.generic
37+
// CHECK-COUNT-1: tensor.unpack
38+
39+
func.func @pack_vnni_mmt4d(%arg0: tensor<4x2x32x32xbf16>, %arg1: tensor<1x2x32x32xbf16>) -> tensor<4x1x32x32xbf16> {
40+
%cst = arith.constant 0.000000e+00 : bf16
41+
%0 = tensor.empty() : tensor<4x1x32x32xbf16>
42+
%1 = linalg.fill ins(%cst : bf16) outs(%0 : tensor<4x1x32x32xbf16>) -> tensor<4x1x32x32xbf16>
43+
%2 = linalg.mmt4d ins(%arg0, %arg1 : tensor<4x2x32x32xbf16>, tensor<1x2x32x32xbf16>) outs(%0 : tensor<4x1x32x32xbf16>) -> tensor<4x1x32x32xbf16>
44+
return %2 : tensor<4x1x32x32xbf16>
45+
}
46+
// CHECK-COUNT-1: tensor.pack
47+
// CHECK-COUNT-1: linalgx.mm4d_vnni
48+
49+
func.func @pack_vnni_batchmmt4d(%arg0: tensor<4x4x2x32x32xbf16>, %arg1: tensor<4x1x2x32x32xbf16>) -> tensor<4x4x1x32x32xbf16> {
50+
%cst = arith.constant 0.000000e+00 : bf16
51+
%0 = tensor.empty() : tensor<4x4x1x32x32xbf16>
52+
%1 = linalg.fill ins(%cst : bf16) outs(%0 : tensor<4x4x1x32x32xbf16>) -> tensor<4x4x1x32x32xbf16>
53+
%2 = linalg.batch_mmt4d ins(%arg0, %arg1 : tensor<4x4x2x32x32xbf16>, tensor<4x1x2x32x32xbf16>) outs(%0 : tensor<4x4x1x32x32xbf16>) -> tensor<4x4x1x32x32xbf16>
54+
return %2 : tensor<4x4x1x32x32xbf16>
55+
}
56+
// CHECK-COUNT-1: tensor.pack
57+
// CHECK-COUNT-1: linalg.generic
58+

0 commit comments

Comments
 (0)