Skip to content

Commit 34f0aea

Browse files
committed
[mlir][SPIRV] Add decorateType method for MatrixType
This PR adds a decorateType method for MatrixType, ensuring that `spirv.matrix` with offset in `spirv.struct` can be handled correctly. Signed-off-by: MingZhu Yan <[email protected]>
1 parent a527248 commit 34f0aea

File tree

3 files changed

+29
-0
lines changed

3 files changed

+29
-0
lines changed

mlir/include/mlir/Dialect/SPIRV/Utils/LayoutUtils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ namespace spirv {
2424
class ArrayType;
2525
class RuntimeArrayType;
2626
class StructType;
27+
class MatrixType;
2728
} // namespace spirv
2829

2930
/// According to the Vulkan spec "15.6.4. Offset and Stride Assignment":
@@ -67,6 +68,8 @@ class VulkanLayoutUtils {
6768
static Type decorateType(VectorType vectorType, Size &size, Size &alignment);
6869
static Type decorateType(spirv::ArrayType arrayType, Size &size,
6970
Size &alignment);
71+
static Type decorateType(spirv::MatrixType matrixType, Size &size,
72+
Size &alignment);
7073
static Type decorateType(spirv::RuntimeArrayType arrayType, Size &alignment);
7174
static spirv::StructType decorateType(spirv::StructType structType,
7275
Size &size, Size &alignment);

mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@ Type VulkanLayoutUtils::decorateType(Type type, VulkanLayoutUtils::Size &size,
9191
return decorateType(arrayType, size, alignment);
9292
if (auto vectorType = dyn_cast<VectorType>(type))
9393
return decorateType(vectorType, size, alignment);
94+
if (auto matrixType = dyn_cast<spirv::MatrixType>(type))
95+
return decorateType(matrixType, size, alignment);
9496
if (auto arrayType = dyn_cast<spirv::RuntimeArrayType>(type)) {
9597
size = std::numeric_limits<Size>().max();
9698
return decorateType(arrayType, alignment);
@@ -138,6 +140,25 @@ Type VulkanLayoutUtils::decorateType(spirv::ArrayType arrayType,
138140
return spirv::ArrayType::get(memberType, numElements, elementSize);
139141
}
140142

143+
Type VulkanLayoutUtils::decorateType(spirv::MatrixType matrixType,
144+
VulkanLayoutUtils::Size &size,
145+
VulkanLayoutUtils::Size &alignment) {
146+
const auto numColumns = matrixType.getNumColumns();
147+
const auto columnType = matrixType.getColumnType();
148+
const auto numElements = matrixType.getNumElements();
149+
Type elementType = matrixType.getElementType();
150+
Size elementSize = 0;
151+
Size elementAlignment = 1;
152+
153+
decorateType(elementType, elementSize, elementAlignment);
154+
// According to the Vulkan spec:
155+
// "A matrix type inherits scalar alignment from the equivalent array
156+
// declaration."
157+
size = elementSize * numElements;
158+
alignment = elementAlignment;
159+
return spirv::MatrixType::get(columnType, numColumns);
160+
}
161+
141162
Type VulkanLayoutUtils::decorateType(spirv::RuntimeArrayType arrayType,
142163
VulkanLayoutUtils::Size &alignment) {
143164
auto elementType = arrayType.getElementType();

mlir/test/Dialect/SPIRV/IR/types.mlir

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,11 @@ func.func private @matrix_type(!spirv.matrix<4 x vector<4xf16>>) -> ()
497497

498498
// -----
499499

500+
// CHECK: func private @matrix_type(!spirv.struct<(!spirv.matrix<3 x vector<3xf32>> [0])>)
501+
func.func private @matrix_type(!spirv.struct<(!spirv.matrix<3 x vector<3xf32>> [0])>) -> ()
502+
503+
// -----
504+
500505
// expected-error @+1 {{matrix is expected to have 2, 3, or 4 columns}}
501506
func.func private @matrix_invalid_size(!spirv.matrix<5 x vector<3xf32>>) -> ()
502507

0 commit comments

Comments
 (0)