@@ -29,7 +29,7 @@ func.func @fill2DMemrefRows(%memref: memref<?x?xf32>) {
29
29
return
30
30
}
31
31
32
- func.func @testTransposedReadWithMask () {
32
+ func.func @testTransposedReadWithMask (%maskRows: index , %maskCols: index ) {
33
33
%in = memref.alloca () : memref <4 x16 xf32 >
34
34
%out = memref.alloca () : memref <16 x4 xf32 >
35
35
@@ -38,9 +38,7 @@ func.func @testTransposedReadWithMask() {
38
38
39
39
func.call @fill2DMemrefRows (%inDyn ) : (memref <?x?xf32 >) -> ()
40
40
41
- /// A mask so we only read the first 2x15 portion of %in.
42
- %maskRows = arith.constant 2 : index
43
- %maskCols = arith.constant 15 : index
41
+ /// A mask so we only read the first maskRows x maskCols portion of %in.
44
42
%mask = vector.create_mask %maskRows , %maskCols : vector <[4 ]x[16 ]xi1 >
45
43
%pad = arith.constant 0.0 : f32
46
44
%c0 = arith.constant 0 : index
@@ -59,35 +57,31 @@ func.func @testTransposedReadWithMask() {
59
57
call @printMemrefF32 (%inUnranked ) : (memref <*xf32 >) -> ()
60
58
61
59
/// Print the result memref.
62
- vector.print str " ( Masked 15x2) transposed result:"
60
+ vector.print str " Masked transposed result:"
63
61
%outUnranked = memref.cast %outDyn : memref <?x?xf32 > to memref <*xf32 >
64
62
call @printMemrefF32 (%outUnranked ) : (memref <*xf32 >) -> ()
65
63
66
64
return
67
65
}
68
66
69
- func.func @testTransposedWriteWithMask () {
67
+ func.func @testTransposedWriteWithMask (%maskRows: index , %maskCols: index ) {
70
68
%in = memref.alloca () : memref <16 x4 xf32 >
71
69
%out = memref.alloca () : memref <4 x16 xf32 >
72
70
73
- %fill = arith.constant -1 .0 : f32
74
- linalg.fill ins (%fill : f32 ) outs (%out : memref <4 x16 xf32 >)
71
+ %c0_f32 = arith.constant 0 .0 : f32
72
+ linalg.fill ins (%c0_f32 : f32 ) outs (%out : memref <4 x16 xf32 >)
75
73
76
74
%inDyn = memref.cast %in : memref <16 x4 xf32 > to memref <?x?xf32 >
77
75
%outDyn = memref.cast %out : memref <4 x16 xf32 > to memref <?x?xf32 >
78
76
79
77
func.call @fill2DMemrefRows (%inDyn ) : (memref <?x?xf32 >) -> ()
80
78
81
- %pad = arith.constant 0.0 : f32
82
- %c0 = arith.constant 0 : index
83
-
84
79
/// A regular read.
85
- %read = vector.transfer_read %inDyn [%c0 , %c0 ], %pad {in_bounds = [true , true ]}
80
+ %c0 = arith.constant 0 : index
81
+ %read = vector.transfer_read %inDyn [%c0 , %c0 ], %c0_f32 {in_bounds = [true , true ]}
86
82
: memref <?x?xf32 >, vector <[16 ]x[4 ]xf32 >
87
83
88
- /// A mask so we only write the first 3x8 portion of transpose(%in).
89
- %maskRows = arith.constant 3 : index
90
- %maskCols = arith.constant 8 : index
84
+ /// A mask so we only write the first maskRows x maskCols portion of transpose(%in).
91
85
%mask = vector.create_mask %maskRows , %maskCols : vector <[4 ]x[16 ]xi1 >
92
86
93
87
/// Write out the data with a transpose. Here (like the read test) the mask
@@ -101,7 +95,7 @@ func.func @testTransposedWriteWithMask() {
101
95
call @printMemrefF32 (%inUnranked ) : (memref <*xf32 >) -> ()
102
96
103
97
/// Print the result memref.
104
- vector.print str " ( Masked 3x8) transposed result:"
98
+ vector.print str " Masked transposed result:"
105
99
%outUnranked = memref.cast %outDyn : memref <?x?xf32 > to memref <*xf32 >
106
100
call @printMemrefF32 (%outUnranked ) : (memref <*xf32 >) -> ()
107
101
@@ -120,7 +114,7 @@ func.func @main() {
120
114
// CHECK-NEXT: [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3]
121
115
// CHECK-NEXT: [4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]
122
116
//
123
- // CHECK: (Masked 15x2) transposed result:
117
+ // CHECK: Masked transposed result:
124
118
// CHECK: [1, 2, 0, 0]
125
119
// CHECK-NEXT: [1, 2, 0, 0]
126
120
// CHECK-NEXT: [1, 2, 0, 0]
@@ -137,7 +131,9 @@ func.func @main() {
137
131
// CHECK-NEXT: [1, 2, 0, 0]
138
132
// CHECK-NEXT: [1, 2, 0, 0]
139
133
// CHECK-NEXT: [0, 0, 0, 0]
140
- func.call @testTransposedReadWithMask () : () -> ()
134
+ %readMaskRows = arith.constant 2 : index
135
+ %readMaskCols = arith.constant 15 : index
136
+ func.call @testTransposedReadWithMask (%readMaskRows , %readMaskCols ) : (index , index ) -> ()
141
137
142
138
// CHECK: Input memref:
143
139
// CHECK: [1, 1, 1, 1]
@@ -157,12 +153,14 @@ func.func @main() {
157
153
// CHECK-NEXT: [15, 15, 15, 15]
158
154
// CHECK-NEXT: [16, 16, 16, 16]
159
155
//
160
- // CHECK: (Masked 3x8) transposed result:
161
- // CHECK: [1, 2, 3, 4, 5, 6, 7, 8, -1, -1, -1, -1, -1, -1, -1, -1]
162
- // CHECK-NEXT: [1, 2, 3, 4, 5, 6, 7, 8, -1, -1, -1, -1, -1, -1, -1, -1]
163
- // CHECK-NEXT: [1, 2, 3, 4, 5, 6, 7, 8, -1, -1, -1, -1, -1, -1, -1, -1]
164
- // CHECK-NEXT: [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]
165
- func.call @testTransposedWriteWithMask () : () -> ()
156
+ // CHECK: Masked transposed result:
157
+ // CHECK: [1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0]
158
+ // CHECK-NEXT: [1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0]
159
+ // CHECK-NEXT: [1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0]
160
+ // CHECK-NEXT: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
161
+ %writeMaskRows = arith.constant 3 : index
162
+ %writeMaskCols = arith.constant 8 : index
163
+ func.call @testTransposedWriteWithMask (%writeMaskRows , %writeMaskCols ) : (index , index ) -> ()
166
164
167
165
return
168
166
}
0 commit comments