Skip to content

Commit f9070b2

Browse files
committed
[mlir][vector] Enable CastAwayElementwiseLeadingOneDim for scalable vec
This patch effectively enables the CastAwayElementwiseLeadingOneDim rewrite pattern for scalable vectors. To this end, `ExtractOp::inferReturnTypes` is updated so that scalable dimensions are correctly recognised. The change to ExtractOp will likely make also other conversion patterns valid for scalable vectors, but this patch focuses on just one case. Other conversion patterns will be enabled in the forthcoming patches. Depends on D157993 Differential Revision: https://reviews.llvm.org/D158335
1 parent 841c4dc commit f9070b2

File tree

4 files changed

+112
-1
lines changed

4 files changed

+112
-1
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1151,7 +1151,8 @@ ExtractOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
11511151
auto n =
11521152
std::min<size_t>(adaptor.getPosition().size(), vectorType.getRank());
11531153
inferredReturnTypes.push_back(VectorType::get(
1154-
vectorType.getShape().drop_front(n), vectorType.getElementType()));
1154+
vectorType.getShape().drop_front(n), vectorType.getElementType(),
1155+
vectorType.getScalableDims().drop_front(n)));
11551156
}
11561157
return success();
11571158
}

mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,18 @@ struct CastAwayContractionLeadingOneDim
417417
}
418418
};
419419

420+
/// Looks at elementwise operations on vectors with at least one leading
421+
/// dimension equal 1, e.g. vector<1x[4]x1xf32> (but not vector<2x[4]x1xf32>),
422+
/// and cast aways the leading one dimensions (_plural_) and then broadcasts
423+
/// the results.
424+
///
425+
/// Example before:
426+
/// %1 = arith.mulf %arg0, %arg1 : vector<1x4x1xf32>
427+
/// Example after:
428+
/// %2 = arith.mulf %0, %1 : vector<4x1xf32>
429+
/// %3 = vector.broadcast %2 : vector<4x1xf32> to vector<1x4x1xf32>
430+
///
431+
/// Does support scalable vectors.
420432
class CastAwayElementwiseLeadingOneDim : public RewritePattern {
421433
public:
422434
CastAwayElementwiseLeadingOneDim(MLIRContext *context,

mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,30 @@ func.func @cast_away_insert_leading_one_dims_scalar(%s: f32, %v: vector<1x1x4xf3
276276
return %0: vector<1x1x4xf32>
277277
}
278278

279+
// CHECK-LABEL: func.func @cast_away_insert_leading_one_dims_scalar_scalable(
280+
// CHECK-SAME: %[[S:.*]]: f32,
281+
// CHECK-SAME: %[[V:.*]]: vector<1x1x[4]xf32>) -> vector<1x1x[4]xf32> {
282+
func.func @cast_away_insert_leading_one_dims_scalar_scalable(%s: f32, %v: vector<1x1x[4]xf32>) -> vector<1x1x[4]xf32> {
283+
// CHECK: %[[EXTRACT:.*]] = vector.extract %[[V]][0, 0] : vector<1x1x[4]xf32>
284+
// CHECK: %[[INSERT:.*]] = vector.insert %[[S]], %[[EXTRACT]] [0] : f32 into vector<[4]xf32>
285+
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[INSERT]] : vector<[4]xf32> to vector<1x1x[4]xf32>
286+
// CHECK: return %[[BCAST]] : vector<1x1x[4]xf32>
287+
%0 = vector.insert %s, %v [0, 0, 0] : f32 into vector<1x1x[4]xf32>
288+
return %0: vector<1x1x[4]xf32>
289+
}
290+
291+
// CHECK-LABEL: func.func @cast_away_insert_leading_one_dims_scalar_skip_scalable_dim(
292+
// CHECK-SAME: %[[S:.*]]: f32,
293+
// CHECK-SAME: %[[V:.*]]: vector<1x[1]x4xf32>) -> vector<1x[1]x4xf32> {
294+
func.func @cast_away_insert_leading_one_dims_scalar_skip_scalable_dim(%s: f32, %v: vector<1x[1]x4xf32>) -> vector<1x[1]x4xf32> {
295+
// CHECK: %[[EXTRACT:.*]] = vector.extract %[[V]][0] : vector<1x[1]x4xf32>
296+
// CHECK: %[[INSERT:.*]] = vector.insert %[[S]], %[[EXTRACT]] [0, 0] : f32 into vector<[1]x4xf32>
297+
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[INSERT]] : vector<[1]x4xf32> to vector<1x[1]x4xf32>
298+
// CHECK: return %[[BCAST]] : vector<1x[1]x4xf32>
299+
%0 = vector.insert %s, %v [0, 0, 0] : f32 into vector<1x[1]x4xf32>
300+
return %0: vector<1x[1]x4xf32>
301+
}
302+
279303
// CHECK-LABEL: func @cast_away_insert_leading_one_dims_rank1
280304
// CHECK-SAME: (%[[S:.+]]: vector<4xf32>, %[[V:.+]]: vector<1x1x4xf32>)
281305
// CHECK: %[[BCAST:.+]] = vector.broadcast %[[S]] : vector<4xf32> to vector<1x1x4xf32>
@@ -285,6 +309,16 @@ func.func @cast_away_insert_leading_one_dims_rank1(%s: vector<4xf32>, %v: vector
285309
return %0: vector<1x1x4xf32>
286310
}
287311

312+
// CHECK-LABEL: func.func @cast_away_insert_leading_one_dims_rank1_scalable(
313+
// CHECK-SAME: %[[S:.*]]: vector<[4]xf32>,
314+
// CHECK-SAME: %[[V:.*]]: vector<1x1x[4]xf32>) -> vector<1x1x[4]xf32> {
315+
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[S]] : vector<[4]xf32> to vector<1x1x[4]xf32>
316+
// CHECK: return %[[BCAST]] : vector<1x1x[4]xf32>
317+
func.func @cast_away_insert_leading_one_dims_rank1_scalable(%s: vector<[4]xf32>, %v: vector<1x1x[4]xf32>) -> vector<1x1x[4]xf32> {
318+
%0 = vector.insert %s, %v [0, 0] : vector<[4]xf32> into vector<1x1x[4]xf32>
319+
return %0: vector<1x1x[4]xf32>
320+
}
321+
288322
// CHECK-LABEL: func @cast_away_insert_leading_one_dims_rank2
289323
// CHECK-SAME: (%[[S:.+]]: vector<1x4xf32>, %[[V:.+]]: vector<1x1x4xf32>)
290324
// CHECK: %[[EXTRACT:.+]] = vector.extract %[[S]][0] : vector<1x4xf32>
@@ -295,6 +329,17 @@ func.func @cast_away_insert_leading_one_dims_rank2(%s: vector<1x4xf32>, %v: vect
295329
return %0: vector<1x1x4xf32>
296330
}
297331

332+
// CHECK-LABEL: func.func @cast_away_insert_leading_one_dims_rank2_scalable(
333+
// CHECK-SAME: %[[S:.*]]: vector<1x[4]xf32>,
334+
// CHECK-SAME: %[[V:.*]]: vector<1x1x[4]xf32>) -> vector<1x1x[4]xf32> {
335+
// CHECK: %[[EXTRACT:.*]] = vector.extract %[[S]][0] : vector<1x[4]xf32>
336+
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[EXTRACT]] : vector<[4]xf32> to vector<1x1x[4]xf32>
337+
// CHECK: return %[[BCAST]] : vector<1x1x[4]xf32>
338+
func.func @cast_away_insert_leading_one_dims_rank2_scalable(%s: vector<1x[4]xf32>, %v: vector<1x1x[4]xf32>) -> vector<1x1x[4]xf32> {
339+
%0 = vector.insert %s, %v [0] : vector<1x[4]xf32> into vector<1x1x[4]xf32>
340+
return %0: vector<1x1x[4]xf32>
341+
}
342+
298343
// CHECK-LABEL: func @cast_away_insert_leading_one_dims_rank2_one_dest
299344
// CHECK-SAME: (%[[S:.+]]: vector<1x4xf32>, %[[V:.+]]: vector<1x2x1x4xf32>)
300345
// CHECK: %[[EXTRACTS:.+]] = vector.extract %[[S]][0] : vector<1x4xf32>
@@ -307,6 +352,19 @@ func.func @cast_away_insert_leading_one_dims_rank2_one_dest(%s: vector<1x4xf32>,
307352
return %0: vector<1x2x1x4xf32>
308353
}
309354

355+
// CHECK-LABEL: func.func @cast_away_insert_leading_one_dims_rank2_one_dest_scalable(
356+
// CHECK-SAME: %[[S:.*]]: vector<1x[4]xf32>,
357+
// CHECK-SAME: %[[V:.*]]: vector<1x2x1x[4]xf32>) -> vector<1x2x1x[4]xf32> {
358+
// CHECK: %[[EXTRACTS:.*]] = vector.extract %[[S]][0] : vector<1x[4]xf32>
359+
// CHECK: %[[EXTRACTV:.*]] = vector.extract %[[V]][0] : vector<1x2x1x[4]xf32>
360+
// CHECK: %[[INSERT:.*]] = vector.insert %[[EXTRACTS]], %[[EXTRACTV]] [1, 0] : vector<[4]xf32> into vector<2x1x[4]xf32>
361+
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[INSERT]] : vector<2x1x[4]xf32> to vector<1x2x1x[4]xf32>
362+
// CHECK: return %[[BCAST]] : vector<1x2x1x[4]xf32>
363+
func.func @cast_away_insert_leading_one_dims_rank2_one_dest_scalable(%s: vector<1x[4]xf32>, %v: vector<1x2x1x[4]xf32>) -> vector<1x2x1x[4]xf32> {
364+
%0 = vector.insert %s, %v [0, 1] : vector<1x[4]xf32> into vector<1x2x1x[4]xf32>
365+
return %0: vector<1x2x1x[4]xf32>
366+
}
367+
310368
// CHECK-LABEL: func @cast_away_insert_leading_one_dims_non_one_dest
311369
// CHECK-SAME: (%[[S:.+]]: vector<1x4xf32>, %[[V:.+]]: vector<8x1x4xf32>)
312370
// CHECK: %[[EXTRACT:.+]] = vector.extract %[[S]][0] : vector<1x4xf32>
@@ -317,6 +375,17 @@ func.func @cast_away_insert_leading_one_dims_non_one_dest(%s: vector<1x4xf32>, %
317375
return %0: vector<8x1x4xf32>
318376
}
319377

378+
// CHECK-LABEL: func.func @cast_away_insert_leading_one_dims_non_one_dest_scalable(
379+
// CHECK-SAME: %[[S:.*]]: vector<1x[4]xf32>,
380+
// CHECK-SAME: %[[V:.*]]: vector<8x1x[4]xf32>) -> vector<8x1x[4]xf32> {
381+
// CHECK: %[[EXTRACT:.*]] = vector.extract %[[S]][0] : vector<1x[4]xf32>
382+
// CHECK: %[[INSERT:.*]] = vector.insert %[[EXTRACT]], %[[V]] [5, 0] : vector<[4]xf32> into vector<8x1x[4]xf32>
383+
// CHECK: return %[[INSERT]] : vector<8x1x[4]xf32>
384+
func.func @cast_away_insert_leading_one_dims_non_one_dest_scalable(%s: vector<1x[4]xf32>, %v: vector<8x1x[4]xf32>) -> vector<8x1x[4]xf32> {
385+
%0 = vector.insert %s, %v [5] : vector<1x[4]xf32> into vector<8x1x[4]xf32>
386+
return %0: vector<8x1x[4]xf32>
387+
}
388+
320389
// CHECK-LABEL: func @cast_away_insert_leading_one_dims_one_two_dest
321390
// CHECK-SAME: (%[[S:.+]]: vector<1x8xi1>, %[[V:.+]]: vector<1x1x8x1x8xi1>)
322391
// CHECK: %[[EXTRACTS:.+]] = vector.extract %[[S]][0] : vector<1x8xi1>
@@ -328,3 +397,16 @@ func.func @cast_away_insert_leading_one_dims_one_two_dest(%s: vector<1x8xi1>, %v
328397
%0 = vector.insert %s, %v [0, 0, 7] : vector<1x8xi1> into vector<1x1x8x1x8xi1>
329398
return %0: vector<1x1x8x1x8xi1>
330399
}
400+
401+
// CHECK-LABEL: func.func @cast_away_insert_leading_one_dims_one_two_dest_scalable(
402+
// CHECK-SAME: %[[S:.*]]: vector<1x[8]xi1>,
403+
// CHECK-SAME: %[[V:.*]]: vector<1x1x8x1x[8]xi1>) -> vector<1x1x8x1x[8]xi1> {
404+
// CHECK: %[[EXTRACTS:.*]] = vector.extract %[[S]][0] : vector<1x[8]xi1>
405+
// CHECK: %[[EXTRACTV:.*]] = vector.extract %[[V]][0, 0] : vector<1x1x8x1x[8]xi1>
406+
// CHECK: %[[INSERT:.*]] = vector.insert %[[EXTRACTS]], %[[EXTRACTV]] [7, 0] : vector<[8]xi1> into vector<8x1x[8]xi1>
407+
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[INSERT]] : vector<8x1x[8]xi1> to vector<1x1x8x1x[8]xi1>
408+
// CHECK: return %[[BCAST]] : vector<1x1x8x1x[8]xi1>
409+
func.func @cast_away_insert_leading_one_dims_one_two_dest_scalable(%s: vector<1x[8]xi1>, %v: vector<1x1x8x1x[8]xi1>) -> vector<1x1x8x1x[8]xi1> {
410+
%0 = vector.insert %s, %v [0, 0, 7] : vector<1x[8]xi1> into vector<1x1x8x1x[8]xi1>
411+
return %0: vector<1x1x8x1x[8]xi1>
412+
}

mlir/test/Dialect/Vector/vector-transforms.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,22 @@ func.func @no_change(%arg0: vector<2x[4]x1xf32>, %arg1: vector<2x[4]x1xf32>) ->
3434
return %1 : vector<2x[4]x1xf32>
3535
}
3636

37+
// CHECK-LABEL: func.func @cast_away_leading_one_dim(
38+
// CHECK: %[[MUL:.*]] = arith.mulf %{{.*}}, %{{.*}} : vector<4x1xf32>
39+
// CHECK: vector.broadcast %[[MUL]] : vector<4x1xf32> to vector<1x4x1xf32>
40+
func.func @cast_away_leading_one_dim(%arg0: vector<1x4x1xf32>, %arg1: vector<1x4x1xf32>) -> vector<1x4x1xf32> {
41+
%1 = arith.mulf %arg0, %arg1 : vector<1x4x1xf32>
42+
return %1: vector<1x4x1xf32>
43+
}
44+
45+
// CHECK-LABEL: func.func @cast_away_leading_one_dim_scalable(
46+
// CHECK: %[[MUL:.*]] = arith.mulf %{{.*}}, %{{.*}} : vector<[4]x1xf32>
47+
// CHECK: vector.broadcast %[[MUL]] : vector<[4]x1xf32> to vector<1x[4]x1xf32>
48+
func.func @cast_away_leading_one_dim_scalable(%arg0: vector<1x[4]x1xf32>, %arg1: vector<1x[4]x1xf32>) -> vector<1x[4]x1xf32> {
49+
%1 = arith.mulf %arg0, %arg1 : vector<1x[4]x1xf32>
50+
return %1: vector<1x[4]x1xf32>
51+
}
52+
3753
// CHECK-LABEL: func @add4x4
3854
// CHECK: %[[S1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
3955
// CHECK-NEXT: %[[S2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>

0 commit comments

Comments
 (0)