-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir] Use getSingleElement
/hasSingleElement
in various places
#131460
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir] Use getSingleElement
/hasSingleElement
in various places
#131460
Conversation
@llvm/pr-subscribers-mlir-quant @llvm/pr-subscribers-flang-openmp Author: Matthias Springer (matthias-springer) ChangesThis commit adds a new helper function: This function asserts that the container has a single element and then returns that element. This helper function is useful during 1:N dialect conversions, where certain Also update a few places that should use Patch is 26.31 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/131460.diff 23 Files Affected:
diff --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h
index 78b7e94c2b3a1..dc0443c9244be 100644
--- a/llvm/include/llvm/ADT/STLExtras.h
+++ b/llvm/include/llvm/ADT/STLExtras.h
@@ -325,6 +325,14 @@ template <typename ContainerTy> bool hasSingleElement(ContainerTy &&C) {
return B != E && std::next(B) == E;
}
+/// Asserts that the given container has a single element and returns that
+/// element.
+template <typename ContainerTy>
+decltype(auto) getSingleElement(ContainerTy &&C) {
+ assert(hasSingleElement(C) && "expected container with single element");
+ return *adl_begin(C);
+}
+
/// Return a range covering \p RangeOrContainer with the first N elements
/// excluded.
template <typename T> auto drop_begin(T &&RangeOrContainer, size_t N = 1) {
diff --git a/mlir/include/mlir/Dialect/CommonFolders.h b/mlir/include/mlir/Dialect/CommonFolders.h
index 6f497a259262a..b5a12426aff80 100644
--- a/mlir/include/mlir/Dialect/CommonFolders.h
+++ b/mlir/include/mlir/Dialect/CommonFolders.h
@@ -196,8 +196,7 @@ template <class AttrElementT,
function_ref<std::optional<ElementValueT>(ElementValueT)>>
Attribute constFoldUnaryOpConditional(ArrayRef<Attribute> operands,
CalculationT &&calculate) {
- assert(operands.size() == 1 && "unary op takes one operands");
- if (!operands[0])
+ if (!llvm::getSingleElement(operands))
return {};
static_assert(
@@ -268,8 +267,7 @@ template <
class CalculationT = function_ref<TargetElementValueT(ElementValueT, bool)>>
Attribute constFoldCastOp(ArrayRef<Attribute> operands, Type resType,
CalculationT &&calculate) {
- assert(operands.size() == 1 && "Cast op takes one operand");
- if (!operands[0])
+ if (!llvm::getSingleElement(operands))
return {};
static_assert(
diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp
index 8803ba994b2c1..e01cb3a080b5c 100644
--- a/mlir/lib/Analysis/SliceAnalysis.cpp
+++ b/mlir/lib/Analysis/SliceAnalysis.cpp
@@ -107,7 +107,7 @@ static void getBackwardSliceImpl(Operation *op,
// into us. For now, just bail.
if (parentOp && backwardSlice->count(parentOp) == 0) {
assert(parentOp->getNumRegions() == 1 &&
- parentOp->getRegion(0).getBlocks().size() == 1);
+ llvm::hasSingleElement(parentOp->getRegion(0).getBlocks()));
getBackwardSliceImpl(parentOp, backwardSlice, options);
}
} else {
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 1f2781aa82114..9c4dfa27b1447 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -834,8 +834,7 @@ struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
LogicalResult
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- assert(adaptor.getOperands().size() == 1);
- Type srcType = adaptor.getOperands().front().getType();
+ Type srcType = llvm::getSingleElement(adaptor.getOperands()).getType();
Type dstType = this->getTypeConverter()->convertType(op.getType());
if (!dstType)
return getTypeConversionFailure(rewriter, op);
diff --git a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
index 1b0f023527891..df2da138d3b52 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
@@ -101,8 +101,7 @@ struct WmmaConstantOpToSPIRVLowering final
LogicalResult
matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- assert(adaptor.getOperands().size() == 1);
- Value cst = adaptor.getOperands().front();
+ Value cst = llvm::getSingleElement(adaptor.getOperands());
auto coopType = getTypeConverter()->convertType(op.getType());
if (!coopType)
return rewriter.notifyMatchFailure(op, "type conversion failed");
@@ -181,8 +180,7 @@ struct WmmaElementwiseOpToSPIRVScalarMulLowering final
"splat is not a composite construct");
}
- assert(cc.getConstituents().size() == 1);
- scalar = cc.getConstituents().front();
+ scalar = llvm::getSingleElement(cc.getConstituents());
auto coopType = getTypeConverter()->convertType(op.getType());
if (!coopType)
diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
index b0884d321bc8a..33391995885a4 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -419,13 +419,11 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
SmallVector<Value> dynDims, dynDevice;
for (auto dim : adaptor.getDimsDynamic()) {
// type conversion should be 1:1 for ints
- assert(dim.size() == 1);
- dynDims.emplace_back(dim[0]);
+ dynDims.emplace_back(llvm::getSingleElement(dim));
}
// same for device
for (auto device : adaptor.getDeviceDynamic()) {
- assert(device.size() == 1);
- dynDevice.emplace_back(device[0]);
+ dynDevice.emplace_back(llvm::getSingleElement(device));
}
// To keep the code simple, convert dims/device to values when they are
@@ -771,18 +769,17 @@ struct ConvertMeshToMPIPass
typeConverter.addConversion([](Type type) { return type; });
// convert mesh::ShardingType to a tuple of RankedTensorTypes
- typeConverter.addConversion(
- [](ShardingType type,
- SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
- auto i16 = IntegerType::get(type.getContext(), 16);
- auto i64 = IntegerType::get(type.getContext(), 64);
- std::array<int64_t, 2> shp = {ShapedType::kDynamic,
- ShapedType::kDynamic};
- results.emplace_back(RankedTensorType::get(shp, i16));
- results.emplace_back(RankedTensorType::get(shp, i64)); // actually ?x2
- results.emplace_back(RankedTensorType::get(shp, i64));
- return success();
- });
+ typeConverter.addConversion([](ShardingType type,
+ SmallVectorImpl<Type> &results)
+ -> std::optional<LogicalResult> {
+ auto i16 = IntegerType::get(type.getContext(), 16);
+ auto i64 = IntegerType::get(type.getContext(), 64);
+ std::array<int64_t, 2> shp = {ShapedType::kDynamic, ShapedType::kDynamic};
+ results.emplace_back(RankedTensorType::get(shp, i16));
+ results.emplace_back(RankedTensorType::get(shp, i64)); // actually ?x2
+ results.emplace_back(RankedTensorType::get(shp, i64));
+ return success();
+ });
// To 'extract' components, a UnrealizedConversionCastOp is expected
// to define the input
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 8acb21d5074b4..9c5b9e82cd5e0 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -1236,8 +1236,7 @@ mlir::affine::makeComposedFoldedAffineApply(OpBuilder &b, Location loc,
}
applyOp->erase();
- assert(foldResults.size() == 1 && "expected 1 folded result");
- return foldResults.front();
+ return llvm::getSingleElement(foldResults);
}
OpFoldResult
@@ -1306,8 +1305,7 @@ static OpFoldResult makeComposedFoldedMinMax(OpBuilder &b, Location loc,
}
minMaxOp->erase();
- assert(foldResults.size() == 1 && "expected 1 folded result");
- return foldResults.front();
+ return llvm::getSingleElement(foldResults);
}
OpFoldResult
diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
index bcba17bb21544..4b4eb9ce37b4c 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
@@ -1249,8 +1249,7 @@ struct GreedyFusion {
SmallVector<Operation *, 2> sibLoadOpInsts;
sibNode->getLoadOpsForMemref(memref, &sibLoadOpInsts);
// Currently findSiblingNodeToFuse searches for siblings with one load.
- assert(sibLoadOpInsts.size() == 1);
- Operation *sibLoadOpInst = sibLoadOpInsts[0];
+ Operation *sibLoadOpInst = llvm::getSingleElement(sibLoadOpInsts);
// Gather 'dstNode' load ops to 'memref'.
SmallVector<Operation *, 2> dstLoadOpInsts;
diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
index 71c6acba32d2e..dd539ff685653 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
@@ -1604,10 +1604,8 @@ SmallVector<AffineForOp, 8> mlir::affine::tile(ArrayRef<AffineForOp> forOps,
ArrayRef<uint64_t> sizes,
AffineForOp target) {
SmallVector<AffineForOp, 8> res;
- for (auto loops : tile(forOps, sizes, ArrayRef<AffineForOp>(target))) {
- assert(loops.size() == 1);
- res.push_back(loops[0]);
- }
+ for (auto loops : tile(forOps, sizes, ArrayRef<AffineForOp>(target)))
+ res.push_back(llvm::getSingleElement(loops));
return res;
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.cpp
index 6fcfa05468eea..55a09622644ea 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.cpp
@@ -44,30 +44,27 @@ struct LinalgCopyOpInterface
linalg::CopyOp> {
OpOperand &getSourceOperand(Operation *op) const {
auto copyOp = cast<CopyOp>(op);
- assert(copyOp.getInputs().size() == 1 && "expected single input");
- return copyOp.getInputsMutable()[0];
+ return llvm::getSingleElement(copyOp.getInputsMutable());
}
bool
isEquivalentSubset(Operation *op, Value candidate,
function_ref<bool(Value, Value)> equivalenceFn) const {
auto copyOp = cast<CopyOp>(op);
- assert(copyOp.getOutputs().size() == 1 && "expected single output");
- return equivalenceFn(candidate, copyOp.getOutputs()[0]);
+ return equivalenceFn(candidate,
+ llvm::getSingleElement(copyOp.getOutputs()));
}
Value buildSubsetExtraction(Operation *op, OpBuilder &builder,
Location loc) const {
auto copyOp = cast<CopyOp>(op);
- assert(copyOp.getOutputs().size() == 1 && "expected single output");
- return copyOp.getOutputs()[0];
+ return llvm::getSingleElement(copyOp.getOutputs());
}
SmallVector<Value>
getValuesNeededToBuildSubsetExtraction(Operation *op) const {
auto copyOp = cast<CopyOp>(op);
- assert(copyOp.getOutputs().size() == 1 && "expected single output");
- return {copyOp.getOutputs()[0]};
+ return {llvm::getSingleElement(copyOp.getOutputs())};
}
};
} // namespace
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 11597505e7888..59434dccc117b 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -471,7 +471,7 @@ static bool isOpItselfPotentialAutomaticAllocation(Operation *op) {
/// extending the lifetime of allocations.
static bool lastNonTerminatorInRegion(Operation *op) {
return op->getNextNode() == op->getBlock()->getTerminator() &&
- op->getParentRegion()->getBlocks().size() == 1;
+ llvm::hasSingleElement(op->getParentRegion()->getBlocks());
}
/// Inline an AllocaScopeOp if either the direct parent is an allocation scope
diff --git a/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp b/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
index 71b88d1be1b05..de834fed90e42 100644
--- a/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
+++ b/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
@@ -46,8 +46,8 @@ class QuantizedTypeConverter : public TypeConverter {
static Value materializeConversion(OpBuilder &builder, Type type,
ValueRange inputs, Location loc) {
- assert(inputs.size() == 1);
- return builder.create<quant::StorageCastOp>(loc, type, inputs[0]);
+ return builder.create<quant::StorageCastOp>(loc, type,
+ llvm::getSingleElement(inputs));
}
public:
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index e9d7dc1b847c6..ee46f9c97268b 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -52,7 +52,7 @@ static Value castBuffer(OpBuilder &b, Value buffer, Type type) {
static bool doesNotAliasExternalValue(Value value, Region *region,
ValueRange exceptions,
const OneShotAnalysisState &state) {
- assert(region->getBlocks().size() == 1 &&
+ assert(llvm::hasSingleElement(region->getBlocks()) &&
"expected region with single block");
bool result = true;
state.applyOnAliases(value, [&](Value alias) {
diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
index c0589044c26ec..40d2e254fb7dd 100644
--- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
@@ -24,12 +24,6 @@ static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
return result;
}
-/// Assert that the given value range contains a single value and return it.
-static Value getSingleValue(ValueRange values) {
- assert(values.size() == 1 && "expected single value");
- return values.front();
-}
-
// CRTP
// A base class that takes care of 1:N type conversion, which maps the converted
// op results (computed by the derived class) and materializes 1:N conversion.
@@ -119,9 +113,9 @@ class ConvertForOpTypes
// We can not do clone as the number of result types after conversion
// might be different.
ForOp newOp = rewriter.create<ForOp>(
- op.getLoc(), getSingleValue(adaptor.getLowerBound()),
- getSingleValue(adaptor.getUpperBound()),
- getSingleValue(adaptor.getStep()),
+ op.getLoc(), llvm::getSingleElement(adaptor.getLowerBound()),
+ llvm::getSingleElement(adaptor.getUpperBound()),
+ llvm::getSingleElement(adaptor.getStep()),
flattenValues(adaptor.getInitArgs()));
// Reserve whatever attributes in the original op.
@@ -149,7 +143,8 @@ class ConvertIfOpTypes
TypeRange dstTypes) const {
IfOp newOp = rewriter.create<IfOp>(
- op.getLoc(), dstTypes, getSingleValue(adaptor.getCondition()), true);
+ op.getLoc(), dstTypes, llvm::getSingleElement(adaptor.getCondition()),
+ true);
newOp->setAttrs(op->getAttrs());
// We do not need the empty blocks created by rewriter.
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 19335255fd492..e9471c1dbd0b7 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -1310,10 +1310,8 @@ SmallVector<Loops, 8> mlir::tile(ArrayRef<scf::ForOp> forOps,
Loops mlir::tile(ArrayRef<scf::ForOp> forOps, ArrayRef<Value> sizes,
scf::ForOp target) {
SmallVector<scf::ForOp, 8> res;
- for (auto loops : tile(forOps, sizes, ArrayRef<scf::ForOp>(target))) {
- assert(loops.size() == 1);
- res.push_back(loops[0]);
- }
+ for (auto loops : tile(forOps, sizes, ArrayRef<scf::ForOp>(target)))
+ res.push_back(llvm::getSingleElement(loops));
return res;
}
diff --git a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
index 66a2e45001781..6c3b23937f98f 100644
--- a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -38,7 +38,7 @@ struct AssumingOpInterface
size_t resultNum = std::distance(op->getOpResults().begin(),
llvm::find(op->getOpResults(), value));
// TODO: Support multiple blocks.
- assert(assumingOp.getDoRegion().getBlocks().size() == 1 &&
+ assert(llvm::hasSingleElement(assumingOp.getDoRegion().getBlocks()) &&
"expected exactly 1 block");
auto yieldOp = dyn_cast<shape::AssumingYieldOp>(
assumingOp.getDoRegion().front().getTerminator());
@@ -49,7 +49,7 @@ struct AssumingOpInterface
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
auto assumingOp = cast<shape::AssumingOp>(op);
- assert(assumingOp.getDoRegion().getBlocks().size() == 1 &&
+ assert(llvm::hasSingleElement(assumingOp.getDoRegion().getBlocks()) &&
"only 1 block supported");
auto yieldOp = cast<shape::AssumingYieldOp>(
assumingOp.getDoRegion().front().getTerminator());
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
index 9e9fea76416b9..948ba60ac0bbe 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
@@ -12,12 +12,6 @@
using namespace mlir;
using namespace mlir::sparse_tensor;
-/// Assert that the given value range contains a single value and return it.
-static Value getSingleValue(ValueRange values) {
- assert(values.size() == 1 && "expected single value");
- return values.front();
-}
-
static void convertLevelType(SparseTensorEncodingAttr enc, Level lvl,
SmallVectorImpl<Type> &fields) {
// Position and coordinate buffer in the sparse structure.
@@ -200,7 +194,7 @@ class ExtractIterSpaceConverter
// Construct the iteration space.
SparseIterationSpace space(loc, rewriter,
- getSingleValue(adaptor.getTensor()), 0,
+ llvm::getSingleElement(adaptor.getTensor()), 0,
op.getLvlRange(), adaptor.getParentIter());
SmallVector<Value> result = space.toValues();
@@ -218,8 +212,8 @@ class ExtractValOpConverter : public OpConversionPattern<ExtractValOp> {
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value pos = adaptor.getIterator().back();
- Value valBuf =
- rewriter.create<ToValuesOp>(loc, getSingleValue(adaptor.getTensor()));
+ Value valBuf = rewriter.create<ToValuesOp>(
+ loc, llvm::getSingleElement(adaptor.getTensor()));
rewriter.replaceOpWithNewOp<memref::LoadOp>(op, valBuf, pos);
return success();
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 20d46f7ca00c5..6a66ad24a87b4 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -47,12 +47,6 @@ static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
return result;
}
-/// Assert that the given value range contains a single value and return it.
-static Value getSingleValue(ValueRange values) {
- assert(values.size() == 1 && "expected single value");
- return values.front();
-}
-
/// Generates a load with proper `index` typing.
static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx) {
idx = genCast(builder, loc, idx, builder.getIndexType());
@@ -962,10 +956,10 @@ class SparseCompressConverter : public OpConversionPattern<CompressOp> {
SmallVector<Value> fields;
auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields,
...
[truncated]
|
@llvm/pr-subscribers-mlir-affine Author: Matthias Springer (matthias-springer) ChangesThis commit adds a new helper function: This function asserts that the container has a single element and then returns that element. This helper function is useful during 1:N dialect conversions, where certain Also update a few places that should use Patch is 26.31 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/131460.diff 23 Files Affected:
diff --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h
index 78b7e94c2b3a1..dc0443c9244be 100644
--- a/llvm/include/llvm/ADT/STLExtras.h
+++ b/llvm/include/llvm/ADT/STLExtras.h
@@ -325,6 +325,14 @@ template <typename ContainerTy> bool hasSingleElement(ContainerTy &&C) {
return B != E && std::next(B) == E;
}
+/// Asserts that the given container has a single element and returns that
+/// element.
+template <typename ContainerTy>
+decltype(auto) getSingleElement(ContainerTy &&C) {
+ assert(hasSingleElement(C) && "expected container with single element");
+ return *adl_begin(C);
+}
+
/// Return a range covering \p RangeOrContainer with the first N elements
/// excluded.
template <typename T> auto drop_begin(T &&RangeOrContainer, size_t N = 1) {
diff --git a/mlir/include/mlir/Dialect/CommonFolders.h b/mlir/include/mlir/Dialect/CommonFolders.h
index 6f497a259262a..b5a12426aff80 100644
--- a/mlir/include/mlir/Dialect/CommonFolders.h
+++ b/mlir/include/mlir/Dialect/CommonFolders.h
@@ -196,8 +196,7 @@ template <class AttrElementT,
function_ref<std::optional<ElementValueT>(ElementValueT)>>
Attribute constFoldUnaryOpConditional(ArrayRef<Attribute> operands,
CalculationT &&calculate) {
- assert(operands.size() == 1 && "unary op takes one operands");
- if (!operands[0])
+ if (!llvm::getSingleElement(operands))
return {};
static_assert(
@@ -268,8 +267,7 @@ template <
class CalculationT = function_ref<TargetElementValueT(ElementValueT, bool)>>
Attribute constFoldCastOp(ArrayRef<Attribute> operands, Type resType,
CalculationT &&calculate) {
- assert(operands.size() == 1 && "Cast op takes one operand");
- if (!operands[0])
+ if (!llvm::getSingleElement(operands))
return {};
static_assert(
diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp
index 8803ba994b2c1..e01cb3a080b5c 100644
--- a/mlir/lib/Analysis/SliceAnalysis.cpp
+++ b/mlir/lib/Analysis/SliceAnalysis.cpp
@@ -107,7 +107,7 @@ static void getBackwardSliceImpl(Operation *op,
// into us. For now, just bail.
if (parentOp && backwardSlice->count(parentOp) == 0) {
assert(parentOp->getNumRegions() == 1 &&
- parentOp->getRegion(0).getBlocks().size() == 1);
+ llvm::hasSingleElement(parentOp->getRegion(0).getBlocks()));
getBackwardSliceImpl(parentOp, backwardSlice, options);
}
} else {
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 1f2781aa82114..9c4dfa27b1447 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -834,8 +834,7 @@ struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
LogicalResult
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- assert(adaptor.getOperands().size() == 1);
- Type srcType = adaptor.getOperands().front().getType();
+ Type srcType = llvm::getSingleElement(adaptor.getOperands()).getType();
Type dstType = this->getTypeConverter()->convertType(op.getType());
if (!dstType)
return getTypeConversionFailure(rewriter, op);
diff --git a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
index 1b0f023527891..df2da138d3b52 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
@@ -101,8 +101,7 @@ struct WmmaConstantOpToSPIRVLowering final
LogicalResult
matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- assert(adaptor.getOperands().size() == 1);
- Value cst = adaptor.getOperands().front();
+ Value cst = llvm::getSingleElement(adaptor.getOperands());
auto coopType = getTypeConverter()->convertType(op.getType());
if (!coopType)
return rewriter.notifyMatchFailure(op, "type conversion failed");
@@ -181,8 +180,7 @@ struct WmmaElementwiseOpToSPIRVScalarMulLowering final
"splat is not a composite construct");
}
- assert(cc.getConstituents().size() == 1);
- scalar = cc.getConstituents().front();
+ scalar = llvm::getSingleElement(cc.getConstituents());
auto coopType = getTypeConverter()->convertType(op.getType());
if (!coopType)
diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
index b0884d321bc8a..33391995885a4 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -419,13 +419,11 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
SmallVector<Value> dynDims, dynDevice;
for (auto dim : adaptor.getDimsDynamic()) {
// type conversion should be 1:1 for ints
- assert(dim.size() == 1);
- dynDims.emplace_back(dim[0]);
+ dynDims.emplace_back(llvm::getSingleElement(dim));
}
// same for device
for (auto device : adaptor.getDeviceDynamic()) {
- assert(device.size() == 1);
- dynDevice.emplace_back(device[0]);
+ dynDevice.emplace_back(llvm::getSingleElement(device));
}
// To keep the code simple, convert dims/device to values when they are
@@ -771,18 +769,17 @@ struct ConvertMeshToMPIPass
typeConverter.addConversion([](Type type) { return type; });
// convert mesh::ShardingType to a tuple of RankedTensorTypes
- typeConverter.addConversion(
- [](ShardingType type,
- SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
- auto i16 = IntegerType::get(type.getContext(), 16);
- auto i64 = IntegerType::get(type.getContext(), 64);
- std::array<int64_t, 2> shp = {ShapedType::kDynamic,
- ShapedType::kDynamic};
- results.emplace_back(RankedTensorType::get(shp, i16));
- results.emplace_back(RankedTensorType::get(shp, i64)); // actually ?x2
- results.emplace_back(RankedTensorType::get(shp, i64));
- return success();
- });
+ typeConverter.addConversion([](ShardingType type,
+ SmallVectorImpl<Type> &results)
+ -> std::optional<LogicalResult> {
+ auto i16 = IntegerType::get(type.getContext(), 16);
+ auto i64 = IntegerType::get(type.getContext(), 64);
+ std::array<int64_t, 2> shp = {ShapedType::kDynamic, ShapedType::kDynamic};
+ results.emplace_back(RankedTensorType::get(shp, i16));
+ results.emplace_back(RankedTensorType::get(shp, i64)); // actually ?x2
+ results.emplace_back(RankedTensorType::get(shp, i64));
+ return success();
+ });
// To 'extract' components, a UnrealizedConversionCastOp is expected
// to define the input
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 8acb21d5074b4..9c5b9e82cd5e0 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -1236,8 +1236,7 @@ mlir::affine::makeComposedFoldedAffineApply(OpBuilder &b, Location loc,
}
applyOp->erase();
- assert(foldResults.size() == 1 && "expected 1 folded result");
- return foldResults.front();
+ return llvm::getSingleElement(foldResults);
}
OpFoldResult
@@ -1306,8 +1305,7 @@ static OpFoldResult makeComposedFoldedMinMax(OpBuilder &b, Location loc,
}
minMaxOp->erase();
- assert(foldResults.size() == 1 && "expected 1 folded result");
- return foldResults.front();
+ return llvm::getSingleElement(foldResults);
}
OpFoldResult
diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
index bcba17bb21544..4b4eb9ce37b4c 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
@@ -1249,8 +1249,7 @@ struct GreedyFusion {
SmallVector<Operation *, 2> sibLoadOpInsts;
sibNode->getLoadOpsForMemref(memref, &sibLoadOpInsts);
// Currently findSiblingNodeToFuse searches for siblings with one load.
- assert(sibLoadOpInsts.size() == 1);
- Operation *sibLoadOpInst = sibLoadOpInsts[0];
+ Operation *sibLoadOpInst = llvm::getSingleElement(sibLoadOpInsts);
// Gather 'dstNode' load ops to 'memref'.
SmallVector<Operation *, 2> dstLoadOpInsts;
diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
index 71c6acba32d2e..dd539ff685653 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
@@ -1604,10 +1604,8 @@ SmallVector<AffineForOp, 8> mlir::affine::tile(ArrayRef<AffineForOp> forOps,
ArrayRef<uint64_t> sizes,
AffineForOp target) {
SmallVector<AffineForOp, 8> res;
- for (auto loops : tile(forOps, sizes, ArrayRef<AffineForOp>(target))) {
- assert(loops.size() == 1);
- res.push_back(loops[0]);
- }
+ for (auto loops : tile(forOps, sizes, ArrayRef<AffineForOp>(target)))
+ res.push_back(llvm::getSingleElement(loops));
return res;
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.cpp
index 6fcfa05468eea..55a09622644ea 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.cpp
@@ -44,30 +44,27 @@ struct LinalgCopyOpInterface
linalg::CopyOp> {
OpOperand &getSourceOperand(Operation *op) const {
auto copyOp = cast<CopyOp>(op);
- assert(copyOp.getInputs().size() == 1 && "expected single input");
- return copyOp.getInputsMutable()[0];
+ return llvm::getSingleElement(copyOp.getInputsMutable());
}
bool
isEquivalentSubset(Operation *op, Value candidate,
function_ref<bool(Value, Value)> equivalenceFn) const {
auto copyOp = cast<CopyOp>(op);
- assert(copyOp.getOutputs().size() == 1 && "expected single output");
- return equivalenceFn(candidate, copyOp.getOutputs()[0]);
+ return equivalenceFn(candidate,
+ llvm::getSingleElement(copyOp.getOutputs()));
}
Value buildSubsetExtraction(Operation *op, OpBuilder &builder,
Location loc) const {
auto copyOp = cast<CopyOp>(op);
- assert(copyOp.getOutputs().size() == 1 && "expected single output");
- return copyOp.getOutputs()[0];
+ return llvm::getSingleElement(copyOp.getOutputs());
}
SmallVector<Value>
getValuesNeededToBuildSubsetExtraction(Operation *op) const {
auto copyOp = cast<CopyOp>(op);
- assert(copyOp.getOutputs().size() == 1 && "expected single output");
- return {copyOp.getOutputs()[0]};
+ return {llvm::getSingleElement(copyOp.getOutputs())};
}
};
} // namespace
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 11597505e7888..59434dccc117b 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -471,7 +471,7 @@ static bool isOpItselfPotentialAutomaticAllocation(Operation *op) {
/// extending the lifetime of allocations.
static bool lastNonTerminatorInRegion(Operation *op) {
return op->getNextNode() == op->getBlock()->getTerminator() &&
- op->getParentRegion()->getBlocks().size() == 1;
+ llvm::hasSingleElement(op->getParentRegion()->getBlocks());
}
/// Inline an AllocaScopeOp if either the direct parent is an allocation scope
diff --git a/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp b/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
index 71b88d1be1b05..de834fed90e42 100644
--- a/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
+++ b/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
@@ -46,8 +46,8 @@ class QuantizedTypeConverter : public TypeConverter {
static Value materializeConversion(OpBuilder &builder, Type type,
ValueRange inputs, Location loc) {
- assert(inputs.size() == 1);
- return builder.create<quant::StorageCastOp>(loc, type, inputs[0]);
+ return builder.create<quant::StorageCastOp>(loc, type,
+ llvm::getSingleElement(inputs));
}
public:
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index e9d7dc1b847c6..ee46f9c97268b 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -52,7 +52,7 @@ static Value castBuffer(OpBuilder &b, Value buffer, Type type) {
static bool doesNotAliasExternalValue(Value value, Region *region,
ValueRange exceptions,
const OneShotAnalysisState &state) {
- assert(region->getBlocks().size() == 1 &&
+ assert(llvm::hasSingleElement(region->getBlocks()) &&
"expected region with single block");
bool result = true;
state.applyOnAliases(value, [&](Value alias) {
diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
index c0589044c26ec..40d2e254fb7dd 100644
--- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
@@ -24,12 +24,6 @@ static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
return result;
}
-/// Assert that the given value range contains a single value and return it.
-static Value getSingleValue(ValueRange values) {
- assert(values.size() == 1 && "expected single value");
- return values.front();
-}
-
// CRTP
// A base class that takes care of 1:N type conversion, which maps the converted
// op results (computed by the derived class) and materializes 1:N conversion.
@@ -119,9 +113,9 @@ class ConvertForOpTypes
// We can not do clone as the number of result types after conversion
// might be different.
ForOp newOp = rewriter.create<ForOp>(
- op.getLoc(), getSingleValue(adaptor.getLowerBound()),
- getSingleValue(adaptor.getUpperBound()),
- getSingleValue(adaptor.getStep()),
+ op.getLoc(), llvm::getSingleElement(adaptor.getLowerBound()),
+ llvm::getSingleElement(adaptor.getUpperBound()),
+ llvm::getSingleElement(adaptor.getStep()),
flattenValues(adaptor.getInitArgs()));
// Reserve whatever attributes in the original op.
@@ -149,7 +143,8 @@ class ConvertIfOpTypes
TypeRange dstTypes) const {
IfOp newOp = rewriter.create<IfOp>(
- op.getLoc(), dstTypes, getSingleValue(adaptor.getCondition()), true);
+ op.getLoc(), dstTypes, llvm::getSingleElement(adaptor.getCondition()),
+ true);
newOp->setAttrs(op->getAttrs());
// We do not need the empty blocks created by rewriter.
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 19335255fd492..e9471c1dbd0b7 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -1310,10 +1310,8 @@ SmallVector<Loops, 8> mlir::tile(ArrayRef<scf::ForOp> forOps,
Loops mlir::tile(ArrayRef<scf::ForOp> forOps, ArrayRef<Value> sizes,
scf::ForOp target) {
SmallVector<scf::ForOp, 8> res;
- for (auto loops : tile(forOps, sizes, ArrayRef<scf::ForOp>(target))) {
- assert(loops.size() == 1);
- res.push_back(loops[0]);
- }
+ for (auto loops : tile(forOps, sizes, ArrayRef<scf::ForOp>(target)))
+ res.push_back(llvm::getSingleElement(loops));
return res;
}
diff --git a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
index 66a2e45001781..6c3b23937f98f 100644
--- a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -38,7 +38,7 @@ struct AssumingOpInterface
size_t resultNum = std::distance(op->getOpResults().begin(),
llvm::find(op->getOpResults(), value));
// TODO: Support multiple blocks.
- assert(assumingOp.getDoRegion().getBlocks().size() == 1 &&
+ assert(llvm::hasSingleElement(assumingOp.getDoRegion().getBlocks()) &&
"expected exactly 1 block");
auto yieldOp = dyn_cast<shape::AssumingYieldOp>(
assumingOp.getDoRegion().front().getTerminator());
@@ -49,7 +49,7 @@ struct AssumingOpInterface
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
auto assumingOp = cast<shape::AssumingOp>(op);
- assert(assumingOp.getDoRegion().getBlocks().size() == 1 &&
+ assert(llvm::hasSingleElement(assumingOp.getDoRegion().getBlocks()) &&
"only 1 block supported");
auto yieldOp = cast<shape::AssumingYieldOp>(
assumingOp.getDoRegion().front().getTerminator());
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
index 9e9fea76416b9..948ba60ac0bbe 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
@@ -12,12 +12,6 @@
using namespace mlir;
using namespace mlir::sparse_tensor;
-/// Assert that the given value range contains a single value and return it.
-static Value getSingleValue(ValueRange values) {
- assert(values.size() == 1 && "expected single value");
- return values.front();
-}
-
static void convertLevelType(SparseTensorEncodingAttr enc, Level lvl,
SmallVectorImpl<Type> &fields) {
// Position and coordinate buffer in the sparse structure.
@@ -200,7 +194,7 @@ class ExtractIterSpaceConverter
// Construct the iteration space.
SparseIterationSpace space(loc, rewriter,
- getSingleValue(adaptor.getTensor()), 0,
+ llvm::getSingleElement(adaptor.getTensor()), 0,
op.getLvlRange(), adaptor.getParentIter());
SmallVector<Value> result = space.toValues();
@@ -218,8 +212,8 @@ class ExtractValOpConverter : public OpConversionPattern<ExtractValOp> {
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value pos = adaptor.getIterator().back();
- Value valBuf =
- rewriter.create<ToValuesOp>(loc, getSingleValue(adaptor.getTensor()));
+ Value valBuf = rewriter.create<ToValuesOp>(
+ loc, llvm::getSingleElement(adaptor.getTensor()));
rewriter.replaceOpWithNewOp<memref::LoadOp>(op, valBuf, pos);
return success();
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 20d46f7ca00c5..6a66ad24a87b4 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -47,12 +47,6 @@ static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
return result;
}
-/// Assert that the given value range contains a single value and return it.
-static Value getSingleValue(ValueRange values) {
- assert(values.size() == 1 && "expected single value");
- return values.front();
-}
-
/// Generates a load with proper `index` typing.
static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx) {
idx = genCast(builder, loc, idx, builder.getIndexType());
@@ -962,10 +956,10 @@ class SparseCompressConverter : public OpConversionPattern<CompressOp> {
SmallVector<Value> fields;
auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields,
...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems like a nice addition.
For STLExtras changes, we typically require unit tests and land the change that introduces uses as a follow up PR. This is to minimize the risk of reverts that cause a full project rebuild. Could you split this PR into two?
For as for unit tests, we don't need anything fancy beyond making sure this works with a couple of data types and const / lvalue / rvalue parameters. We could also add a death test for the assertion.
Shower thought: we could also have |
530de07
to
6217315
Compare
getSingleElement
helper and use in MLIRgetSingleElement
/hasSingleElement
in various places
✅ With the latest revision this PR passed the C/C++ code formatter. |
65159c4
to
78eca92
Compare
6217315
to
cacf322
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM % formatting
cacf322
to
0ce8d2c
Compare
That could be good practices in general, but that hasn't been very followed here, for example this one was a very similar patch I think (and you were OK on this one as is). |
That's what I meant by 'typically require' -- trivial wrappers around simple STL functions would arguably benefit only a little from tests, but new code with interesting contract definitely does. |
(arguably, yes - I'd argue that even the simple wrappers could/should be tested, FWIW - but I understand it's a goal, not a hard-and-fast/inviolable rule - not every failure to meet that bar is equally problematic, etc) |
Right, the question is what is the bar for "tested". For every other kind of utilities, having enough in-tree uses counts as "tested" (we don't have C++ gtests for every possible utility functions historically). |
Sorry, I meant unit tested in this context/case. For things in ADT/STLExtras - yes, we don't, historically, have full coverage - but as the project's gotten larger I think it's become more relevant to test them in isolation more robustly (so that they're tested even when the only uses are in some non-LLVM subproject, so that they're tested robustly regardless of which use cases we have/don't have in-tree at any given moment, etc). |
Continuing with these two examples (min_element and getSingleElement), my thought process was roughly:
That's why I think it was safe to wave min_element through but not getSingleElement. |
This is consistent with assertions in the projects in general I believe: we don't test for assertions. |
We do test for assertions in ADT code when the assertions are advertised as part of the interface. For example many range/iterator functions like |
This is a code cleanup. Update a few places in MLIR that should use
hasSingleElement
/getSingleElement
.Note:
hasSingleElement
is faster than.getSize() == 1
when it is used with linked lists etc.Depends on #131508.