Skip to content

Commit a290c3a

Browse files
committed
[mlir][spirv] Improve stride support in array types
This commit added stride support in runtime array types. It also adjusted the assembly form for the stride from `[N]` to `stride=N`. This makes the IR more readable, especially for the cases where one mix array types and struct types. Differential Revision: https://reviews.llvm.org/D78034
1 parent 6cdcb9b commit a290c3a

26 files changed

+303
-224
lines changed

mlir/docs/Dialects/SPIR-V.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -287,13 +287,15 @@ element-type ::= integer-type
287287
| vector-type
288288
| spirv-type
289289
290-
array-type ::= `!spv.array<` integer-literal `x` element-type `>`
290+
array-type ::= `!spv.array` `<` integer-literal `x` element-type
291+
(`,` `stride` `=` integer-literal)? `>`
291292
```
292293

293294
For example,
294295

295296
```mlir
296297
!spv.array<4 x i32>
298+
!spv.array<4 x i32, stride = 4>
297299
!spv.array<16 x vector<4 x f32>>
298300
```
299301

@@ -351,13 +353,14 @@ For example,
351353
This corresponds to SPIR-V [runtime array type][RuntimeArrayType]. Its syntax is
352354

353355
```
354-
runtime-array-type ::= `!spv.rtarray<` element-type `>`
356+
runtime-array-type ::= `!spv.rtarray` `<` element-type (`,` `stride` `=` integer-literal)? `>`
355357
```
356358

357359
For example,
358360

359361
```mlir
360362
!spv.rtarray<i32>
363+
!spv.rtarray<i32, stride=4>
361364
!spv.rtarray<vector<4 x f32>>
362365
```
363366

mlir/include/mlir/Dialect/SPIRV/LayoutUtils.h

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,11 @@
1818
namespace mlir {
1919
class Type;
2020
class VectorType;
21+
2122
namespace spirv {
22-
class StructType;
2323
class ArrayType;
24+
class RuntimeArrayType;
25+
class StructType;
2426
} // namespace spirv
2527

2628
/// According to the Vulkan spec "14.5.4. Offset and Stride Assignment":
@@ -47,21 +49,27 @@ class VulkanLayoutUtils {
4749
public:
4850
using Size = uint64_t;
4951

50-
/// Returns a new StructType with layout info. Assigns the type size in bytes
51-
/// to the `size`. Assigns the type alignment in bytes to the `alignment`.
52-
static spirv::StructType decorateType(spirv::StructType structType,
53-
Size &size, Size &alignment);
52+
/// Returns a new StructType with layout decoration.
53+
static spirv::StructType decorateType(spirv::StructType structType);
54+
5455
/// Checks whether a type is legal in terms of Vulkan layout info
5556
/// decoration. A type is dynamically illegal if it's a composite type in the
5657
/// StorageBuffer, PhysicalStorageBuffer, Uniform, and PushConstant Storage
5758
/// Classes without layout information.
5859
static bool isLegalType(Type type);
5960

6061
private:
62+
/// Returns a new type with layout decoration. Assigns the type size in bytes
63+
/// to the `size`. Assigns the type alignment in bytes to the `alignment`.
6164
static Type decorateType(Type type, Size &size, Size &alignment);
65+
6266
static Type decorateType(VectorType vectorType, Size &size, Size &alignment);
6367
static Type decorateType(spirv::ArrayType arrayType, Size &size,
6468
Size &alignment);
69+
static Type decorateType(spirv::RuntimeArrayType arrayType, Size &alignment);
70+
static spirv::StructType decorateType(spirv::StructType structType,
71+
Size &size, Size &alignment);
72+
6573
/// Calculates the alignment for the given scalar type.
6674
static Size getScalarTypeAlignment(Type scalarType);
6775
};

mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -147,23 +147,22 @@ class ArrayType : public Type::TypeBase<ArrayType, CompositeType,
147147
detail::ArrayTypeStorage> {
148148
public:
149149
using Base::Base;
150-
// Zero layout specifies that is no layout
151-
using LayoutInfo = uint64_t;
152150

153151
static bool kindof(unsigned kind) { return kind == TypeKind::Array; }
154152

155153
static ArrayType get(Type elementType, unsigned elementCount);
156154

155+
/// Returns an array type with the given stride in bytes.
157156
static ArrayType get(Type elementType, unsigned elementCount,
158-
LayoutInfo layoutInfo);
157+
unsigned stride);
159158

160159
unsigned getNumElements() const;
161160

162161
Type getElementType() const;
163162

164-
bool hasLayout() const;
165-
166-
uint64_t getArrayStride() const;
163+
/// Returns the array stride in bytes. 0 means no stride decorated on this
164+
/// type.
165+
unsigned getArrayStride() const;
167166

168167
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
169168
Optional<spirv::StorageClass> storage = llvm::None);
@@ -243,8 +242,15 @@ class RuntimeArrayType
243242

244243
static RuntimeArrayType get(Type elementType);
245244

245+
/// Returns a runtime array type with the given stride in bytes.
246+
static RuntimeArrayType get(Type elementType, unsigned stride);
247+
246248
Type getElementType() const;
247249

250+
/// Returns the array stride in bytes. 0 means no stride decorated on this
251+
/// type.
252+
unsigned getArrayStride() const;
253+
248254
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
249255
Optional<spirv::StorageClass> storage = llvm::None);
250256
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,

mlir/lib/Dialect/SPIRV/LayoutUtils.cpp

Lines changed: 41 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,13 @@
1616

1717
using namespace mlir;
1818

19+
spirv::StructType
20+
VulkanLayoutUtils::decorateType(spirv::StructType structType) {
21+
Size size = 0;
22+
Size alignment = 1;
23+
return decorateType(structType, size, alignment);
24+
}
25+
1926
spirv::StructType
2027
VulkanLayoutUtils::decorateType(spirv::StructType structType,
2128
VulkanLayoutUtils::Size &size,
@@ -25,21 +32,26 @@ VulkanLayoutUtils::decorateType(spirv::StructType structType,
2532
}
2633

2734
SmallVector<Type, 4> memberTypes;
28-
SmallVector<VulkanLayoutUtils::Size, 4> layoutInfo;
35+
SmallVector<Size, 4> layoutInfo;
2936
SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations;
3037

31-
VulkanLayoutUtils::Size structMemberOffset = 0;
32-
VulkanLayoutUtils::Size maxMemberAlignment = 1;
38+
Size structMemberOffset = 0;
39+
Size maxMemberAlignment = 1;
3340

3441
for (uint32_t i = 0, e = structType.getNumElements(); i < e; ++i) {
35-
VulkanLayoutUtils::Size memberSize = 0;
36-
VulkanLayoutUtils::Size memberAlignment = 1;
42+
Size memberSize = 0;
43+
Size memberAlignment = 1;
3744

38-
auto memberType = VulkanLayoutUtils::decorateType(
39-
structType.getElementType(i), memberSize, memberAlignment);
45+
auto memberType =
46+
decorateType(structType.getElementType(i), memberSize, memberAlignment);
4047
structMemberOffset = llvm::alignTo(structMemberOffset, memberAlignment);
4148
memberTypes.push_back(memberType);
4249
layoutInfo.push_back(structMemberOffset);
50+
// If the member's size is the max value, it must be the last member and it
51+
// must be a runtime array.
52+
assert(memberSize != std::numeric_limits<Size>().max() ||
53+
(i + 1 == e &&
54+
structType.getElementType(i).isa<spirv::RuntimeArrayType>()));
4355
// According to the Vulkan spec:
4456
// "A structure has a base alignment equal to the largest base alignment of
4557
// any of its members."
@@ -60,22 +72,22 @@ VulkanLayoutUtils::decorateType(spirv::StructType structType,
6072
Type VulkanLayoutUtils::decorateType(Type type, VulkanLayoutUtils::Size &size,
6173
VulkanLayoutUtils::Size &alignment) {
6274
if (type.isa<spirv::ScalarType>()) {
63-
alignment = VulkanLayoutUtils::getScalarTypeAlignment(type);
75+
alignment = getScalarTypeAlignment(type);
6476
// Vulkan spec does not specify any padding for a scalar type.
6577
size = alignment;
6678
return type;
6779
}
6880

6981
switch (type.getKind()) {
7082
case spirv::TypeKind::Struct:
71-
return VulkanLayoutUtils::decorateType(type.cast<spirv::StructType>(), size,
72-
alignment);
83+
return decorateType(type.cast<spirv::StructType>(), size, alignment);
7384
case spirv::TypeKind::Array:
74-
return VulkanLayoutUtils::decorateType(type.cast<spirv::ArrayType>(), size,
75-
alignment);
85+
return decorateType(type.cast<spirv::ArrayType>(), size, alignment);
7686
case StandardTypes::Vector:
77-
return VulkanLayoutUtils::decorateType(type.cast<VectorType>(), size,
78-
alignment);
87+
return decorateType(type.cast<VectorType>(), size, alignment);
88+
case spirv::TypeKind::RuntimeArray:
89+
size = std::numeric_limits<Size>().max();
90+
return decorateType(type.cast<spirv::RuntimeArrayType>(), alignment);
7991
default:
8092
llvm_unreachable("unhandled SPIR-V type");
8193
}
@@ -86,11 +98,10 @@ Type VulkanLayoutUtils::decorateType(VectorType vectorType,
8698
VulkanLayoutUtils::Size &alignment) {
8799
const auto numElements = vectorType.getNumElements();
88100
auto elementType = vectorType.getElementType();
89-
VulkanLayoutUtils::Size elementSize = 0;
90-
VulkanLayoutUtils::Size elementAlignment = 1;
101+
Size elementSize = 0;
102+
Size elementAlignment = 1;
91103

92-
auto memberType = VulkanLayoutUtils::decorateType(elementType, elementSize,
93-
elementAlignment);
104+
auto memberType = decorateType(elementType, elementSize, elementAlignment);
94105
// According to the Vulkan spec:
95106
// 1. "A two-component vector has a base alignment equal to twice its scalar
96107
// alignment."
@@ -106,11 +117,10 @@ Type VulkanLayoutUtils::decorateType(spirv::ArrayType arrayType,
106117
VulkanLayoutUtils::Size &alignment) {
107118
const auto numElements = arrayType.getNumElements();
108119
auto elementType = arrayType.getElementType();
109-
spirv::ArrayType::LayoutInfo elementSize = 0;
110-
VulkanLayoutUtils::Size elementAlignment = 1;
120+
Size elementSize = 0;
121+
Size elementAlignment = 1;
111122

112-
auto memberType = VulkanLayoutUtils::decorateType(elementType, elementSize,
113-
elementAlignment);
123+
auto memberType = decorateType(elementType, elementSize, elementAlignment);
114124
// According to the Vulkan spec:
115125
// "An array has a base alignment equal to the base alignment of its element
116126
// type."
@@ -119,6 +129,15 @@ Type VulkanLayoutUtils::decorateType(spirv::ArrayType arrayType,
119129
return spirv::ArrayType::get(memberType, numElements, elementSize);
120130
}
121131

132+
Type VulkanLayoutUtils::decorateType(spirv::RuntimeArrayType arrayType,
133+
VulkanLayoutUtils::Size &alignment) {
134+
auto elementType = arrayType.getElementType();
135+
Size elementSize = 0;
136+
137+
auto memberType = decorateType(elementType, elementSize, alignment);
138+
return spirv::RuntimeArrayType::get(memberType, elementSize);
139+
}
140+
122141
VulkanLayoutUtils::Size
123142
VulkanLayoutUtils::getScalarTypeAlignment(Type scalarType) {
124143
// According to the Vulkan spec:

mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp

Lines changed: 48 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ Optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect,
149149
DialectAsmParser &parser);
150150

151151
template <>
152-
Optional<uint64_t> parseAndVerify<uint64_t>(SPIRVDialect const &dialect,
152+
Optional<unsigned> parseAndVerify<unsigned>(SPIRVDialect const &dialect,
153153
DialectAsmParser &parser);
154154

155155
static Type parseAndVerifyType(SPIRVDialect const &dialect,
@@ -196,13 +196,39 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect,
196196
return type;
197197
}
198198

199+
/// Parses an optional `, stride = N` assembly segment. If no parsing failure
200+
/// occurs, writes `N` to `stride` if existing and writes 0 to `stride` if
201+
/// missing.
202+
static LogicalResult parseOptionalArrayStride(const SPIRVDialect &dialect,
203+
DialectAsmParser &parser,
204+
unsigned &stride) {
205+
if (failed(parser.parseOptionalComma())) {
206+
stride = 0;
207+
return success();
208+
}
209+
210+
if (parser.parseKeyword("stride") || parser.parseEqual())
211+
return failure();
212+
213+
llvm::SMLoc strideLoc = parser.getCurrentLocation();
214+
Optional<unsigned> optStride = parseAndVerify<unsigned>(dialect, parser);
215+
if (!optStride)
216+
return failure();
217+
218+
if (!(stride = optStride.getValue())) {
219+
parser.emitError(strideLoc, "ArrayStride must be greater than zero");
220+
return failure();
221+
}
222+
return success();
223+
}
224+
199225
// element-type ::= integer-type
200226
// | floating-point-type
201227
// | vector-type
202228
// | spirv-type
203229
//
204-
// array-type ::= `!spv.array<` integer-literal `x` element-type
205-
// (`[` integer-literal `]`)? `>`
230+
// array-type ::= `!spv.array` `<` integer-literal `x` element-type
231+
// (`,` `stride` `=` integer-literal)? `>`
206232
static Type parseArrayType(SPIRVDialect const &dialect,
207233
DialectAsmParser &parser) {
208234
if (parser.parseLess())
@@ -230,25 +256,13 @@ static Type parseArrayType(SPIRVDialect const &dialect,
230256
if (!elementType)
231257
return Type();
232258

233-
ArrayType::LayoutInfo layoutInfo = 0;
234-
if (succeeded(parser.parseOptionalLSquare())) {
235-
llvm::SMLoc layoutLoc = parser.getCurrentLocation();
236-
auto layout = parseAndVerify<ArrayType::LayoutInfo>(dialect, parser);
237-
if (!layout)
238-
return Type();
239-
240-
if (!(layoutInfo = layout.getValue())) {
241-
parser.emitError(layoutLoc, "ArrayStride must be greater than zero");
242-
return Type();
243-
}
244-
245-
if (parser.parseRSquare())
246-
return Type();
247-
}
259+
unsigned stride = 0;
260+
if (failed(parseOptionalArrayStride(dialect, parser, stride)))
261+
return Type();
248262

249263
if (parser.parseGreater())
250264
return Type();
251-
return ArrayType::get(elementType, count, layoutInfo);
265+
return ArrayType::get(elementType, count, stride);
252266
}
253267

254268
// TODO(ravishankarm) : Reorder methods to be utilities first and parse*Type
@@ -285,7 +299,8 @@ static Type parsePointerType(SPIRVDialect const &dialect,
285299
return PointerType::get(pointeeType, *storageClass);
286300
}
287301

288-
// runtime-array-type ::= `!spv.rtarray<` element-type `>`
302+
// runtime-array-type ::= `!spv.rtarray` `<` element-type
303+
// (`,` `stride` `=` integer-literal)? `>`
289304
static Type parseRuntimeArrayType(SPIRVDialect const &dialect,
290305
DialectAsmParser &parser) {
291306
if (parser.parseLess())
@@ -295,9 +310,13 @@ static Type parseRuntimeArrayType(SPIRVDialect const &dialect,
295310
if (!elementType)
296311
return Type();
297312

313+
unsigned stride = 0;
314+
if (failed(parseOptionalArrayStride(dialect, parser, stride)))
315+
return Type();
316+
298317
if (parser.parseGreater())
299318
return Type();
300-
return RuntimeArrayType::get(elementType);
319+
return RuntimeArrayType::get(elementType, stride);
301320
}
302321

303322
// Specialize this function to parse each of the parameters that define an
@@ -337,9 +356,9 @@ static Optional<IntTy> parseAndVerifyInteger(SPIRVDialect const &dialect,
337356
}
338357

339358
template <>
340-
Optional<uint64_t> parseAndVerify<uint64_t>(SPIRVDialect const &dialect,
359+
Optional<unsigned> parseAndVerify<unsigned>(SPIRVDialect const &dialect,
341360
DialectAsmParser &parser) {
342-
return parseAndVerifyInteger<uint64_t>(dialect, parser);
361+
return parseAndVerifyInteger<unsigned>(dialect, parser);
343362
}
344363

345364
namespace {
@@ -526,14 +545,16 @@ Type SPIRVDialect::parseType(DialectAsmParser &parser) const {
526545

527546
static void print(ArrayType type, DialectAsmPrinter &os) {
528547
os << "array<" << type.getNumElements() << " x " << type.getElementType();
529-
if (type.hasLayout()) {
530-
os << " [" << type.getArrayStride() << "]";
531-
}
548+
if (unsigned stride = type.getArrayStride())
549+
os << ", stride=" << stride;
532550
os << ">";
533551
}
534552

535553
static void print(RuntimeArrayType type, DialectAsmPrinter &os) {
536-
os << "rtarray<" << type.getElementType() << ">";
554+
os << "rtarray<" << type.getElementType();
555+
if (unsigned stride = type.getArrayStride())
556+
os << ", stride=" << stride;
557+
os << ">";
537558
}
538559

539560
static void print(PointerType type, DialectAsmPrinter &os) {

0 commit comments

Comments
 (0)