Skip to content

Commit b824db2

Browse files
Huang, Haixinzhczhong
authored andcommitted
add test & bugfix for new pass
1 parent f1c40b8 commit b824db2

File tree

3 files changed

+288
-1
lines changed

3 files changed

+288
-1
lines changed

include/gc/Transforms/Microkernel/MicrokernelPasses.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def EarlyDispatchMicrokernel: Pass<"early-dispatch-microkernel", "::mlir::Module
6262
"microkernel::MicrokernelDialect"];
6363
}
6464

65-
def MergeBranchMicrokernelContext: Pass<"merge-branch-microkernel-context", "::mlir::func::FuncOp"> {
65+
def MergeBranchMicrokernelContext: Pass<"merge-branch-microkernel-context", "::mlir::ModuleOp"> {
6666
let summary = "Find and merge identical microkernel context operations in branches into one";
6767
let description = [{
6868
Find and merge identical microkernel context operations in branches into one.

lib/gc/Transforms/Microkernel/MergeBranchMicrokernelContext.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ class BrgemmDispatchAnalysis {
3838
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(BrgemmDispatchAnalysis)
3939
explicit BrgemmDispatchAnalysis(Operation *);
4040
void setKernelDispatch(Operation *tilecfg, Operation *dispatch) {
41+
LLVM_DEBUG(llvm::dbgs() << "* setKernelDispatch: " << tilecfg << "; "
42+
<< dispatch << "\n");
4143
brgemmDispatches[tilecfg] = dispatch;
4244
};
4345
Operation *getKernelDispatch(Operation *tilecfg) const {
@@ -50,6 +52,8 @@ class BrgemmDispatchAnalysis {
5052
};
5153

5254
BrgemmDispatchAnalysis::BrgemmDispatchAnalysis(Operation *root) {
55+
LLVM_DEBUG(llvm::dbgs() << "* construct BrgemmDispatchAnalysis: " << *root
56+
<< "\n");
5357
ModuleOp module = dyn_cast_or_null<ModuleOp>(root);
5458
if (!module)
5559
return;
@@ -108,6 +112,8 @@ BrgemmDispatchAnalysis::traceDispatchInGlobalCtor(ModuleOp module,
108112
for (auto &opRef : body.getOps()) {
109113
auto *op = &opRef;
110114
auto tryCallOp = dyn_cast_or_null<func::CallOp>(op);
115+
if (!tryCallOp)
116+
continue;
111117
auto callee = tryCallOp.getCalleeAttr().getAttr();
112118
if (callee == StringAttr::get(op->getContext(), DNNL_BRGEMM_DISPATCH_NAME))
113119
return op;
@@ -122,8 +128,11 @@ extractTileOpsFromRegion(Region &region) {
122128
std::pair<Operation *, Operation *> ret{nullptr, nullptr};
123129

124130
for (auto &opRef : region.getOps()) {
131+
LLVM_DEBUG(llvm::dbgs() << ">>> " << opRef << "\n");
125132
auto *op = &opRef;
126133
auto tryCallOp = dyn_cast_or_null<func::CallOp>(op);
134+
if (!tryCallOp)
135+
continue;
127136
auto callee = tryCallOp.getCalleeAttr().getAttr();
128137
if (callee == StringAttr::get(op->getContext(), DNNL_BRGEMM_TILECFG_NAME))
129138
ret.first = op;
Lines changed: 278 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,278 @@
1+
// RUN: gc-opt %s -early-dispatch-microkernel -convert-microkernel-to-dnnl-func -merge-branch-microkernel-context -split-input-file | FileCheck %s
2+
3+
#map = affine_map<(d0, d1) -> (d0, d1)>
4+
module {
5+
func.func @simple_brgemm() {
6+
%c0_i64 = arith.constant 0 : i64
7+
%c0_index = arith.constant 0 : index
8+
%c1_index = arith.constant 1 : index
9+
%c4_index = arith.constant 4 : index
10+
%c8_index = arith.constant 8 : index
11+
%c16_i64 = arith.constant 16 : i64
12+
%cst = arith.constant 0.000000e+00 : f32
13+
%alloc = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xbf16>
14+
%alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x16x16x32x2xbf16>
15+
%alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32>
16+
scf.for %arg0 = %c0_index to %c4_index step %c1_index {
17+
scf.for %arg1 = %c0_index to %c8_index step %c1_index {
18+
%alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32>
19+
linalg.fill ins(%cst : f32) outs(%alloc_3 : memref<32x32xf32>)
20+
%subview = memref.subview %alloc[%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>
21+
%subview_4 = memref.subview %alloc_0[%arg1, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>
22+
%cmp = arith.cmpi eq, %arg0, %c0_index : index
23+
scf.if %cmp {
24+
%0 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (beta_0, stride) data_type = (bf16, bf16)
25+
microkernel.brgemm.prologue(%0) : (i64) -> ()
26+
microkernel.brgemm(%0, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> ()
27+
microkernel.brgemm.epilogue(%0) : (i64) -> ()
28+
} else {
29+
%1 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (stride) data_type = (bf16, bf16)
30+
microkernel.brgemm.prologue(%1) : (i64) -> ()
31+
microkernel.brgemm(%1, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> ()
32+
microkernel.brgemm.epilogue(%1) : (i64) -> ()
33+
}
34+
memref.dealloc %alloc_3 : memref<32x32xf32>
35+
}
36+
}
37+
return
38+
}
39+
}
40+
41+
// CHECK-LABEL: simple_brgemm
42+
43+
// CHECK: scf.for %arg0 = %c0 to %c4 step %c1
44+
// CHECK-NEXT: scf.for %arg1 = %c0 to %c8 step %c1
45+
46+
// CHECK: func.call @dnnl_brgemm_tileconfig
47+
// CHECK-NEXT: scf.if
48+
// CHECK: } else {
49+
// CHECK: }
50+
// CHECK-NEXT: func.call @dnnl_brgemm_tilerelease() : () -> ()
51+
52+
// -----
53+
54+
#map = affine_map<(d0, d1) -> (d0, d1)>
55+
module {
56+
func.func @simple_brgemm() {
57+
%c0_i64 = arith.constant 0 : i64
58+
%c0_index = arith.constant 0 : index
59+
%c1_index = arith.constant 1 : index
60+
%c4_index = arith.constant 4 : index
61+
%c8_index = arith.constant 8 : index
62+
%c16_i64 = arith.constant 16 : i64
63+
%cst = arith.constant 0.000000e+00 : f32
64+
%alloc = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xbf16>
65+
%alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x16x16x32x2xbf16>
66+
%alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32>
67+
scf.for %arg0 = %c0_index to %c4_index step %c1_index {
68+
scf.for %arg1 = %c0_index to %c8_index step %c1_index {
69+
%alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32>
70+
linalg.fill ins(%cst : f32) outs(%alloc_3 : memref<32x32xf32>)
71+
%subview = memref.subview %alloc[%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>
72+
%subview_4 = memref.subview %alloc_0[%arg1, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>
73+
%cmp = arith.cmpi eq, %arg0, %c0_index : index
74+
scf.if %cmp {
75+
%0 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (beta_0, stride) data_type = (bf16, bf16)
76+
microkernel.brgemm.prologue(%0) : (i64) -> ()
77+
microkernel.brgemm(%0, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> ()
78+
microkernel.brgemm.epilogue(%0) : (i64) -> ()
79+
}
80+
memref.dealloc %alloc_3 : memref<32x32xf32>
81+
}
82+
}
83+
return
84+
}
85+
}
86+
87+
// CHECK-LABEL: simple_brgemm
88+
89+
// CHECK: scf.for %arg0 = %c0 to %c4 step %c1
90+
// CHECK-NEXT: scf.for %arg1 = %c0 to %c8 step %c1
91+
92+
// CHECK: scf.if
93+
// CHECK: func.call @dnnl_brgemm_tileconfig
94+
// CHECK: func.call @dnnl_brgemm_tilerelease() : () -> ()
95+
// CHECK: }
96+
97+
// -----
98+
99+
#map = affine_map<(d0, d1) -> (d0, d1)>
100+
module {
101+
func.func @simple_brgemm() {
102+
%c0_i64 = arith.constant 0 : i64
103+
%c0_index = arith.constant 0 : index
104+
%c1_index = arith.constant 1 : index
105+
%c4_index = arith.constant 4 : index
106+
%c8_index = arith.constant 8 : index
107+
%c16_i64 = arith.constant 16 : i64
108+
%cst = arith.constant 0.000000e+00 : f32
109+
%alloc = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xbf16>
110+
%alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x16x16x32x2xbf16>
111+
%alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32>
112+
scf.for %arg0 = %c0_index to %c4_index step %c1_index {
113+
scf.for %arg1 = %c0_index to %c8_index step %c1_index {
114+
%alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32>
115+
linalg.fill ins(%cst : f32) outs(%alloc_3 : memref<32x32xf32>)
116+
%subview = memref.subview %alloc[%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>
117+
%subview_4 = memref.subview %alloc_0[%arg1, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>
118+
%cmp = arith.cmpi eq, %arg0, %c0_index : index
119+
scf.if %cmp {
120+
%0 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (beta_0, stride) data_type = (bf16, bf16)
121+
microkernel.brgemm.prologue(%0) : (i64) -> ()
122+
microkernel.brgemm(%0, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> ()
123+
microkernel.brgemm.epilogue(%0) : (i64) -> ()
124+
} else {
125+
%1 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 512, 512] flags = (stride) data_type = (bf16, bf16)
126+
microkernel.brgemm.prologue(%1) : (i64) -> ()
127+
microkernel.brgemm(%1, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> ()
128+
microkernel.brgemm.epilogue(%1) : (i64) -> ()
129+
}
130+
memref.dealloc %alloc_3 : memref<32x32xf32>
131+
}
132+
}
133+
return
134+
}
135+
}
136+
137+
// CHECK-LABEL: simple_brgemm
138+
139+
// CHECK: scf.for %arg0 = %c0 to %c4 step %c1
140+
// CHECK-NEXT: scf.for %arg1 = %c0 to %c8 step %c1
141+
142+
// CHECK: scf.if
143+
// CHECK: func.call @dnnl_brgemm_tileconfig
144+
// CHECK: func.call @dnnl_brgemm_tilerelease() : () -> ()
145+
// CHECK: } else {
146+
// CHECK: func.call @dnnl_brgemm_tileconfig
147+
// CHECK: func.call @dnnl_brgemm_tilerelease() : () -> ()
148+
// CHECK: }
149+
150+
// -----
151+
152+
#map = affine_map<(d0, d1) -> (d0, d1)>
153+
module {
154+
func.func @simple_brgemm() {
155+
%c0_i64 = arith.constant 0 : i64
156+
%c0_index = arith.constant 0 : index
157+
%c1_index = arith.constant 1 : index
158+
%c4_index = arith.constant 4 : index
159+
%c8_index = arith.constant 8 : index
160+
%c16_i64 = arith.constant 16 : i64
161+
%cst = arith.constant 0.000000e+00 : f32
162+
%alloc = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xbf16>
163+
%alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x16x16x32x2xbf16>
164+
%alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32>
165+
scf.for %arg0 = %c0_index to %c4_index step %c1_index {
166+
scf.for %arg1 = %c0_index to %c8_index step %c1_index {
167+
%alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32>
168+
linalg.fill ins(%cst : f32) outs(%alloc_3 : memref<32x32xf32>)
169+
%subview = memref.subview %alloc[%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>
170+
%subview_4 = memref.subview %alloc_0[%arg1, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>
171+
scf.index_switch %arg0
172+
case 0 {
173+
%0 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (beta_0, stride) data_type = (bf16, bf16)
174+
microkernel.brgemm.prologue(%0) : (i64) -> ()
175+
microkernel.brgemm(%0, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> ()
176+
microkernel.brgemm.epilogue(%0) : (i64) -> ()
177+
scf.yield
178+
}
179+
case 1 {
180+
%1 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (stride) data_type = (bf16, bf16)
181+
microkernel.brgemm.prologue(%1) : (i64) -> ()
182+
microkernel.brgemm(%1, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> ()
183+
microkernel.brgemm.epilogue(%1) : (i64) -> ()
184+
scf.yield
185+
}
186+
default {
187+
%2 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (stride) data_type = (bf16, bf16)
188+
microkernel.brgemm.prologue(%2) : (i64) -> ()
189+
microkernel.brgemm(%2, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> ()
190+
microkernel.brgemm.epilogue(%2) : (i64) -> ()
191+
scf.yield
192+
}
193+
memref.dealloc %alloc_3 : memref<32x32xf32>
194+
}
195+
}
196+
return
197+
}
198+
}
199+
200+
// CHECK-LABEL: simple_brgemm
201+
202+
// CHECK: scf.for %arg0 = %c0 to %c4 step %c1
203+
// CHECK-NEXT: scf.for %arg1 = %c0 to %c8 step %c1
204+
205+
// CHECK: func.call @dnnl_brgemm_tileconfig
206+
// CHECK-NEXT: scf.index_switch
207+
// CHECK: case 0 {
208+
// CHECK: case 1 {
209+
// CHECK: default {
210+
// CHECK: }
211+
// CHECK-NEXT: func.call @dnnl_brgemm_tilerelease() : () -> ()
212+
213+
// -----
214+
215+
#map = affine_map<(d0, d1) -> (d0, d1)>
216+
module {
217+
func.func @simple_brgemm() {
218+
%c0_i64 = arith.constant 0 : i64
219+
%c0_index = arith.constant 0 : index
220+
%c1_index = arith.constant 1 : index
221+
%c4_index = arith.constant 4 : index
222+
%c8_index = arith.constant 8 : index
223+
%c16_i64 = arith.constant 16 : i64
224+
%cst = arith.constant 0.000000e+00 : f32
225+
%alloc = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xbf16>
226+
%alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x16x16x32x2xbf16>
227+
%alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32>
228+
scf.for %arg0 = %c0_index to %c4_index step %c1_index {
229+
scf.for %arg1 = %c0_index to %c8_index step %c1_index {
230+
%alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32>
231+
linalg.fill ins(%cst : f32) outs(%alloc_3 : memref<32x32xf32>)
232+
%subview = memref.subview %alloc[%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>
233+
%subview_4 = memref.subview %alloc_0[%arg1, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>
234+
scf.index_switch %arg0
235+
case 0 {
236+
%0 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (beta_0, stride) data_type = (bf16, bf16)
237+
microkernel.brgemm.prologue(%0) : (i64) -> ()
238+
microkernel.brgemm(%0, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> ()
239+
microkernel.brgemm.epilogue(%0) : (i64) -> ()
240+
scf.yield
241+
}
242+
case 1 {
243+
%1 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (stride) data_type = (bf16, bf16)
244+
microkernel.brgemm.prologue(%1) : (i64) -> ()
245+
microkernel.brgemm(%1, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> ()
246+
microkernel.brgemm.epilogue(%1) : (i64) -> ()
247+
scf.yield
248+
}
249+
default {
250+
%2 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 512, 512] flags = (stride) data_type = (bf16, bf16)
251+
microkernel.brgemm.prologue(%2) : (i64) -> ()
252+
microkernel.brgemm(%2, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> ()
253+
microkernel.brgemm.epilogue(%2) : (i64) -> ()
254+
scf.yield
255+
}
256+
memref.dealloc %alloc_3 : memref<32x32xf32>
257+
}
258+
}
259+
return
260+
}
261+
}
262+
263+
// CHECK-LABEL: simple_brgemm
264+
265+
// CHECK: scf.for %arg0 = %c0 to %c4 step %c1
266+
// CHECK-NEXT: scf.for %arg1 = %c0 to %c8 step %c1
267+
268+
// CHECK: scf.index_switch
269+
// CHECK: case 0 {
270+
// CHECK: func.call @dnnl_brgemm_tileconfig
271+
// CHECK: func.call @dnnl_brgemm_tilerelease() : () -> ()
272+
// CHECK: case 1 {
273+
// CHECK: func.call @dnnl_brgemm_tileconfig
274+
// CHECK: func.call @dnnl_brgemm_tilerelease() : () -> ()
275+
// CHECK: default {
276+
// CHECK: func.call @dnnl_brgemm_tileconfig
277+
// CHECK: func.call @dnnl_brgemm_tilerelease() : () -> ()
278+
// CHECK: }

0 commit comments

Comments
 (0)