Skip to content

[mlir] Use getSingleElement/hasSingleElement in various places #131460

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 17, 2025

Conversation

matthias-springer
Copy link
Member

@matthias-springer matthias-springer commented Mar 15, 2025

This is a code cleanup. Update a few places in MLIR that should use hasSingleElement/getSingleElement.

Note: hasSingleElement is faster than .getSize() == 1 when it is used with linked lists etc.

Depends on #131508.

@llvmbot
Copy link
Member

llvmbot commented Mar 15, 2025

@llvm/pr-subscribers-mlir-quant
@llvm/pr-subscribers-mlir-spirv
@llvm/pr-subscribers-mlir-llvm
@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-flang-openmp

Author: Matthias Springer (matthias-springer)

Changes

This commit adds a new helper function: getSingleElement

This function asserts that the container has a single element and then returns that element. This helper function is useful during 1:N dialect conversions, where certain ValueRanges (return from the adaptor) are known to have a single value.

Also update a few places that should use hasSingleElement instead of .size() = 1.


Patch is 26.31 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/131460.diff

23 Files Affected:

  • (modified) llvm/include/llvm/ADT/STLExtras.h (+8)
  • (modified) mlir/include/mlir/Dialect/CommonFolders.h (+2-4)
  • (modified) mlir/lib/Analysis/SliceAnalysis.cpp (+1-1)
  • (modified) mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp (+1-2)
  • (modified) mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp (+2-4)
  • (modified) mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp (+13-16)
  • (modified) mlir/lib/Dialect/Affine/IR/AffineOps.cpp (+2-4)
  • (modified) mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp (+1-2)
  • (modified) mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp (+2-4)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.cpp (+5-8)
  • (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp (+2-2)
  • (modified) mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp (+1-1)
  • (modified) mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp (+5-10)
  • (modified) mlir/lib/Dialect/SCF/Utils/Utils.cpp (+2-4)
  • (modified) mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp (+2-2)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp (+3-9)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp (+5-11)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp (+2-3)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp (+4-9)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+1-2)
  • (modified) mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp (+1-2)
  • (modified) mlir/test/lib/Analysis/TestCFGLoopInfo.cpp (+1-1)
diff --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h
index 78b7e94c2b3a1..dc0443c9244be 100644
--- a/llvm/include/llvm/ADT/STLExtras.h
+++ b/llvm/include/llvm/ADT/STLExtras.h
@@ -325,6 +325,14 @@ template <typename ContainerTy> bool hasSingleElement(ContainerTy &&C) {
   return B != E && std::next(B) == E;
 }
 
+/// Asserts that the given container has a single element and returns that
+/// element.
+template <typename ContainerTy>
+decltype(auto) getSingleElement(ContainerTy &&C) {
+  assert(hasSingleElement(C) && "expected container with single element");
+  return *adl_begin(C);
+}
+
 /// Return a range covering \p RangeOrContainer with the first N elements
 /// excluded.
 template <typename T> auto drop_begin(T &&RangeOrContainer, size_t N = 1) {
diff --git a/mlir/include/mlir/Dialect/CommonFolders.h b/mlir/include/mlir/Dialect/CommonFolders.h
index 6f497a259262a..b5a12426aff80 100644
--- a/mlir/include/mlir/Dialect/CommonFolders.h
+++ b/mlir/include/mlir/Dialect/CommonFolders.h
@@ -196,8 +196,7 @@ template <class AttrElementT,
               function_ref<std::optional<ElementValueT>(ElementValueT)>>
 Attribute constFoldUnaryOpConditional(ArrayRef<Attribute> operands,
                                       CalculationT &&calculate) {
-  assert(operands.size() == 1 && "unary op takes one operands");
-  if (!operands[0])
+  if (!llvm::getSingleElement(operands))
     return {};
 
   static_assert(
@@ -268,8 +267,7 @@ template <
     class CalculationT = function_ref<TargetElementValueT(ElementValueT, bool)>>
 Attribute constFoldCastOp(ArrayRef<Attribute> operands, Type resType,
                           CalculationT &&calculate) {
-  assert(operands.size() == 1 && "Cast op takes one operand");
-  if (!operands[0])
+  if (!llvm::getSingleElement(operands))
     return {};
 
   static_assert(
diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp
index 8803ba994b2c1..e01cb3a080b5c 100644
--- a/mlir/lib/Analysis/SliceAnalysis.cpp
+++ b/mlir/lib/Analysis/SliceAnalysis.cpp
@@ -107,7 +107,7 @@ static void getBackwardSliceImpl(Operation *op,
       // into us. For now, just bail.
       if (parentOp && backwardSlice->count(parentOp) == 0) {
         assert(parentOp->getNumRegions() == 1 &&
-               parentOp->getRegion(0).getBlocks().size() == 1);
+               llvm::hasSingleElement(parentOp->getRegion(0).getBlocks()));
         getBackwardSliceImpl(parentOp, backwardSlice, options);
       }
     } else {
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 1f2781aa82114..9c4dfa27b1447 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -834,8 +834,7 @@ struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
   LogicalResult
   matchAndRewrite(Op op, typename Op::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    assert(adaptor.getOperands().size() == 1);
-    Type srcType = adaptor.getOperands().front().getType();
+    Type srcType = llvm::getSingleElement(adaptor.getOperands()).getType();
     Type dstType = this->getTypeConverter()->convertType(op.getType());
     if (!dstType)
       return getTypeConversionFailure(rewriter, op);
diff --git a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
index 1b0f023527891..df2da138d3b52 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
@@ -101,8 +101,7 @@ struct WmmaConstantOpToSPIRVLowering final
   LogicalResult
   matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    assert(adaptor.getOperands().size() == 1);
-    Value cst = adaptor.getOperands().front();
+    Value cst = llvm::getSingleElement(adaptor.getOperands());
     auto coopType = getTypeConverter()->convertType(op.getType());
     if (!coopType)
       return rewriter.notifyMatchFailure(op, "type conversion failed");
@@ -181,8 +180,7 @@ struct WmmaElementwiseOpToSPIRVScalarMulLowering final
                                          "splat is not a composite construct");
     }
 
-    assert(cc.getConstituents().size() == 1);
-    scalar = cc.getConstituents().front();
+    scalar = llvm::getSingleElement(cc.getConstituents());
 
     auto coopType = getTypeConverter()->convertType(op.getType());
     if (!coopType)
diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
index b0884d321bc8a..33391995885a4 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -419,13 +419,11 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
     SmallVector<Value> dynDims, dynDevice;
     for (auto dim : adaptor.getDimsDynamic()) {
       // type conversion should be 1:1 for ints
-      assert(dim.size() == 1);
-      dynDims.emplace_back(dim[0]);
+      dynDims.emplace_back(llvm::getSingleElement(dim));
     }
     // same for device
     for (auto device : adaptor.getDeviceDynamic()) {
-      assert(device.size() == 1);
-      dynDevice.emplace_back(device[0]);
+      dynDevice.emplace_back(llvm::getSingleElement(device));
     }
 
     // To keep the code simple, convert dims/device to values when they are
@@ -771,18 +769,17 @@ struct ConvertMeshToMPIPass
     typeConverter.addConversion([](Type type) { return type; });
 
     // convert mesh::ShardingType to a tuple of RankedTensorTypes
-    typeConverter.addConversion(
-        [](ShardingType type,
-           SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
-          auto i16 = IntegerType::get(type.getContext(), 16);
-          auto i64 = IntegerType::get(type.getContext(), 64);
-          std::array<int64_t, 2> shp = {ShapedType::kDynamic,
-                                        ShapedType::kDynamic};
-          results.emplace_back(RankedTensorType::get(shp, i16));
-          results.emplace_back(RankedTensorType::get(shp, i64)); // actually ?x2
-          results.emplace_back(RankedTensorType::get(shp, i64));
-          return success();
-        });
+    typeConverter.addConversion([](ShardingType type,
+                                   SmallVectorImpl<Type> &results)
+                                    -> std::optional<LogicalResult> {
+      auto i16 = IntegerType::get(type.getContext(), 16);
+      auto i64 = IntegerType::get(type.getContext(), 64);
+      std::array<int64_t, 2> shp = {ShapedType::kDynamic, ShapedType::kDynamic};
+      results.emplace_back(RankedTensorType::get(shp, i16));
+      results.emplace_back(RankedTensorType::get(shp, i64)); // actually ?x2
+      results.emplace_back(RankedTensorType::get(shp, i64));
+      return success();
+    });
 
     // To 'extract' components, a UnrealizedConversionCastOp is expected
     // to define the input
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 8acb21d5074b4..9c5b9e82cd5e0 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -1236,8 +1236,7 @@ mlir::affine::makeComposedFoldedAffineApply(OpBuilder &b, Location loc,
   }
 
   applyOp->erase();
-  assert(foldResults.size() == 1 && "expected 1 folded result");
-  return foldResults.front();
+  return llvm::getSingleElement(foldResults);
 }
 
 OpFoldResult
@@ -1306,8 +1305,7 @@ static OpFoldResult makeComposedFoldedMinMax(OpBuilder &b, Location loc,
   }
 
   minMaxOp->erase();
-  assert(foldResults.size() == 1 && "expected 1 folded result");
-  return foldResults.front();
+  return llvm::getSingleElement(foldResults);
 }
 
 OpFoldResult
diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
index bcba17bb21544..4b4eb9ce37b4c 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
@@ -1249,8 +1249,7 @@ struct GreedyFusion {
       SmallVector<Operation *, 2> sibLoadOpInsts;
       sibNode->getLoadOpsForMemref(memref, &sibLoadOpInsts);
       // Currently findSiblingNodeToFuse searches for siblings with one load.
-      assert(sibLoadOpInsts.size() == 1);
-      Operation *sibLoadOpInst = sibLoadOpInsts[0];
+      Operation *sibLoadOpInst = llvm::getSingleElement(sibLoadOpInsts);
 
       // Gather 'dstNode' load ops to 'memref'.
       SmallVector<Operation *, 2> dstLoadOpInsts;
diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
index 71c6acba32d2e..dd539ff685653 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
@@ -1604,10 +1604,8 @@ SmallVector<AffineForOp, 8> mlir::affine::tile(ArrayRef<AffineForOp> forOps,
                                                ArrayRef<uint64_t> sizes,
                                                AffineForOp target) {
   SmallVector<AffineForOp, 8> res;
-  for (auto loops : tile(forOps, sizes, ArrayRef<AffineForOp>(target))) {
-    assert(loops.size() == 1);
-    res.push_back(loops[0]);
-  }
+  for (auto loops : tile(forOps, sizes, ArrayRef<AffineForOp>(target)))
+    res.push_back(llvm::getSingleElement(loops));
   return res;
 }
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.cpp
index 6fcfa05468eea..55a09622644ea 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.cpp
@@ -44,30 +44,27 @@ struct LinalgCopyOpInterface
                                                        linalg::CopyOp> {
   OpOperand &getSourceOperand(Operation *op) const {
     auto copyOp = cast<CopyOp>(op);
-    assert(copyOp.getInputs().size() == 1 && "expected single input");
-    return copyOp.getInputsMutable()[0];
+    return llvm::getSingleElement(copyOp.getInputsMutable());
   }
 
   bool
   isEquivalentSubset(Operation *op, Value candidate,
                      function_ref<bool(Value, Value)> equivalenceFn) const {
     auto copyOp = cast<CopyOp>(op);
-    assert(copyOp.getOutputs().size() == 1 && "expected single output");
-    return equivalenceFn(candidate, copyOp.getOutputs()[0]);
+    return equivalenceFn(candidate,
+                         llvm::getSingleElement(copyOp.getOutputs()));
   }
 
   Value buildSubsetExtraction(Operation *op, OpBuilder &builder,
                               Location loc) const {
     auto copyOp = cast<CopyOp>(op);
-    assert(copyOp.getOutputs().size() == 1 && "expected single output");
-    return copyOp.getOutputs()[0];
+    return llvm::getSingleElement(copyOp.getOutputs());
   }
 
   SmallVector<Value>
   getValuesNeededToBuildSubsetExtraction(Operation *op) const {
     auto copyOp = cast<CopyOp>(op);
-    assert(copyOp.getOutputs().size() == 1 && "expected single output");
-    return {copyOp.getOutputs()[0]};
+    return {llvm::getSingleElement(copyOp.getOutputs())};
   }
 };
 } // namespace
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 11597505e7888..59434dccc117b 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -471,7 +471,7 @@ static bool isOpItselfPotentialAutomaticAllocation(Operation *op) {
 /// extending the lifetime of allocations.
 static bool lastNonTerminatorInRegion(Operation *op) {
   return op->getNextNode() == op->getBlock()->getTerminator() &&
-         op->getParentRegion()->getBlocks().size() == 1;
+         llvm::hasSingleElement(op->getParentRegion()->getBlocks());
 }
 
 /// Inline an AllocaScopeOp if either the direct parent is an allocation scope
diff --git a/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp b/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
index 71b88d1be1b05..de834fed90e42 100644
--- a/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
+++ b/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
@@ -46,8 +46,8 @@ class QuantizedTypeConverter : public TypeConverter {
 
   static Value materializeConversion(OpBuilder &builder, Type type,
                                      ValueRange inputs, Location loc) {
-    assert(inputs.size() == 1);
-    return builder.create<quant::StorageCastOp>(loc, type, inputs[0]);
+    return builder.create<quant::StorageCastOp>(loc, type,
+                                                llvm::getSingleElement(inputs));
   }
 
 public:
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index e9d7dc1b847c6..ee46f9c97268b 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -52,7 +52,7 @@ static Value castBuffer(OpBuilder &b, Value buffer, Type type) {
 static bool doesNotAliasExternalValue(Value value, Region *region,
                                       ValueRange exceptions,
                                       const OneShotAnalysisState &state) {
-  assert(region->getBlocks().size() == 1 &&
+  assert(llvm::hasSingleElement(region->getBlocks()) &&
          "expected region with single block");
   bool result = true;
   state.applyOnAliases(value, [&](Value alias) {
diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
index c0589044c26ec..40d2e254fb7dd 100644
--- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
@@ -24,12 +24,6 @@ static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
   return result;
 }
 
-/// Assert that the given value range contains a single value and return it.
-static Value getSingleValue(ValueRange values) {
-  assert(values.size() == 1 && "expected single value");
-  return values.front();
-}
-
 // CRTP
 // A base class that takes care of 1:N type conversion, which maps the converted
 // op results (computed by the derived class) and materializes 1:N conversion.
@@ -119,9 +113,9 @@ class ConvertForOpTypes
     // We can not do clone as the number of result types after conversion
     // might be different.
     ForOp newOp = rewriter.create<ForOp>(
-        op.getLoc(), getSingleValue(adaptor.getLowerBound()),
-        getSingleValue(adaptor.getUpperBound()),
-        getSingleValue(adaptor.getStep()),
+        op.getLoc(), llvm::getSingleElement(adaptor.getLowerBound()),
+        llvm::getSingleElement(adaptor.getUpperBound()),
+        llvm::getSingleElement(adaptor.getStep()),
         flattenValues(adaptor.getInitArgs()));
 
     // Reserve whatever attributes in the original op.
@@ -149,7 +143,8 @@ class ConvertIfOpTypes
                                       TypeRange dstTypes) const {
 
     IfOp newOp = rewriter.create<IfOp>(
-        op.getLoc(), dstTypes, getSingleValue(adaptor.getCondition()), true);
+        op.getLoc(), dstTypes, llvm::getSingleElement(adaptor.getCondition()),
+        true);
     newOp->setAttrs(op->getAttrs());
 
     // We do not need the empty blocks created by rewriter.
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 19335255fd492..e9471c1dbd0b7 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -1310,10 +1310,8 @@ SmallVector<Loops, 8> mlir::tile(ArrayRef<scf::ForOp> forOps,
 Loops mlir::tile(ArrayRef<scf::ForOp> forOps, ArrayRef<Value> sizes,
                  scf::ForOp target) {
   SmallVector<scf::ForOp, 8> res;
-  for (auto loops : tile(forOps, sizes, ArrayRef<scf::ForOp>(target))) {
-    assert(loops.size() == 1);
-    res.push_back(loops[0]);
-  }
+  for (auto loops : tile(forOps, sizes, ArrayRef<scf::ForOp>(target)))
+    res.push_back(llvm::getSingleElement(loops));
   return res;
 }
 
diff --git a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
index 66a2e45001781..6c3b23937f98f 100644
--- a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -38,7 +38,7 @@ struct AssumingOpInterface
     size_t resultNum = std::distance(op->getOpResults().begin(),
                                      llvm::find(op->getOpResults(), value));
     // TODO: Support multiple blocks.
-    assert(assumingOp.getDoRegion().getBlocks().size() == 1 &&
+    assert(llvm::hasSingleElement(assumingOp.getDoRegion().getBlocks()) &&
            "expected exactly 1 block");
     auto yieldOp = dyn_cast<shape::AssumingYieldOp>(
         assumingOp.getDoRegion().front().getTerminator());
@@ -49,7 +49,7 @@ struct AssumingOpInterface
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           const BufferizationOptions &options) const {
     auto assumingOp = cast<shape::AssumingOp>(op);
-    assert(assumingOp.getDoRegion().getBlocks().size() == 1 &&
+    assert(llvm::hasSingleElement(assumingOp.getDoRegion().getBlocks()) &&
            "only 1 block supported");
     auto yieldOp = cast<shape::AssumingYieldOp>(
         assumingOp.getDoRegion().front().getTerminator());
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
index 9e9fea76416b9..948ba60ac0bbe 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
@@ -12,12 +12,6 @@
 using namespace mlir;
 using namespace mlir::sparse_tensor;
 
-/// Assert that the given value range contains a single value and return it.
-static Value getSingleValue(ValueRange values) {
-  assert(values.size() == 1 && "expected single value");
-  return values.front();
-}
-
 static void convertLevelType(SparseTensorEncodingAttr enc, Level lvl,
                              SmallVectorImpl<Type> &fields) {
   // Position and coordinate buffer in the sparse structure.
@@ -200,7 +194,7 @@ class ExtractIterSpaceConverter
 
     // Construct the iteration space.
     SparseIterationSpace space(loc, rewriter,
-                               getSingleValue(adaptor.getTensor()), 0,
+                               llvm::getSingleElement(adaptor.getTensor()), 0,
                                op.getLvlRange(), adaptor.getParentIter());
 
     SmallVector<Value> result = space.toValues();
@@ -218,8 +212,8 @@ class ExtractValOpConverter : public OpConversionPattern<ExtractValOp> {
                   ConversionPatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
     Value pos = adaptor.getIterator().back();
-    Value valBuf =
-        rewriter.create<ToValuesOp>(loc, getSingleValue(adaptor.getTensor()));
+    Value valBuf = rewriter.create<ToValuesOp>(
+        loc, llvm::getSingleElement(adaptor.getTensor()));
     rewriter.replaceOpWithNewOp<memref::LoadOp>(op, valBuf, pos);
     return success();
   }
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 20d46f7ca00c5..6a66ad24a87b4 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -47,12 +47,6 @@ static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
   return result;
 }
 
-/// Assert that the given value range contains a single value and return it.
-static Value getSingleValue(ValueRange values) {
-  assert(values.size() == 1 && "expected single value");
-  return values.front();
-}
-
 /// Generates a load with proper `index` typing.
 static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx) {
   idx = genCast(builder, loc, idx, builder.getIndexType());
@@ -962,10 +956,10 @@ class SparseCompressConverter : public OpConversionPattern<CompressOp> {
     SmallVector<Value> fields;
     auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields,
                                ...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Mar 15, 2025

@llvm/pr-subscribers-mlir-affine

Author: Matthias Springer (matthias-springer)

Changes

This commit adds a new helper function: getSingleElement

This function asserts that the container has a single element and then returns that element. This helper function is useful during 1:N dialect conversions, where certain ValueRanges (return from the adaptor) are known to have a single value.

Also update a few places that should use hasSingleElement instead of .size() = 1.


Patch is 26.31 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/131460.diff

23 Files Affected:

  • (modified) llvm/include/llvm/ADT/STLExtras.h (+8)
  • (modified) mlir/include/mlir/Dialect/CommonFolders.h (+2-4)
  • (modified) mlir/lib/Analysis/SliceAnalysis.cpp (+1-1)
  • (modified) mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp (+1-2)
  • (modified) mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp (+2-4)
  • (modified) mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp (+13-16)
  • (modified) mlir/lib/Dialect/Affine/IR/AffineOps.cpp (+2-4)
  • (modified) mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp (+1-2)
  • (modified) mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp (+2-4)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.cpp (+5-8)
  • (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp (+2-2)
  • (modified) mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp (+1-1)
  • (modified) mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp (+5-10)
  • (modified) mlir/lib/Dialect/SCF/Utils/Utils.cpp (+2-4)
  • (modified) mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp (+2-2)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp (+3-9)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp (+5-11)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp (+2-3)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp (+4-9)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+1-2)
  • (modified) mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp (+1-2)
  • (modified) mlir/test/lib/Analysis/TestCFGLoopInfo.cpp (+1-1)
diff --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h
index 78b7e94c2b3a1..dc0443c9244be 100644
--- a/llvm/include/llvm/ADT/STLExtras.h
+++ b/llvm/include/llvm/ADT/STLExtras.h
@@ -325,6 +325,14 @@ template <typename ContainerTy> bool hasSingleElement(ContainerTy &&C) {
   return B != E && std::next(B) == E;
 }
 
+/// Asserts that the given container has a single element and returns that
+/// element.
+template <typename ContainerTy>
+decltype(auto) getSingleElement(ContainerTy &&C) {
+  assert(hasSingleElement(C) && "expected container with single element");
+  return *adl_begin(C);
+}
+
 /// Return a range covering \p RangeOrContainer with the first N elements
 /// excluded.
 template <typename T> auto drop_begin(T &&RangeOrContainer, size_t N = 1) {
diff --git a/mlir/include/mlir/Dialect/CommonFolders.h b/mlir/include/mlir/Dialect/CommonFolders.h
index 6f497a259262a..b5a12426aff80 100644
--- a/mlir/include/mlir/Dialect/CommonFolders.h
+++ b/mlir/include/mlir/Dialect/CommonFolders.h
@@ -196,8 +196,7 @@ template <class AttrElementT,
               function_ref<std::optional<ElementValueT>(ElementValueT)>>
 Attribute constFoldUnaryOpConditional(ArrayRef<Attribute> operands,
                                       CalculationT &&calculate) {
-  assert(operands.size() == 1 && "unary op takes one operands");
-  if (!operands[0])
+  if (!llvm::getSingleElement(operands))
     return {};
 
   static_assert(
@@ -268,8 +267,7 @@ template <
     class CalculationT = function_ref<TargetElementValueT(ElementValueT, bool)>>
 Attribute constFoldCastOp(ArrayRef<Attribute> operands, Type resType,
                           CalculationT &&calculate) {
-  assert(operands.size() == 1 && "Cast op takes one operand");
-  if (!operands[0])
+  if (!llvm::getSingleElement(operands))
     return {};
 
   static_assert(
diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp
index 8803ba994b2c1..e01cb3a080b5c 100644
--- a/mlir/lib/Analysis/SliceAnalysis.cpp
+++ b/mlir/lib/Analysis/SliceAnalysis.cpp
@@ -107,7 +107,7 @@ static void getBackwardSliceImpl(Operation *op,
       // into us. For now, just bail.
       if (parentOp && backwardSlice->count(parentOp) == 0) {
         assert(parentOp->getNumRegions() == 1 &&
-               parentOp->getRegion(0).getBlocks().size() == 1);
+               llvm::hasSingleElement(parentOp->getRegion(0).getBlocks()));
         getBackwardSliceImpl(parentOp, backwardSlice, options);
       }
     } else {
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 1f2781aa82114..9c4dfa27b1447 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -834,8 +834,7 @@ struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
   LogicalResult
   matchAndRewrite(Op op, typename Op::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    assert(adaptor.getOperands().size() == 1);
-    Type srcType = adaptor.getOperands().front().getType();
+    Type srcType = llvm::getSingleElement(adaptor.getOperands()).getType();
     Type dstType = this->getTypeConverter()->convertType(op.getType());
     if (!dstType)
       return getTypeConversionFailure(rewriter, op);
diff --git a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
index 1b0f023527891..df2da138d3b52 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
@@ -101,8 +101,7 @@ struct WmmaConstantOpToSPIRVLowering final
   LogicalResult
   matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    assert(adaptor.getOperands().size() == 1);
-    Value cst = adaptor.getOperands().front();
+    Value cst = llvm::getSingleElement(adaptor.getOperands());
     auto coopType = getTypeConverter()->convertType(op.getType());
     if (!coopType)
       return rewriter.notifyMatchFailure(op, "type conversion failed");
@@ -181,8 +180,7 @@ struct WmmaElementwiseOpToSPIRVScalarMulLowering final
                                          "splat is not a composite construct");
     }
 
-    assert(cc.getConstituents().size() == 1);
-    scalar = cc.getConstituents().front();
+    scalar = llvm::getSingleElement(cc.getConstituents());
 
     auto coopType = getTypeConverter()->convertType(op.getType());
     if (!coopType)
diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
index b0884d321bc8a..33391995885a4 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -419,13 +419,11 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
     SmallVector<Value> dynDims, dynDevice;
     for (auto dim : adaptor.getDimsDynamic()) {
       // type conversion should be 1:1 for ints
-      assert(dim.size() == 1);
-      dynDims.emplace_back(dim[0]);
+      dynDims.emplace_back(llvm::getSingleElement(dim));
     }
     // same for device
     for (auto device : adaptor.getDeviceDynamic()) {
-      assert(device.size() == 1);
-      dynDevice.emplace_back(device[0]);
+      dynDevice.emplace_back(llvm::getSingleElement(device));
     }
 
     // To keep the code simple, convert dims/device to values when they are
@@ -771,18 +769,17 @@ struct ConvertMeshToMPIPass
     typeConverter.addConversion([](Type type) { return type; });
 
     // convert mesh::ShardingType to a tuple of RankedTensorTypes
-    typeConverter.addConversion(
-        [](ShardingType type,
-           SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
-          auto i16 = IntegerType::get(type.getContext(), 16);
-          auto i64 = IntegerType::get(type.getContext(), 64);
-          std::array<int64_t, 2> shp = {ShapedType::kDynamic,
-                                        ShapedType::kDynamic};
-          results.emplace_back(RankedTensorType::get(shp, i16));
-          results.emplace_back(RankedTensorType::get(shp, i64)); // actually ?x2
-          results.emplace_back(RankedTensorType::get(shp, i64));
-          return success();
-        });
+    typeConverter.addConversion([](ShardingType type,
+                                   SmallVectorImpl<Type> &results)
+                                    -> std::optional<LogicalResult> {
+      auto i16 = IntegerType::get(type.getContext(), 16);
+      auto i64 = IntegerType::get(type.getContext(), 64);
+      std::array<int64_t, 2> shp = {ShapedType::kDynamic, ShapedType::kDynamic};
+      results.emplace_back(RankedTensorType::get(shp, i16));
+      results.emplace_back(RankedTensorType::get(shp, i64)); // actually ?x2
+      results.emplace_back(RankedTensorType::get(shp, i64));
+      return success();
+    });
 
     // To 'extract' components, a UnrealizedConversionCastOp is expected
     // to define the input
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 8acb21d5074b4..9c5b9e82cd5e0 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -1236,8 +1236,7 @@ mlir::affine::makeComposedFoldedAffineApply(OpBuilder &b, Location loc,
   }
 
   applyOp->erase();
-  assert(foldResults.size() == 1 && "expected 1 folded result");
-  return foldResults.front();
+  return llvm::getSingleElement(foldResults);
 }
 
 OpFoldResult
@@ -1306,8 +1305,7 @@ static OpFoldResult makeComposedFoldedMinMax(OpBuilder &b, Location loc,
   }
 
   minMaxOp->erase();
-  assert(foldResults.size() == 1 && "expected 1 folded result");
-  return foldResults.front();
+  return llvm::getSingleElement(foldResults);
 }
 
 OpFoldResult
diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
index bcba17bb21544..4b4eb9ce37b4c 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
@@ -1249,8 +1249,7 @@ struct GreedyFusion {
       SmallVector<Operation *, 2> sibLoadOpInsts;
       sibNode->getLoadOpsForMemref(memref, &sibLoadOpInsts);
       // Currently findSiblingNodeToFuse searches for siblings with one load.
-      assert(sibLoadOpInsts.size() == 1);
-      Operation *sibLoadOpInst = sibLoadOpInsts[0];
+      Operation *sibLoadOpInst = llvm::getSingleElement(sibLoadOpInsts);
 
       // Gather 'dstNode' load ops to 'memref'.
       SmallVector<Operation *, 2> dstLoadOpInsts;
diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
index 71c6acba32d2e..dd539ff685653 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
@@ -1604,10 +1604,8 @@ SmallVector<AffineForOp, 8> mlir::affine::tile(ArrayRef<AffineForOp> forOps,
                                                ArrayRef<uint64_t> sizes,
                                                AffineForOp target) {
   SmallVector<AffineForOp, 8> res;
-  for (auto loops : tile(forOps, sizes, ArrayRef<AffineForOp>(target))) {
-    assert(loops.size() == 1);
-    res.push_back(loops[0]);
-  }
+  for (auto loops : tile(forOps, sizes, ArrayRef<AffineForOp>(target)))
+    res.push_back(llvm::getSingleElement(loops));
   return res;
 }
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.cpp
index 6fcfa05468eea..55a09622644ea 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.cpp
@@ -44,30 +44,27 @@ struct LinalgCopyOpInterface
                                                        linalg::CopyOp> {
   OpOperand &getSourceOperand(Operation *op) const {
     auto copyOp = cast<CopyOp>(op);
-    assert(copyOp.getInputs().size() == 1 && "expected single input");
-    return copyOp.getInputsMutable()[0];
+    return llvm::getSingleElement(copyOp.getInputsMutable());
   }
 
   bool
   isEquivalentSubset(Operation *op, Value candidate,
                      function_ref<bool(Value, Value)> equivalenceFn) const {
     auto copyOp = cast<CopyOp>(op);
-    assert(copyOp.getOutputs().size() == 1 && "expected single output");
-    return equivalenceFn(candidate, copyOp.getOutputs()[0]);
+    return equivalenceFn(candidate,
+                         llvm::getSingleElement(copyOp.getOutputs()));
   }
 
   Value buildSubsetExtraction(Operation *op, OpBuilder &builder,
                               Location loc) const {
     auto copyOp = cast<CopyOp>(op);
-    assert(copyOp.getOutputs().size() == 1 && "expected single output");
-    return copyOp.getOutputs()[0];
+    return llvm::getSingleElement(copyOp.getOutputs());
   }
 
   SmallVector<Value>
   getValuesNeededToBuildSubsetExtraction(Operation *op) const {
     auto copyOp = cast<CopyOp>(op);
-    assert(copyOp.getOutputs().size() == 1 && "expected single output");
-    return {copyOp.getOutputs()[0]};
+    return {llvm::getSingleElement(copyOp.getOutputs())};
   }
 };
 } // namespace
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 11597505e7888..59434dccc117b 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -471,7 +471,7 @@ static bool isOpItselfPotentialAutomaticAllocation(Operation *op) {
 /// extending the lifetime of allocations.
 static bool lastNonTerminatorInRegion(Operation *op) {
   return op->getNextNode() == op->getBlock()->getTerminator() &&
-         op->getParentRegion()->getBlocks().size() == 1;
+         llvm::hasSingleElement(op->getParentRegion()->getBlocks());
 }
 
 /// Inline an AllocaScopeOp if either the direct parent is an allocation scope
diff --git a/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp b/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
index 71b88d1be1b05..de834fed90e42 100644
--- a/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
+++ b/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
@@ -46,8 +46,8 @@ class QuantizedTypeConverter : public TypeConverter {
 
   static Value materializeConversion(OpBuilder &builder, Type type,
                                      ValueRange inputs, Location loc) {
-    assert(inputs.size() == 1);
-    return builder.create<quant::StorageCastOp>(loc, type, inputs[0]);
+    return builder.create<quant::StorageCastOp>(loc, type,
+                                                llvm::getSingleElement(inputs));
   }
 
 public:
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index e9d7dc1b847c6..ee46f9c97268b 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -52,7 +52,7 @@ static Value castBuffer(OpBuilder &b, Value buffer, Type type) {
 static bool doesNotAliasExternalValue(Value value, Region *region,
                                       ValueRange exceptions,
                                       const OneShotAnalysisState &state) {
-  assert(region->getBlocks().size() == 1 &&
+  assert(llvm::hasSingleElement(region->getBlocks()) &&
          "expected region with single block");
   bool result = true;
   state.applyOnAliases(value, [&](Value alias) {
diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
index c0589044c26ec..40d2e254fb7dd 100644
--- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
@@ -24,12 +24,6 @@ static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
   return result;
 }
 
-/// Assert that the given value range contains a single value and return it.
-static Value getSingleValue(ValueRange values) {
-  assert(values.size() == 1 && "expected single value");
-  return values.front();
-}
-
 // CRTP
 // A base class that takes care of 1:N type conversion, which maps the converted
 // op results (computed by the derived class) and materializes 1:N conversion.
@@ -119,9 +113,9 @@ class ConvertForOpTypes
     // We can not do clone as the number of result types after conversion
     // might be different.
     ForOp newOp = rewriter.create<ForOp>(
-        op.getLoc(), getSingleValue(adaptor.getLowerBound()),
-        getSingleValue(adaptor.getUpperBound()),
-        getSingleValue(adaptor.getStep()),
+        op.getLoc(), llvm::getSingleElement(adaptor.getLowerBound()),
+        llvm::getSingleElement(adaptor.getUpperBound()),
+        llvm::getSingleElement(adaptor.getStep()),
         flattenValues(adaptor.getInitArgs()));
 
     // Reserve whatever attributes in the original op.
@@ -149,7 +143,8 @@ class ConvertIfOpTypes
                                       TypeRange dstTypes) const {
 
     IfOp newOp = rewriter.create<IfOp>(
-        op.getLoc(), dstTypes, getSingleValue(adaptor.getCondition()), true);
+        op.getLoc(), dstTypes, llvm::getSingleElement(adaptor.getCondition()),
+        true);
     newOp->setAttrs(op->getAttrs());
 
     // We do not need the empty blocks created by rewriter.
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 19335255fd492..e9471c1dbd0b7 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -1310,10 +1310,8 @@ SmallVector<Loops, 8> mlir::tile(ArrayRef<scf::ForOp> forOps,
 Loops mlir::tile(ArrayRef<scf::ForOp> forOps, ArrayRef<Value> sizes,
                  scf::ForOp target) {
   SmallVector<scf::ForOp, 8> res;
-  for (auto loops : tile(forOps, sizes, ArrayRef<scf::ForOp>(target))) {
-    assert(loops.size() == 1);
-    res.push_back(loops[0]);
-  }
+  for (auto loops : tile(forOps, sizes, ArrayRef<scf::ForOp>(target)))
+    res.push_back(llvm::getSingleElement(loops));
   return res;
 }
 
diff --git a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
index 66a2e45001781..6c3b23937f98f 100644
--- a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -38,7 +38,7 @@ struct AssumingOpInterface
     size_t resultNum = std::distance(op->getOpResults().begin(),
                                      llvm::find(op->getOpResults(), value));
     // TODO: Support multiple blocks.
-    assert(assumingOp.getDoRegion().getBlocks().size() == 1 &&
+    assert(llvm::hasSingleElement(assumingOp.getDoRegion().getBlocks()) &&
            "expected exactly 1 block");
     auto yieldOp = dyn_cast<shape::AssumingYieldOp>(
         assumingOp.getDoRegion().front().getTerminator());
@@ -49,7 +49,7 @@ struct AssumingOpInterface
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           const BufferizationOptions &options) const {
     auto assumingOp = cast<shape::AssumingOp>(op);
-    assert(assumingOp.getDoRegion().getBlocks().size() == 1 &&
+    assert(llvm::hasSingleElement(assumingOp.getDoRegion().getBlocks()) &&
            "only 1 block supported");
     auto yieldOp = cast<shape::AssumingYieldOp>(
         assumingOp.getDoRegion().front().getTerminator());
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
index 9e9fea76416b9..948ba60ac0bbe 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
@@ -12,12 +12,6 @@
 using namespace mlir;
 using namespace mlir::sparse_tensor;
 
-/// Assert that the given value range contains a single value and return it.
-static Value getSingleValue(ValueRange values) {
-  assert(values.size() == 1 && "expected single value");
-  return values.front();
-}
-
 static void convertLevelType(SparseTensorEncodingAttr enc, Level lvl,
                              SmallVectorImpl<Type> &fields) {
   // Position and coordinate buffer in the sparse structure.
@@ -200,7 +194,7 @@ class ExtractIterSpaceConverter
 
     // Construct the iteration space.
     SparseIterationSpace space(loc, rewriter,
-                               getSingleValue(adaptor.getTensor()), 0,
+                               llvm::getSingleElement(adaptor.getTensor()), 0,
                                op.getLvlRange(), adaptor.getParentIter());
 
     SmallVector<Value> result = space.toValues();
@@ -218,8 +212,8 @@ class ExtractValOpConverter : public OpConversionPattern<ExtractValOp> {
                   ConversionPatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
     Value pos = adaptor.getIterator().back();
-    Value valBuf =
-        rewriter.create<ToValuesOp>(loc, getSingleValue(adaptor.getTensor()));
+    Value valBuf = rewriter.create<ToValuesOp>(
+        loc, llvm::getSingleElement(adaptor.getTensor()));
     rewriter.replaceOpWithNewOp<memref::LoadOp>(op, valBuf, pos);
     return success();
   }
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 20d46f7ca00c5..6a66ad24a87b4 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -47,12 +47,6 @@ static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
   return result;
 }
 
-/// Assert that the given value range contains a single value and return it.
-static Value getSingleValue(ValueRange values) {
-  assert(values.size() == 1 && "expected single value");
-  return values.front();
-}
-
 /// Generates a load with proper `index` typing.
 static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx) {
   idx = genCast(builder, loc, idx, builder.getIndexType());
@@ -962,10 +956,10 @@ class SparseCompressConverter : public OpConversionPattern<CompressOp> {
     SmallVector<Value> fields;
     auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields,
                                ...
[truncated]

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like a nice addition.

For STLExtras changes, we typically require unit tests and land the change that introduces uses as a follow up PR. This is to minimize the risk of reverts that cause a full project rebuild. Could you split this PR into two?

For as for unit tests, we don't need anything fancy beyond making sure this works with a couple of data types and const / lvalue / rvalue parameters. We could also add a death test for the assertion.

@kuhar
Copy link
Member

kuhar commented Mar 16, 2025

Shower thought: we could also have auto &[first, second, third] = getNElements<3>(range);, although I'm not sure how often this comes up. (Not suggesting any changes to this PR, just an idea)

@matthias-springer matthias-springer force-pushed the users/matthias-springer/get_single_element branch from 530de07 to 6217315 Compare March 16, 2025 10:25
@matthias-springer matthias-springer changed the title [llvm] Add getSingleElement helper and use in MLIR [mlir] Use getSingleElement/hasSingleElement in various places Mar 16, 2025
@matthias-springer matthias-springer changed the base branch from main to users/matthias-springer/get_single_el March 16, 2025 10:26
@matthias-springer matthias-springer requested a review from kuhar March 16, 2025 10:26
Copy link

github-actions bot commented Mar 16, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@matthias-springer matthias-springer force-pushed the users/matthias-springer/get_single_el branch 3 times, most recently from 65159c4 to 78eca92 Compare March 16, 2025 12:22
@matthias-springer matthias-springer force-pushed the users/matthias-springer/get_single_element branch from 6217315 to cacf322 Compare March 16, 2025 12:23
Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM % formatting

Base automatically changed from users/matthias-springer/get_single_el to main March 16, 2025 20:20
@matthias-springer matthias-springer force-pushed the users/matthias-springer/get_single_element branch from cacf322 to 0ce8d2c Compare March 16, 2025 20:23
@matthias-springer matthias-springer merged commit 6c867e2 into main Mar 17, 2025
11 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/get_single_element branch March 17, 2025 06:43
@joker-eph
Copy link
Collaborator

For STLExtras changes, we typically require unit tests and land the change that introduces uses as a follow up PR.

That could be good practices in general, but that hasn't been very followed here, for example this one was a very similar patch I think (and you were OK on this one as is).

@kuhar
Copy link
Member

kuhar commented Mar 17, 2025

That's what I meant by 'typically require' -- trivial wrappers around simple STL functions would arguably benefit only a little from tests, but new code with interesting contract definitely does.

@dwblaikie
Copy link
Collaborator

(arguably, yes - I'd argue that even the simple wrappers could/should be tested, FWIW - but I understand it's a goal, not a hard-and-fast/inviolable rule - not every failure to meet that bar is equally problematic, etc)

@joker-eph
Copy link
Collaborator

yes - I'd argue that even the simple wrappers could/should be tested,

Right, the question is what is the bar for "tested". For every other kind of utilities, having enough in-tree uses counts as "tested" (we don't have C++ gtests for every possible utility functions historically).

@dwblaikie
Copy link
Collaborator

yes - I'd argue that even the simple wrappers could/should be tested,

Right, the question is what is the bar for "tested". For every other kind of utilities, having enough in-tree uses counts as "tested" (we don't have C++ gtests for every possible utility functions historically).

Sorry, I meant unit tested in this context/case. For things in ADT/STLExtras - yes, we don't, historically, have full coverage - but as the project's gotten larger I think it's become more relevant to test them in isolation more robustly (so that they're tested even when the only uses are in some non-LLVM subproject, so that they're tested robustly regardless of which use cases we have/don't have in-tree at any given moment, etc).

@kuhar
Copy link
Member

kuhar commented Mar 17, 2025

yes - I'd argue that even the simple wrappers could/should be tested,

Right, the question is what is the bar for "tested". For every other kind of utilities, having enough in-tree uses counts as "tested" (we don't have C++ gtests for every possible utility functions historically).

Continuing with these two examples (min_element and getSingleElement), my thought process was roughly:

  1. Is the code tested elsewhere. For min_element, the underlying c++ function would be tested on the STL side. For getSingleElement, hasSingleElement comes with unit tests but not in combination with the surrounding logic.
  2. Is it possible that the new uses in the codebase don't fully exercise the code. For min_element, probably not. For getSingleElement: probably yes. We don't expect to hit the assertion anywhere, if the assertion never triggers, we wouldn't notice. We had this issue with cast functions in the pasts that stopped asserting for some time.
  3. Is it possible that future modification to the implementation will introduce breakage that will go unnoticed? For min_element, highly unlikely (there aren't many ways to update this code AFAICT). For getSingleElement, potentially yes: say we change hasSingleElement to hasNItemsOrMore(C, 1) or some other functions that appears to be similar.

That's why I think it was safe to wave min_element through but not getSingleElement.

@joker-eph
Copy link
Collaborator

joker-eph commented Mar 17, 2025

We don't expect to hit the assertion anywhere, if the assertion never triggers, we wouldn't notice.

This is consistent with assertions in the projects in general I believe: we don't test for assertions.

@kuhar
Copy link
Member

kuhar commented Mar 17, 2025

We don't expect to hit the assertion anywhere, if the assertion never triggers, we wouldn't notice.

This is consistent with assertions in the projects in general I believe: we don't test for assertions.

We do test for assertions in ADT code when the assertions are advertised as part of the interface. For example many range/iterator functions like enumerate, zip, early_inc_range, seq are all tested for assertions. Outside of ADT we also test stuff like cast/any_cast/etc.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants