4
4
func.func @mlp (%in: tensor <128 x512 xbf16 >,
5
5
%weight0: tensor <512 x64 xbf16 >, %bias0: tensor <64 xbf16 >,
6
6
%weight1: tensor <64 x256 xbf16 >, %bias1: tensor <256 xbf16 >) -> tensor <128 x256 xbf16 > {
7
+ // CHECK: [[MM1:%.+]] = onednn_graph.matmul
8
+ // CHECK: [[RL1:%.+]] = onednn_graph.relu [[MM1]]
9
+ // CHECK: [[MM2:%.+]] = onednn_graph.matmul
10
+ // CHECK: [[AD2:%.+]] = onednn_graph.add [[MM2]]
11
+ // CHECK: [[RL2:%.+]] = onednn_graph.relu [[AD2]]
12
+ // CHECK: return [[RL2]]
7
13
%0 = onednn_graph.matmul %in , %weight0 , %bias0
8
14
: (tensor <128 x512 xbf16 >, tensor <512 x64 xbf16 >, tensor <64 xbf16 >) -> tensor <128 x64 xbf16 >
9
15
%1 = onednn_graph.relu %0 : (tensor <128 x64 xbf16 >) -> tensor <128 x64 xbf16 >
@@ -17,6 +23,10 @@ func.func @mlp(%in: tensor<128x512xbf16>,
17
23
// CHECK-LABEL: @mlp_transpose_a
18
24
func.func @mlp_transpose_a (%in: tensor <512 x128 xbf16 >,
19
25
%weight0: tensor <512 x256 xbf16 >, %bias0: tensor <256 xbf16 >) -> tensor <128 x256 xbf16 > {
26
+ // CHECK: [[MM1:%.+]] = onednn_graph.matmul
27
+ // CHECK: {transpose_a = true}
28
+ // CHECK-NEXT: [[RL1:%.+]] = onednn_graph.relu [[MM1]]
29
+ // CHECK-NEXT: return [[RL1]]
20
30
%0 = onednn_graph.matmul %in , %weight0 , %bias0 {transpose_a = true }
21
31
: (tensor <512 x128 xbf16 >, tensor <512 x256 xbf16 >, tensor <256 xbf16 >) -> tensor <128 x256 xbf16 >
22
32
%1 = onednn_graph.relu %0 : (tensor <128 x256 xbf16 >) -> tensor <128 x256 xbf16 >
@@ -26,6 +36,10 @@ func.func @mlp_transpose_a(%in: tensor<512x128xbf16>,
26
36
// CHECK-LABEL: @mlp_transpose_b
27
37
func.func @mlp_transpose_b (%in: tensor <128 x512 xbf16 >,
28
38
%weight0: tensor <256 x512 xbf16 >, %bias0: tensor <256 xbf16 >) -> tensor <128 x256 xbf16 > {
39
+ // CHECK: [[MM1:%.+]] = onednn_graph.matmul
40
+ // CHECK: {transpose_b = true}
41
+ // CHECK-NEXT: [[RL1:%.+]] = onednn_graph.relu [[MM1]]
42
+ // CHECK-NEXT: return [[RL1]]
29
43
%0 = onednn_graph.matmul %in , %weight0 , %bias0 {transpose_b = true }
30
44
: (tensor <128 x512 xbf16 >, tensor <256 x512 xbf16 >, tensor <256 xbf16 >) -> tensor <128 x256 xbf16 >
31
45
%1 = onednn_graph.relu %0 : (tensor <128 x256 xbf16 >) -> tensor <128 x256 xbf16 >
@@ -35,6 +49,10 @@ func.func @mlp_transpose_b(%in: tensor<128x512xbf16>,
35
49
// CHECK-LABEL: @mlp_transpose_a_b
36
50
func.func @mlp_transpose_a_b (%in: tensor <512 x128 xbf16 >,
37
51
%weight0: tensor <256 x512 xbf16 >, %bias0: tensor <256 xbf16 >) -> tensor <128 x256 xbf16 > {
52
+ // CHECK: [[MM1:%.+]] = onednn_graph.matmul
53
+ // CHECK: {transpose_a = true, transpose_b = true}
54
+ // CHECK-NEXT: [[RL1:%.+]] = onednn_graph.relu [[MM1]]
55
+ // CHECK-NEXT: return [[RL1]]
38
56
%0 = onednn_graph.matmul %in , %weight0 , %bias0 {transpose_a = true , transpose_b = true }
39
57
: (tensor <512 x128 xbf16 >, tensor <256 x512 xbf16 >, tensor <256 xbf16 >) -> tensor <128 x256 xbf16 >
40
58
%1 = onednn_graph.relu %0 : (tensor <128 x256 xbf16 >) -> tensor <128 x256 xbf16 >
0 commit comments