Skip to content

Commit d7e6f11

Browse files
HazemAbdelhafezantiagainst
authored andcommitted
[mlir][spirv] Enhance structure type member decoration handling
Modify structure type in SPIR-V dialect to support: 1) Multiple decorations per structure member 2) Key-value based decorations (e.g., MatrixStride) This commit kept the Offset decoration separate from members' decorations container for easier implementation and logical clarity. As such, all references to Structure layoutinfo are now offsetinfo, and any member layout defining decoration (e.g., RowMajor for Matrix) will be add to the members' decorations container along with its value if any. Differential Revision: https://reviews.llvm.org/D81426
1 parent 8b828e9 commit d7e6f11

File tree

9 files changed

+173
-105
lines changed

9 files changed

+173
-105
lines changed

mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -276,22 +276,40 @@ class StructType : public Type::TypeBase<StructType, CompositeType,
276276
public:
277277
using Base::Base;
278278

279-
// Layout information used for members in a struct in SPIR-V
280-
//
281-
// TODO(ravishankarm) : For now this only supports the offset type, so uses
282-
// uint64_t value to represent the offset, with
283-
// std::numeric_limit<uint64_t>::max indicating no offset. Change this to
284-
// something that can hold all the information needed for different member
285-
// types
286-
using LayoutInfo = uint64_t;
279+
// Type for specifying the offset of the struct members
280+
using OffsetInfo = uint32_t;
281+
282+
// Type for specifying the decoration(s) on struct members
283+
struct MemberDecorationInfo {
284+
uint32_t memberIndex : 31;
285+
uint32_t hasValue : 1;
286+
Decoration decoration;
287+
uint32_t decorationValue;
288+
289+
MemberDecorationInfo(uint32_t index, uint32_t hasValue,
290+
Decoration decoration, uint32_t decorationValue)
291+
: memberIndex(index), hasValue(hasValue), decoration(decoration),
292+
decorationValue(decorationValue) {}
293+
294+
bool operator==(const MemberDecorationInfo &other) const {
295+
return (this->memberIndex == other.memberIndex) &&
296+
(this->decoration == other.decoration) &&
297+
(this->decorationValue == other.decorationValue);
298+
}
287299

288-
using MemberDecorationInfo = std::pair<uint32_t, spirv::Decoration>;
300+
bool operator<(const MemberDecorationInfo &other) const {
301+
return this->memberIndex < other.memberIndex ||
302+
(this->memberIndex == other.memberIndex &&
303+
static_cast<uint32_t>(this->decoration) <
304+
static_cast<uint32_t>(other.decoration));
305+
}
306+
};
289307

290308
static bool kindof(unsigned kind) { return kind == TypeKind::Struct; }
291309

292310
/// Construct a StructType with at least one member.
293311
static StructType get(ArrayRef<Type> memberTypes,
294-
ArrayRef<LayoutInfo> layoutInfo = {},
312+
ArrayRef<OffsetInfo> offsetInfo = {},
295313
ArrayRef<MemberDecorationInfo> memberDecorations = {});
296314

297315
/// Construct a struct with no members.
@@ -323,9 +341,9 @@ class StructType : public Type::TypeBase<StructType, CompositeType,
323341

324342
ElementTypeRange getElementTypes() const;
325343

326-
bool hasLayout() const;
344+
bool hasOffset() const;
327345

328-
uint64_t getOffset(unsigned) const;
346+
uint64_t getMemberOffset(unsigned) const;
329347

330348
// Returns in `allMemberDecorations` the spirv::Decorations (apart from
331349
// Offset) associated with all members of the StructType.
@@ -334,15 +352,19 @@ class StructType : public Type::TypeBase<StructType, CompositeType,
334352

335353
// Returns in `memberDecorations` all the spirv::Decorations (apart from
336354
// Offset) associated with the `i`-th member of the StructType.
337-
void getMemberDecorations(
338-
unsigned i, SmallVectorImpl<spirv::Decoration> &memberDecorations) const;
355+
void getMemberDecorations(unsigned i,
356+
SmallVectorImpl<StructType::MemberDecorationInfo>
357+
&memberDecorations) const;
339358

340359
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
341360
Optional<spirv::StorageClass> storage = llvm::None);
342361
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
343362
Optional<spirv::StorageClass> storage = llvm::None);
344363
};
345364

365+
llvm::hash_code
366+
hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo);
367+
346368
// SPIR-V cooperative matrix type
347369
class CooperativeMatrixNVType
348370
: public Type::TypeBase<CooperativeMatrixNVType, CompositeType,

mlir/lib/Dialect/SPIRV/LayoutUtils.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ VulkanLayoutUtils::decorateType(spirv::StructType structType,
3232
}
3333

3434
SmallVector<Type, 4> memberTypes;
35-
SmallVector<Size, 4> layoutInfo;
35+
SmallVector<spirv::StructType::OffsetInfo, 4> offsetInfo;
3636
SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations;
3737

3838
Size structMemberOffset = 0;
@@ -46,7 +46,8 @@ VulkanLayoutUtils::decorateType(spirv::StructType structType,
4646
decorateType(structType.getElementType(i), memberSize, memberAlignment);
4747
structMemberOffset = llvm::alignTo(structMemberOffset, memberAlignment);
4848
memberTypes.push_back(memberType);
49-
layoutInfo.push_back(structMemberOffset);
49+
offsetInfo.push_back(
50+
static_cast<spirv::StructType::OffsetInfo>(structMemberOffset));
5051
// If the member's size is the max value, it must be the last member and it
5152
// must be a runtime array.
5253
assert(memberSize != std::numeric_limits<Size>().max() ||
@@ -66,7 +67,7 @@ VulkanLayoutUtils::decorateType(spirv::StructType structType,
6667
size = llvm::alignTo(structMemberOffset, maxMemberAlignment);
6768
alignment = maxMemberAlignment;
6869
structType.getMemberDecorations(memberDecorations);
69-
return spirv::StructType::get(memberTypes, layoutInfo, memberDecorations);
70+
return spirv::StructType::get(memberTypes, offsetInfo, memberDecorations);
7071
}
7172

7273
Type VulkanLayoutUtils::decorateType(Type type, VulkanLayoutUtils::Size &size,
@@ -168,7 +169,7 @@ bool VulkanLayoutUtils::isLegalType(Type type) {
168169
case spirv::StorageClass::StorageBuffer:
169170
case spirv::StorageClass::PushConstant:
170171
case spirv::StorageClass::PhysicalStorageBuffer:
171-
return structType.hasLayout() || !structType.getNumElements();
172+
return structType.hasOffset() || !structType.getNumElements();
172173
default:
173174
return true;
174175
}

mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp

Lines changed: 44 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -535,30 +535,31 @@ static Type parseImageType(SPIRVDialect const &dialect,
535535
static ParseResult parseStructMemberDecorations(
536536
SPIRVDialect const &dialect, DialectAsmParser &parser,
537537
ArrayRef<Type> memberTypes,
538-
SmallVectorImpl<StructType::LayoutInfo> &layoutInfo,
538+
SmallVectorImpl<StructType::OffsetInfo> &offsetInfo,
539539
SmallVectorImpl<StructType::MemberDecorationInfo> &memberDecorationInfo) {
540540

541541
// Check if the first element is offset.
542-
llvm::SMLoc layoutLoc = parser.getCurrentLocation();
543-
StructType::LayoutInfo layout = 0;
544-
OptionalParseResult layoutParseResult = parser.parseOptionalInteger(layout);
545-
if (layoutParseResult.hasValue()) {
546-
if (failed(*layoutParseResult))
542+
llvm::SMLoc offsetLoc = parser.getCurrentLocation();
543+
StructType::OffsetInfo offset = 0;
544+
OptionalParseResult offsetParseResult = parser.parseOptionalInteger(offset);
545+
if (offsetParseResult.hasValue()) {
546+
if (failed(*offsetParseResult))
547547
return failure();
548548

549-
if (layoutInfo.size() != memberTypes.size() - 1) {
550-
return parser.emitError(
551-
layoutLoc, "layout specification must be given for all members");
549+
if (offsetInfo.size() != memberTypes.size() - 1) {
550+
return parser.emitError(offsetLoc,
551+
"offset specification must be given for "
552+
"all members");
552553
}
553-
layoutInfo.push_back(layout);
554+
offsetInfo.push_back(offset);
554555
}
555556

556557
// Check for no spirv::Decorations.
557558
if (succeeded(parser.parseOptionalRSquare()))
558559
return success();
559560

560-
// If there was a layout, make sure to parse the comma.
561-
if (layoutParseResult.hasValue() && parser.parseComma())
561+
// If there was an offset, make sure to parse the comma.
562+
if (offsetParseResult.hasValue() && parser.parseComma())
562563
return failure();
563564

564565
// Check for spirv::Decorations.
@@ -567,9 +568,23 @@ static ParseResult parseStructMemberDecorations(
567568
if (!memberDecoration)
568569
return failure();
569570

570-
memberDecorationInfo.emplace_back(
571-
static_cast<uint32_t>(memberTypes.size() - 1),
572-
memberDecoration.getValue());
571+
// Parse member decoration value if it exists.
572+
if (succeeded(parser.parseOptionalEqual())) {
573+
auto memberDecorationValue =
574+
parseAndVerifyInteger<uint32_t>(dialect, parser);
575+
576+
if (!memberDecorationValue)
577+
return failure();
578+
579+
memberDecorationInfo.emplace_back(
580+
static_cast<uint32_t>(memberTypes.size() - 1), 1,
581+
memberDecoration.getValue(), memberDecorationValue.getValue());
582+
} else {
583+
memberDecorationInfo.emplace_back(
584+
static_cast<uint32_t>(memberTypes.size() - 1), 0,
585+
memberDecoration.getValue(), 0);
586+
}
587+
573588
} while (succeeded(parser.parseOptionalComma()));
574589

575590
return parser.parseRSquare();
@@ -587,7 +602,7 @@ static Type parseStructType(SPIRVDialect const &dialect,
587602
return StructType::getEmpty(dialect.getContext());
588603

589604
SmallVector<Type, 4> memberTypes;
590-
SmallVector<StructType::LayoutInfo, 4> layoutInfo;
605+
SmallVector<StructType::OffsetInfo, 4> offsetInfo;
591606
SmallVector<StructType::MemberDecorationInfo, 4> memberDecorationInfo;
592607

593608
do {
@@ -597,21 +612,21 @@ static Type parseStructType(SPIRVDialect const &dialect,
597612
memberTypes.push_back(memberType);
598613

599614
if (succeeded(parser.parseOptionalLSquare())) {
600-
if (parseStructMemberDecorations(dialect, parser, memberTypes, layoutInfo,
615+
if (parseStructMemberDecorations(dialect, parser, memberTypes, offsetInfo,
601616
memberDecorationInfo)) {
602617
return Type();
603618
}
604619
}
605620
} while (succeeded(parser.parseOptionalComma()));
606621

607-
if (!layoutInfo.empty() && memberTypes.size() != layoutInfo.size()) {
622+
if (!offsetInfo.empty() && memberTypes.size() != offsetInfo.size()) {
608623
parser.emitError(parser.getNameLoc(),
609-
"layout specification must be given for all members");
624+
"offset specification must be given for all members");
610625
return Type();
611626
}
612627
if (parser.parseGreater())
613628
return Type();
614-
return StructType::get(memberTypes, layoutInfo, memberDecorationInfo);
629+
return StructType::get(memberTypes, offsetInfo, memberDecorationInfo);
615630
}
616631

617632
// spirv-type ::= array-type
@@ -679,17 +694,20 @@ static void print(StructType type, DialectAsmPrinter &os) {
679694
os << "struct<";
680695
auto printMember = [&](unsigned i) {
681696
os << type.getElementType(i);
682-
SmallVector<spirv::Decoration, 0> decorations;
697+
SmallVector<spirv::StructType::MemberDecorationInfo, 0> decorations;
683698
type.getMemberDecorations(i, decorations);
684-
if (type.hasLayout() || !decorations.empty()) {
699+
if (type.hasOffset() || !decorations.empty()) {
685700
os << " [";
686-
if (type.hasLayout()) {
687-
os << type.getOffset(i);
701+
if (type.hasOffset()) {
702+
os << type.getMemberOffset(i);
688703
if (!decorations.empty())
689704
os << ", ";
690705
}
691-
auto eachFn = [&os](spirv::Decoration decoration) {
692-
os << stringifyDecoration(decoration);
706+
auto eachFn = [&os](spirv::StructType::MemberDecorationInfo decoration) {
707+
os << stringifyDecoration(decoration.decoration);
708+
if (decoration.hasValue) {
709+
os << "=" << decoration.decorationValue;
710+
}
693711
};
694712
llvm::interleaveComma(decorations, os, eachFn);
695713
os << "]";

0 commit comments

Comments
 (0)