Skip to content

Commit 2e37500

Browse files
author
git apple-llvm automerger
committed
Merge commit 'f9070b2dfbec' from llvm.org/main into next
2 parents 1811fa0 + f9070b2 commit 2e37500

File tree

4 files changed

+112
-1
lines changed

4 files changed

+112
-1
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1151,7 +1151,8 @@ ExtractOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
11511151
auto n =
11521152
std::min<size_t>(adaptor.getPosition().size(), vectorType.getRank());
11531153
inferredReturnTypes.push_back(VectorType::get(
1154-
vectorType.getShape().drop_front(n), vectorType.getElementType()));
1154+
vectorType.getShape().drop_front(n), vectorType.getElementType(),
1155+
vectorType.getScalableDims().drop_front(n)));
11551156
}
11561157
return success();
11571158
}

mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,18 @@ struct CastAwayContractionLeadingOneDim
417417
}
418418
};
419419

420+
/// Looks at elementwise operations on vectors with at least one leading
421+
/// dimension equal 1, e.g. vector<1x[4]x1xf32> (but not vector<2x[4]x1xf32>),
422+
/// and cast aways the leading one dimensions (_plural_) and then broadcasts
423+
/// the results.
424+
///
425+
/// Example before:
426+
/// %1 = arith.mulf %arg0, %arg1 : vector<1x4x1xf32>
427+
/// Example after:
428+
/// %2 = arith.mulf %0, %1 : vector<4x1xf32>
429+
/// %3 = vector.broadcast %2 : vector<4x1xf32> to vector<1x4x1xf32>
430+
///
431+
/// Does support scalable vectors.
420432
class CastAwayElementwiseLeadingOneDim : public RewritePattern {
421433
public:
422434
CastAwayElementwiseLeadingOneDim(MLIRContext *context,

mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,30 @@ func.func @cast_away_insert_leading_one_dims_scalar(%s: f32, %v: vector<1x1x4xf3
276276
return %0: vector<1x1x4xf32>
277277
}
278278

279+
// CHECK-LABEL: func.func @cast_away_insert_leading_one_dims_scalar_scalable(
280+
// CHECK-SAME: %[[S:.*]]: f32,
281+
// CHECK-SAME: %[[V:.*]]: vector<1x1x[4]xf32>) -> vector<1x1x[4]xf32> {
282+
func.func @cast_away_insert_leading_one_dims_scalar_scalable(%s: f32, %v: vector<1x1x[4]xf32>) -> vector<1x1x[4]xf32> {
283+
// CHECK: %[[EXTRACT:.*]] = vector.extract %[[V]][0, 0] : vector<1x1x[4]xf32>
284+
// CHECK: %[[INSERT:.*]] = vector.insert %[[S]], %[[EXTRACT]] [0] : f32 into vector<[4]xf32>
285+
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[INSERT]] : vector<[4]xf32> to vector<1x1x[4]xf32>
286+
// CHECK: return %[[BCAST]] : vector<1x1x[4]xf32>
287+
%0 = vector.insert %s, %v [0, 0, 0] : f32 into vector<1x1x[4]xf32>
288+
return %0: vector<1x1x[4]xf32>
289+
}
290+
291+
// CHECK-LABEL: func.func @cast_away_insert_leading_one_dims_scalar_skip_scalable_dim(
292+
// CHECK-SAME: %[[S:.*]]: f32,
293+
// CHECK-SAME: %[[V:.*]]: vector<1x[1]x4xf32>) -> vector<1x[1]x4xf32> {
294+
func.func @cast_away_insert_leading_one_dims_scalar_skip_scalable_dim(%s: f32, %v: vector<1x[1]x4xf32>) -> vector<1x[1]x4xf32> {
295+
// CHECK: %[[EXTRACT:.*]] = vector.extract %[[V]][0] : vector<1x[1]x4xf32>
296+
// CHECK: %[[INSERT:.*]] = vector.insert %[[S]], %[[EXTRACT]] [0, 0] : f32 into vector<[1]x4xf32>
297+
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[INSERT]] : vector<[1]x4xf32> to vector<1x[1]x4xf32>
298+
// CHECK: return %[[BCAST]] : vector<1x[1]x4xf32>
299+
%0 = vector.insert %s, %v [0, 0, 0] : f32 into vector<1x[1]x4xf32>
300+
return %0: vector<1x[1]x4xf32>
301+
}
302+
279303
// CHECK-LABEL: func @cast_away_insert_leading_one_dims_rank1
280304
// CHECK-SAME: (%[[S:.+]]: vector<4xf32>, %[[V:.+]]: vector<1x1x4xf32>)
281305
// CHECK: %[[BCAST:.+]] = vector.broadcast %[[S]] : vector<4xf32> to vector<1x1x4xf32>
@@ -285,6 +309,16 @@ func.func @cast_away_insert_leading_one_dims_rank1(%s: vector<4xf32>, %v: vector
285309
return %0: vector<1x1x4xf32>
286310
}
287311

312+
// CHECK-LABEL: func.func @cast_away_insert_leading_one_dims_rank1_scalable(
313+
// CHECK-SAME: %[[S:.*]]: vector<[4]xf32>,
314+
// CHECK-SAME: %[[V:.*]]: vector<1x1x[4]xf32>) -> vector<1x1x[4]xf32> {
315+
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[S]] : vector<[4]xf32> to vector<1x1x[4]xf32>
316+
// CHECK: return %[[BCAST]] : vector<1x1x[4]xf32>
317+
func.func @cast_away_insert_leading_one_dims_rank1_scalable(%s: vector<[4]xf32>, %v: vector<1x1x[4]xf32>) -> vector<1x1x[4]xf32> {
318+
%0 = vector.insert %s, %v [0, 0] : vector<[4]xf32> into vector<1x1x[4]xf32>
319+
return %0: vector<1x1x[4]xf32>
320+
}
321+
288322
// CHECK-LABEL: func @cast_away_insert_leading_one_dims_rank2
289323
// CHECK-SAME: (%[[S:.+]]: vector<1x4xf32>, %[[V:.+]]: vector<1x1x4xf32>)
290324
// CHECK: %[[EXTRACT:.+]] = vector.extract %[[S]][0] : vector<1x4xf32>
@@ -295,6 +329,17 @@ func.func @cast_away_insert_leading_one_dims_rank2(%s: vector<1x4xf32>, %v: vect
295329
return %0: vector<1x1x4xf32>
296330
}
297331

332+
// CHECK-LABEL: func.func @cast_away_insert_leading_one_dims_rank2_scalable(
333+
// CHECK-SAME: %[[S:.*]]: vector<1x[4]xf32>,
334+
// CHECK-SAME: %[[V:.*]]: vector<1x1x[4]xf32>) -> vector<1x1x[4]xf32> {
335+
// CHECK: %[[EXTRACT:.*]] = vector.extract %[[S]][0] : vector<1x[4]xf32>
336+
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[EXTRACT]] : vector<[4]xf32> to vector<1x1x[4]xf32>
337+
// CHECK: return %[[BCAST]] : vector<1x1x[4]xf32>
338+
func.func @cast_away_insert_leading_one_dims_rank2_scalable(%s: vector<1x[4]xf32>, %v: vector<1x1x[4]xf32>) -> vector<1x1x[4]xf32> {
339+
%0 = vector.insert %s, %v [0] : vector<1x[4]xf32> into vector<1x1x[4]xf32>
340+
return %0: vector<1x1x[4]xf32>
341+
}
342+
298343
// CHECK-LABEL: func @cast_away_insert_leading_one_dims_rank2_one_dest
299344
// CHECK-SAME: (%[[S:.+]]: vector<1x4xf32>, %[[V:.+]]: vector<1x2x1x4xf32>)
300345
// CHECK: %[[EXTRACTS:.+]] = vector.extract %[[S]][0] : vector<1x4xf32>
@@ -307,6 +352,19 @@ func.func @cast_away_insert_leading_one_dims_rank2_one_dest(%s: vector<1x4xf32>,
307352
return %0: vector<1x2x1x4xf32>
308353
}
309354

355+
// CHECK-LABEL: func.func @cast_away_insert_leading_one_dims_rank2_one_dest_scalable(
356+
// CHECK-SAME: %[[S:.*]]: vector<1x[4]xf32>,
357+
// CHECK-SAME: %[[V:.*]]: vector<1x2x1x[4]xf32>) -> vector<1x2x1x[4]xf32> {
358+
// CHECK: %[[EXTRACTS:.*]] = vector.extract %[[S]][0] : vector<1x[4]xf32>
359+
// CHECK: %[[EXTRACTV:.*]] = vector.extract %[[V]][0] : vector<1x2x1x[4]xf32>
360+
// CHECK: %[[INSERT:.*]] = vector.insert %[[EXTRACTS]], %[[EXTRACTV]] [1, 0] : vector<[4]xf32> into vector<2x1x[4]xf32>
361+
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[INSERT]] : vector<2x1x[4]xf32> to vector<1x2x1x[4]xf32>
362+
// CHECK: return %[[BCAST]] : vector<1x2x1x[4]xf32>
363+
func.func @cast_away_insert_leading_one_dims_rank2_one_dest_scalable(%s: vector<1x[4]xf32>, %v: vector<1x2x1x[4]xf32>) -> vector<1x2x1x[4]xf32> {
364+
%0 = vector.insert %s, %v [0, 1] : vector<1x[4]xf32> into vector<1x2x1x[4]xf32>
365+
return %0: vector<1x2x1x[4]xf32>
366+
}
367+
310368
// CHECK-LABEL: func @cast_away_insert_leading_one_dims_non_one_dest
311369
// CHECK-SAME: (%[[S:.+]]: vector<1x4xf32>, %[[V:.+]]: vector<8x1x4xf32>)
312370
// CHECK: %[[EXTRACT:.+]] = vector.extract %[[S]][0] : vector<1x4xf32>
@@ -317,6 +375,17 @@ func.func @cast_away_insert_leading_one_dims_non_one_dest(%s: vector<1x4xf32>, %
317375
return %0: vector<8x1x4xf32>
318376
}
319377

378+
// CHECK-LABEL: func.func @cast_away_insert_leading_one_dims_non_one_dest_scalable(
379+
// CHECK-SAME: %[[S:.*]]: vector<1x[4]xf32>,
380+
// CHECK-SAME: %[[V:.*]]: vector<8x1x[4]xf32>) -> vector<8x1x[4]xf32> {
381+
// CHECK: %[[EXTRACT:.*]] = vector.extract %[[S]][0] : vector<1x[4]xf32>
382+
// CHECK: %[[INSERT:.*]] = vector.insert %[[EXTRACT]], %[[V]] [5, 0] : vector<[4]xf32> into vector<8x1x[4]xf32>
383+
// CHECK: return %[[INSERT]] : vector<8x1x[4]xf32>
384+
func.func @cast_away_insert_leading_one_dims_non_one_dest_scalable(%s: vector<1x[4]xf32>, %v: vector<8x1x[4]xf32>) -> vector<8x1x[4]xf32> {
385+
%0 = vector.insert %s, %v [5] : vector<1x[4]xf32> into vector<8x1x[4]xf32>
386+
return %0: vector<8x1x[4]xf32>
387+
}
388+
320389
// CHECK-LABEL: func @cast_away_insert_leading_one_dims_one_two_dest
321390
// CHECK-SAME: (%[[S:.+]]: vector<1x8xi1>, %[[V:.+]]: vector<1x1x8x1x8xi1>)
322391
// CHECK: %[[EXTRACTS:.+]] = vector.extract %[[S]][0] : vector<1x8xi1>
@@ -328,3 +397,16 @@ func.func @cast_away_insert_leading_one_dims_one_two_dest(%s: vector<1x8xi1>, %v
328397
%0 = vector.insert %s, %v [0, 0, 7] : vector<1x8xi1> into vector<1x1x8x1x8xi1>
329398
return %0: vector<1x1x8x1x8xi1>
330399
}
400+
401+
// CHECK-LABEL: func.func @cast_away_insert_leading_one_dims_one_two_dest_scalable(
402+
// CHECK-SAME: %[[S:.*]]: vector<1x[8]xi1>,
403+
// CHECK-SAME: %[[V:.*]]: vector<1x1x8x1x[8]xi1>) -> vector<1x1x8x1x[8]xi1> {
404+
// CHECK: %[[EXTRACTS:.*]] = vector.extract %[[S]][0] : vector<1x[8]xi1>
405+
// CHECK: %[[EXTRACTV:.*]] = vector.extract %[[V]][0, 0] : vector<1x1x8x1x[8]xi1>
406+
// CHECK: %[[INSERT:.*]] = vector.insert %[[EXTRACTS]], %[[EXTRACTV]] [7, 0] : vector<[8]xi1> into vector<8x1x[8]xi1>
407+
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[INSERT]] : vector<8x1x[8]xi1> to vector<1x1x8x1x[8]xi1>
408+
// CHECK: return %[[BCAST]] : vector<1x1x8x1x[8]xi1>
409+
func.func @cast_away_insert_leading_one_dims_one_two_dest_scalable(%s: vector<1x[8]xi1>, %v: vector<1x1x8x1x[8]xi1>) -> vector<1x1x8x1x[8]xi1> {
410+
%0 = vector.insert %s, %v [0, 0, 7] : vector<1x[8]xi1> into vector<1x1x8x1x[8]xi1>
411+
return %0: vector<1x1x8x1x[8]xi1>
412+
}

mlir/test/Dialect/Vector/vector-transforms.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,22 @@ func.func @no_change(%arg0: vector<2x[4]x1xf32>, %arg1: vector<2x[4]x1xf32>) ->
3434
return %1 : vector<2x[4]x1xf32>
3535
}
3636

37+
// CHECK-LABEL: func.func @cast_away_leading_one_dim(
38+
// CHECK: %[[MUL:.*]] = arith.mulf %{{.*}}, %{{.*}} : vector<4x1xf32>
39+
// CHECK: vector.broadcast %[[MUL]] : vector<4x1xf32> to vector<1x4x1xf32>
40+
func.func @cast_away_leading_one_dim(%arg0: vector<1x4x1xf32>, %arg1: vector<1x4x1xf32>) -> vector<1x4x1xf32> {
41+
%1 = arith.mulf %arg0, %arg1 : vector<1x4x1xf32>
42+
return %1: vector<1x4x1xf32>
43+
}
44+
45+
// CHECK-LABEL: func.func @cast_away_leading_one_dim_scalable(
46+
// CHECK: %[[MUL:.*]] = arith.mulf %{{.*}}, %{{.*}} : vector<[4]x1xf32>
47+
// CHECK: vector.broadcast %[[MUL]] : vector<[4]x1xf32> to vector<1x[4]x1xf32>
48+
func.func @cast_away_leading_one_dim_scalable(%arg0: vector<1x[4]x1xf32>, %arg1: vector<1x[4]x1xf32>) -> vector<1x[4]x1xf32> {
49+
%1 = arith.mulf %arg0, %arg1 : vector<1x[4]x1xf32>
50+
return %1: vector<1x[4]x1xf32>
51+
}
52+
3753
// CHECK-LABEL: func @add4x4
3854
// CHECK: %[[S1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
3955
// CHECK-NEXT: %[[S2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>

0 commit comments

Comments
 (0)