Skip to content

Commit 9f1da90

Browse files
authored
[mlir][SPIRV] Do not rewrite CompositeInsert for coopmatrix (#137837)
When rewriting multiple CompositeInserts to CompositeConstruct, we need to know the number of elements of the result type. However, we cannot query the number of elements for cooperative matrix types.
1 parent bc546ca commit 9f1da90

File tree

2 files changed

+15
-0
lines changed

2 files changed

+15
-0
lines changed

mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,9 @@ void RewriteInsertsPass::runOnOperation() {
8484
LogicalResult RewriteInsertsPass::collectInsertionChain(
8585
spirv::CompositeInsertOp op,
8686
SmallVectorImpl<spirv::CompositeInsertOp> &insertions) {
87+
if (isa<spirv::CooperativeMatrixType>(op.getComposite().getType()))
88+
return failure();
89+
8790
auto indicesArrayAttr = cast<ArrayAttr>(op.getIndices());
8891
// TODO: handle nested composite object.
8992
if (indicesArrayAttr.size() == 1) {

mlir/test/Dialect/SPIRV/Transforms/rewrite-inserts.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,15 @@ spirv.module Logical GLSL450 {
2929
spirv.ReturnValue %3 : vector<3xf32>
3030
}
3131
}
32+
33+
// -----
34+
35+
spirv.module Logical GLSL450 {
36+
spirv.func @insertCoopMatrix(%value : f32) -> !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA> "None" {
37+
%0 = spirv.Undef : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA>
38+
// CHECK: spirv.CompositeInsert {{%.*}}, {{%.*}} : f32 into !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA>
39+
%1 = spirv.CompositeInsert %value, %0[0 : i32] : f32 into !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA>
40+
41+
spirv.ReturnValue %1 : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA>
42+
}
43+
}

0 commit comments

Comments
 (0)