Skip to content

Commit 07363b9

Browse files
author
git apple-llvm automerger
committed
Merge commit '6b256e3df2f0' from apple/master into swift/master-next
2 parents 00ed11a + 6b256e3 commit 07363b9

File tree

9 files changed

+173
-105
lines changed

9 files changed

+173
-105
lines changed

mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -276,22 +276,40 @@ class StructType : public Type::TypeBase<StructType, CompositeType,
276276
public:
277277
using Base::Base;
278278

279-
// Layout information used for members in a struct in SPIR-V
280-
//
281-
// TODO(ravishankarm) : For now this only supports the offset type, so uses
282-
// uint64_t value to represent the offset, with
283-
// std::numeric_limit<uint64_t>::max indicating no offset. Change this to
284-
// something that can hold all the information needed for different member
285-
// types
286-
using LayoutInfo = uint64_t;
279+
// Type for specifying the offset of the struct members
280+
using OffsetInfo = uint32_t;
281+
282+
// Type for specifying the decoration(s) on struct members
283+
struct MemberDecorationInfo {
284+
uint32_t memberIndex : 31;
285+
uint32_t hasValue : 1;
286+
Decoration decoration;
287+
uint32_t decorationValue;
288+
289+
MemberDecorationInfo(uint32_t index, uint32_t hasValue,
290+
Decoration decoration, uint32_t decorationValue)
291+
: memberIndex(index), hasValue(hasValue), decoration(decoration),
292+
decorationValue(decorationValue) {}
293+
294+
bool operator==(const MemberDecorationInfo &other) const {
295+
return (this->memberIndex == other.memberIndex) &&
296+
(this->decoration == other.decoration) &&
297+
(this->decorationValue == other.decorationValue);
298+
}
287299

288-
using MemberDecorationInfo = std::pair<uint32_t, spirv::Decoration>;
300+
bool operator<(const MemberDecorationInfo &other) const {
301+
return this->memberIndex < other.memberIndex ||
302+
(this->memberIndex == other.memberIndex &&
303+
static_cast<uint32_t>(this->decoration) <
304+
static_cast<uint32_t>(other.decoration));
305+
}
306+
};
289307

290308
static bool kindof(unsigned kind) { return kind == TypeKind::Struct; }
291309

292310
/// Construct a StructType with at least one member.
293311
static StructType get(ArrayRef<Type> memberTypes,
294-
ArrayRef<LayoutInfo> layoutInfo = {},
312+
ArrayRef<OffsetInfo> offsetInfo = {},
295313
ArrayRef<MemberDecorationInfo> memberDecorations = {});
296314

297315
/// Construct a struct with no members.
@@ -323,9 +341,9 @@ class StructType : public Type::TypeBase<StructType, CompositeType,
323341

324342
ElementTypeRange getElementTypes() const;
325343

326-
bool hasLayout() const;
344+
bool hasOffset() const;
327345

328-
uint64_t getOffset(unsigned) const;
346+
uint64_t getMemberOffset(unsigned) const;
329347

330348
// Returns in `allMemberDecorations` the spirv::Decorations (apart from
331349
// Offset) associated with all members of the StructType.
@@ -334,15 +352,19 @@ class StructType : public Type::TypeBase<StructType, CompositeType,
334352

335353
// Returns in `memberDecorations` all the spirv::Decorations (apart from
336354
// Offset) associated with the `i`-th member of the StructType.
337-
void getMemberDecorations(
338-
unsigned i, SmallVectorImpl<spirv::Decoration> &memberDecorations) const;
355+
void getMemberDecorations(unsigned i,
356+
SmallVectorImpl<StructType::MemberDecorationInfo>
357+
&memberDecorations) const;
339358

340359
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
341360
Optional<spirv::StorageClass> storage = llvm::None);
342361
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
343362
Optional<spirv::StorageClass> storage = llvm::None);
344363
};
345364

365+
llvm::hash_code
366+
hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo);
367+
346368
// SPIR-V cooperative matrix type
347369
class CooperativeMatrixNVType
348370
: public Type::TypeBase<CooperativeMatrixNVType, CompositeType,

mlir/lib/Dialect/SPIRV/LayoutUtils.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ VulkanLayoutUtils::decorateType(spirv::StructType structType,
3232
}
3333

3434
SmallVector<Type, 4> memberTypes;
35-
SmallVector<Size, 4> layoutInfo;
35+
SmallVector<spirv::StructType::OffsetInfo, 4> offsetInfo;
3636
SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations;
3737

3838
Size structMemberOffset = 0;
@@ -46,7 +46,8 @@ VulkanLayoutUtils::decorateType(spirv::StructType structType,
4646
decorateType(structType.getElementType(i), memberSize, memberAlignment);
4747
structMemberOffset = llvm::alignTo(structMemberOffset, memberAlignment);
4848
memberTypes.push_back(memberType);
49-
layoutInfo.push_back(structMemberOffset);
49+
offsetInfo.push_back(
50+
static_cast<spirv::StructType::OffsetInfo>(structMemberOffset));
5051
// If the member's size is the max value, it must be the last member and it
5152
// must be a runtime array.
5253
assert(memberSize != std::numeric_limits<Size>().max() ||
@@ -66,7 +67,7 @@ VulkanLayoutUtils::decorateType(spirv::StructType structType,
6667
size = llvm::alignTo(structMemberOffset, maxMemberAlignment);
6768
alignment = maxMemberAlignment;
6869
structType.getMemberDecorations(memberDecorations);
69-
return spirv::StructType::get(memberTypes, layoutInfo, memberDecorations);
70+
return spirv::StructType::get(memberTypes, offsetInfo, memberDecorations);
7071
}
7172

7273
Type VulkanLayoutUtils::decorateType(Type type, VulkanLayoutUtils::Size &size,
@@ -168,7 +169,7 @@ bool VulkanLayoutUtils::isLegalType(Type type) {
168169
case spirv::StorageClass::StorageBuffer:
169170
case spirv::StorageClass::PushConstant:
170171
case spirv::StorageClass::PhysicalStorageBuffer:
171-
return structType.hasLayout() || !structType.getNumElements();
172+
return structType.hasOffset() || !structType.getNumElements();
172173
default:
173174
return true;
174175
}

mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp

Lines changed: 44 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -535,30 +535,31 @@ static Type parseImageType(SPIRVDialect const &dialect,
535535
static ParseResult parseStructMemberDecorations(
536536
SPIRVDialect const &dialect, DialectAsmParser &parser,
537537
ArrayRef<Type> memberTypes,
538-
SmallVectorImpl<StructType::LayoutInfo> &layoutInfo,
538+
SmallVectorImpl<StructType::OffsetInfo> &offsetInfo,
539539
SmallVectorImpl<StructType::MemberDecorationInfo> &memberDecorationInfo) {
540540

541541
// Check if the first element is offset.
542-
llvm::SMLoc layoutLoc = parser.getCurrentLocation();
543-
StructType::LayoutInfo layout = 0;
544-
OptionalParseResult layoutParseResult = parser.parseOptionalInteger(layout);
545-
if (layoutParseResult.hasValue()) {
546-
if (failed(*layoutParseResult))
542+
llvm::SMLoc offsetLoc = parser.getCurrentLocation();
543+
StructType::OffsetInfo offset = 0;
544+
OptionalParseResult offsetParseResult = parser.parseOptionalInteger(offset);
545+
if (offsetParseResult.hasValue()) {
546+
if (failed(*offsetParseResult))
547547
return failure();
548548

549-
if (layoutInfo.size() != memberTypes.size() - 1) {
550-
return parser.emitError(
551-
layoutLoc, "layout specification must be given for all members");
549+
if (offsetInfo.size() != memberTypes.size() - 1) {
550+
return parser.emitError(offsetLoc,
551+
"offset specification must be given for "
552+
"all members");
552553
}
553-
layoutInfo.push_back(layout);
554+
offsetInfo.push_back(offset);
554555
}
555556

556557
// Check for no spirv::Decorations.
557558
if (succeeded(parser.parseOptionalRSquare()))
558559
return success();
559560

560-
// If there was a layout, make sure to parse the comma.
561-
if (layoutParseResult.hasValue() && parser.parseComma())
561+
// If there was an offset, make sure to parse the comma.
562+
if (offsetParseResult.hasValue() && parser.parseComma())
562563
return failure();
563564

564565
// Check for spirv::Decorations.
@@ -567,9 +568,23 @@ static ParseResult parseStructMemberDecorations(
567568
if (!memberDecoration)
568569
return failure();
569570

570-
memberDecorationInfo.emplace_back(
571-
static_cast<uint32_t>(memberTypes.size() - 1),
572-
memberDecoration.getValue());
571+
// Parse member decoration value if it exists.
572+
if (succeeded(parser.parseOptionalEqual())) {
573+
auto memberDecorationValue =
574+
parseAndVerifyInteger<uint32_t>(dialect, parser);
575+
576+
if (!memberDecorationValue)
577+
return failure();
578+
579+
memberDecorationInfo.emplace_back(
580+
static_cast<uint32_t>(memberTypes.size() - 1), 1,
581+
memberDecoration.getValue(), memberDecorationValue.getValue());
582+
} else {
583+
memberDecorationInfo.emplace_back(
584+
static_cast<uint32_t>(memberTypes.size() - 1), 0,
585+
memberDecoration.getValue(), 0);
586+
}
587+
573588
} while (succeeded(parser.parseOptionalComma()));
574589

575590
return parser.parseRSquare();
@@ -587,7 +602,7 @@ static Type parseStructType(SPIRVDialect const &dialect,
587602
return StructType::getEmpty(dialect.getContext());
588603

589604
SmallVector<Type, 4> memberTypes;
590-
SmallVector<StructType::LayoutInfo, 4> layoutInfo;
605+
SmallVector<StructType::OffsetInfo, 4> offsetInfo;
591606
SmallVector<StructType::MemberDecorationInfo, 4> memberDecorationInfo;
592607

593608
do {
@@ -597,21 +612,21 @@ static Type parseStructType(SPIRVDialect const &dialect,
597612
memberTypes.push_back(memberType);
598613

599614
if (succeeded(parser.parseOptionalLSquare())) {
600-
if (parseStructMemberDecorations(dialect, parser, memberTypes, layoutInfo,
615+
if (parseStructMemberDecorations(dialect, parser, memberTypes, offsetInfo,
601616
memberDecorationInfo)) {
602617
return Type();
603618
}
604619
}
605620
} while (succeeded(parser.parseOptionalComma()));
606621

607-
if (!layoutInfo.empty() && memberTypes.size() != layoutInfo.size()) {
622+
if (!offsetInfo.empty() && memberTypes.size() != offsetInfo.size()) {
608623
parser.emitError(parser.getNameLoc(),
609-
"layout specification must be given for all members");
624+
"offset specification must be given for all members");
610625
return Type();
611626
}
612627
if (parser.parseGreater())
613628
return Type();
614-
return StructType::get(memberTypes, layoutInfo, memberDecorationInfo);
629+
return StructType::get(memberTypes, offsetInfo, memberDecorationInfo);
615630
}
616631

617632
// spirv-type ::= array-type
@@ -679,17 +694,20 @@ static void print(StructType type, DialectAsmPrinter &os) {
679694
os << "struct<";
680695
auto printMember = [&](unsigned i) {
681696
os << type.getElementType(i);
682-
SmallVector<spirv::Decoration, 0> decorations;
697+
SmallVector<spirv::StructType::MemberDecorationInfo, 0> decorations;
683698
type.getMemberDecorations(i, decorations);
684-
if (type.hasLayout() || !decorations.empty()) {
699+
if (type.hasOffset() || !decorations.empty()) {
685700
os << " [";
686-
if (type.hasLayout()) {
687-
os << type.getOffset(i);
701+
if (type.hasOffset()) {
702+
os << type.getMemberOffset(i);
688703
if (!decorations.empty())
689704
os << ", ";
690705
}
691-
auto eachFn = [&os](spirv::Decoration decoration) {
692-
os << stringifyDecoration(decoration);
706+
auto eachFn = [&os](spirv::StructType::MemberDecorationInfo decoration) {
707+
os << stringifyDecoration(decoration.decoration);
708+
if (decoration.hasValue) {
709+
os << "=" << decoration.decorationValue;
710+
}
693711
};
694712
llvm::interleaveComma(decorations, os, eachFn);
695713
os << "]";

0 commit comments

Comments
 (0)