Skip to content

Commit 53d709c

Browse files
committed
Further relax dim check
1 parent 525d2f1 commit 53d709c

File tree

2 files changed

+5
-44
lines changed

2 files changed

+5
-44
lines changed

mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -341,12 +341,9 @@ struct ContractionLowering : public OpRewritePattern<vector::ContractionOp> {
341341
return rewriter.notifyMatchFailure(contractOp, "Invalid indexing maps");
342342

343343
// TODO: Update shape validation to be target aware.
344-
auto rhsShape = rhs.getType().getShape();
345344
auto accShape = accType.getShape();
346-
int64_t dimM = accShape[0];
347345
int64_t dimN = accShape[1];
348-
int64_t dimK = rhsShape[0];
349-
if (dimM != 8 || dimN != 16 || dimK % 8 != 0)
346+
if (dimN != 8 && dimN != 16)
350347
return rewriter.notifyMatchFailure(contractOp,
351348
"Invalid operand dimensions");
352349

mlir/test/Conversion/VectorToXeGPU/contract-to-xegpu.mlir

Lines changed: 4 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -144,51 +144,15 @@ func.func @negative_gemm_transpose_b(%lhs: vector<8x16xf16>, %rhs: vector<16x16x
144144
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
145145
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
146146
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
147-
func.func @negative_m_dim_size(%lhs: vector<16x16xf16>, %rhs: vector<16x16xf16>,
148-
%acc: vector<16x16xf32>) -> vector<16x16xf32> {
147+
func.func @negative_n_dim_size(%lhs: vector<8x16xf16>, %rhs: vector<16x32xf16>,
148+
%acc: vector<8x32xf32>) -> vector<8x32xf32> {
149149
%3 = vector.contract
150150
{indexing_maps = [#map, #map1, #map2],
151151
iterator_types = ["parallel", "parallel", "reduction"],
152152
kind = #vector.kind<add>} %lhs, %rhs, %acc
153-
: vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf32>
154-
return %3 : vector<16x16xf32>
155-
}
156-
157-
// CHECK-LABEL: @negative_m_dim_size(
158-
// CHECK: vector.contract
159-
160-
// -----
161-
162-
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
163-
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
164-
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
165-
func.func @negative_n_dim_size(%lhs: vector<8x16xf16>, %rhs: vector<16x8xf16>,
166-
%acc: vector<8x8xf32>) -> vector<8x8xf32> {
167-
%3 = vector.contract
168-
{indexing_maps = [#map, #map1, #map2],
169-
iterator_types = ["parallel", "parallel", "reduction"],
170-
kind = #vector.kind<add>} %lhs, %rhs, %acc
171-
: vector<8x16xf16>, vector<16x8xf16> into vector<8x8xf32>
172-
return %3 : vector<8x8xf32>
153+
: vector<8x16xf16>, vector<16x32xf16> into vector<8x32xf32>
154+
return %3 : vector<8x32xf32>
173155
}
174156

175157
// CHECK-LABEL: @negative_n_dim_size(
176158
// CHECK: vector.contract
177-
178-
// -----
179-
180-
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
181-
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
182-
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
183-
func.func @negative_k_dim_size(%lhs: vector<8x12xf16>, %rhs: vector<12x16xf16>,
184-
%acc: vector<8x16xf32>) -> vector<8x16xf32> {
185-
%3 = vector.contract
186-
{indexing_maps = [#map, #map1, #map2],
187-
iterator_types = ["parallel", "parallel", "reduction"],
188-
kind = #vector.kind<add>} %lhs, %rhs, %acc
189-
: vector<8x12xf16>, vector<12x16xf16> into vector<8x16xf32>
190-
return %3 : vector<8x16xf32>
191-
}
192-
193-
// CHECK-LABEL: @negative_k_dim_size(
194-
// CHECK: vector.contract

0 commit comments

Comments
 (0)