Skip to content

Commit 4c28e66

Browse files
authored
[ADT] Support appending multiple values (#69891)
This is so that we can append multiple values at once without having to create a temporary array or repeatedly call `push_back`. Use the new function `append_values` to clean up the SPIR-V serializer code. (NFC)
1 parent 211dc4a commit 4c28e66

File tree

3 files changed

+64
-33
lines changed

3 files changed

+64
-33
lines changed

llvm/include/llvm/ADT/STLExtras.h

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2029,14 +2029,22 @@ void erase_value(Container &C, ValueType V) {
20292029
C.erase(std::remove(C.begin(), C.end(), V), C.end());
20302030
}
20312031

2032-
/// Wrapper function to append a range to a container.
2032+
/// Wrapper function to append range `R` to container `C`.
20332033
///
20342034
/// C.insert(C.end(), R.begin(), R.end());
20352035
template <typename Container, typename Range>
2036-
inline void append_range(Container &C, Range &&R) {
2036+
void append_range(Container &C, Range &&R) {
20372037
C.insert(C.end(), adl_begin(R), adl_end(R));
20382038
}
20392039

2040+
/// Appends all `Values` to container `C`.
2041+
template <typename Container, typename... Args>
2042+
void append_values(Container &C, Args &&...Values) {
2043+
C.reserve(range_size(C) + sizeof...(Args));
2044+
// Append all values one by one.
2045+
((void)C.insert(C.end(), std::forward<Args>(Values)), ...);
2046+
}
2047+
20402048
/// Given a sequence container Cont, replace the range [ContIt, ContEnd) with
20412049
/// the range [ValIt, ValEnd) (which is not from the same container).
20422050
template<typename Container, typename RandomAccessIterator>

llvm/unittests/ADT/STLExtrasTest.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,14 @@
1818
#include <list>
1919
#include <tuple>
2020
#include <type_traits>
21+
#include <unordered_set>
2122
#include <utility>
2223
#include <vector>
2324

2425
using namespace llvm;
2526

2627
using testing::ElementsAre;
28+
using testing::UnorderedElementsAre;
2729

2830
namespace {
2931

@@ -541,6 +543,30 @@ TEST(STLExtrasTest, AppendRange) {
541543
EXPECT_THAT(Str, ElementsAre('a', 'b', 'c', '\0', 'd', 'e', 'f', '\0'));
542544
}
543545

546+
TEST(STLExtrasTest, AppendValues) {
547+
std::vector<int> Vals = {1, 2};
548+
append_values(Vals, 3);
549+
EXPECT_THAT(Vals, ElementsAre(1, 2, 3));
550+
551+
append_values(Vals, 4, 5);
552+
EXPECT_THAT(Vals, ElementsAre(1, 2, 3, 4, 5));
553+
554+
std::vector<StringRef> Strs;
555+
std::string A = "A";
556+
std::string B = "B";
557+
std::string C = "C";
558+
append_values(Strs, A, B);
559+
EXPECT_THAT(Strs, ElementsAre(A, B));
560+
append_values(Strs, C);
561+
EXPECT_THAT(Strs, ElementsAre(A, B, C));
562+
563+
std::unordered_set<int> Set;
564+
append_values(Set, 1, 2);
565+
EXPECT_THAT(Set, UnorderedElementsAre(1, 2));
566+
append_values(Set, 3, 1);
567+
EXPECT_THAT(Set, UnorderedElementsAre(1, 2, 3));
568+
}
569+
544570
TEST(STLExtrasTest, ADLTest) {
545571
some_namespace::some_struct s{{1, 2, 3, 4, 5}, ""};
546572
some_namespace::some_struct s2{{2, 4, 6, 8, 10}, ""};

mlir/lib/Target/SPIRV/Serialization/Serializer.cpp

Lines changed: 28 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
1919
#include "mlir/Support/LogicalResult.h"
2020
#include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
21+
#include "llvm/ADT/STLExtras.h"
2122
#include "llvm/ADT/Sequence.h"
2223
#include "llvm/ADT/SmallPtrSet.h"
2324
#include "llvm/ADT/StringExtras.h"
@@ -443,13 +444,13 @@ LogicalResult Serializer::prepareBasicType(
443444
if (failed(processType(loc, imageType.getElementType(), sampledTypeID)))
444445
return failure();
445446

446-
operands.push_back(sampledTypeID);
447-
operands.push_back(static_cast<uint32_t>(imageType.getDim()));
448-
operands.push_back(static_cast<uint32_t>(imageType.getDepthInfo()));
449-
operands.push_back(static_cast<uint32_t>(imageType.getArrayedInfo()));
450-
operands.push_back(static_cast<uint32_t>(imageType.getSamplingInfo()));
451-
operands.push_back(static_cast<uint32_t>(imageType.getSamplerUseInfo()));
452-
operands.push_back(static_cast<uint32_t>(imageType.getImageFormat()));
447+
llvm::append_values(operands, sampledTypeID,
448+
static_cast<uint32_t>(imageType.getDim()),
449+
static_cast<uint32_t>(imageType.getDepthInfo()),
450+
static_cast<uint32_t>(imageType.getArrayedInfo()),
451+
static_cast<uint32_t>(imageType.getSamplingInfo()),
452+
static_cast<uint32_t>(imageType.getSamplerUseInfo()),
453+
static_cast<uint32_t>(imageType.getImageFormat()));
453454
return success();
454455
}
455456

@@ -605,12 +606,11 @@ LogicalResult Serializer::prepareBasicType(
605606
auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id);
606607
return prepareConstantInt(loc, attr);
607608
};
608-
operands.push_back(elementTypeID);
609-
operands.push_back(
610-
getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getScope())));
611-
operands.push_back(getConstantOp(cooperativeMatrixType.getRows()));
612-
operands.push_back(getConstantOp(cooperativeMatrixType.getColumns()));
613-
operands.push_back(
609+
llvm::append_values(
610+
operands, elementTypeID,
611+
getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getScope())),
612+
getConstantOp(cooperativeMatrixType.getRows()),
613+
getConstantOp(cooperativeMatrixType.getColumns()),
614614
getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getUse())));
615615
return success();
616616
}
@@ -627,11 +627,11 @@ LogicalResult Serializer::prepareBasicType(
627627
auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id);
628628
return prepareConstantInt(loc, attr);
629629
};
630-
operands.push_back(elementTypeID);
631-
operands.push_back(
632-
getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getScope())));
633-
operands.push_back(getConstantOp(cooperativeMatrixType.getRows()));
634-
operands.push_back(getConstantOp(cooperativeMatrixType.getColumns()));
630+
llvm::append_values(
631+
operands, elementTypeID,
632+
getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getScope())),
633+
getConstantOp(cooperativeMatrixType.getRows()),
634+
getConstantOp(cooperativeMatrixType.getColumns()));
635635
return success();
636636
}
637637

@@ -646,12 +646,10 @@ LogicalResult Serializer::prepareBasicType(
646646
auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id);
647647
return prepareConstantInt(loc, attr);
648648
};
649-
operands.push_back(elementTypeID);
650-
operands.push_back(getConstantOp(jointMatrixType.getRows()));
651-
operands.push_back(getConstantOp(jointMatrixType.getColumns()));
652-
operands.push_back(getConstantOp(
653-
static_cast<uint32_t>(jointMatrixType.getMatrixLayout())));
654-
operands.push_back(
649+
llvm::append_values(
650+
operands, elementTypeID, getConstantOp(jointMatrixType.getRows()),
651+
getConstantOp(jointMatrixType.getColumns()),
652+
getConstantOp(static_cast<uint32_t>(jointMatrixType.getMatrixLayout())),
655653
getConstantOp(static_cast<uint32_t>(jointMatrixType.getScope())));
656654
return success();
657655
}
@@ -663,8 +661,7 @@ LogicalResult Serializer::prepareBasicType(
663661
return failure();
664662
}
665663
typeEnum = spirv::Opcode::OpTypeMatrix;
666-
operands.push_back(elementTypeID);
667-
operands.push_back(matrixType.getNumColumns());
664+
llvm::append_values(operands, elementTypeID, matrixType.getNumColumns());
668665
return success();
669666
}
670667

@@ -1261,11 +1258,11 @@ LogicalResult Serializer::emitDecoration(uint32_t target,
12611258
spirv::Decoration decoration,
12621259
ArrayRef<uint32_t> params) {
12631260
uint32_t wordCount = 3 + params.size();
1264-
decorations.push_back(
1265-
spirv::getPrefixedOpcode(wordCount, spirv::Opcode::OpDecorate));
1266-
decorations.push_back(target);
1267-
decorations.push_back(static_cast<uint32_t>(decoration));
1268-
decorations.append(params.begin(), params.end());
1261+
llvm::append_values(
1262+
decorations,
1263+
spirv::getPrefixedOpcode(wordCount, spirv::Opcode::OpDecorate), target,
1264+
static_cast<uint32_t>(decoration));
1265+
llvm::append_range(decorations, params);
12691266
return success();
12701267
}
12711268

0 commit comments

Comments
 (0)