Skip to content

[mlir] Use getSingleElement/hasSingleElement in various places #131460

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions mlir/include/mlir/Dialect/CommonFolders.h
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,7 @@ template <class AttrElementT,
function_ref<std::optional<ElementValueT>(ElementValueT)>>
Attribute constFoldUnaryOpConditional(ArrayRef<Attribute> operands,
CalculationT &&calculate) {
assert(operands.size() == 1 && "unary op takes one operands");
if (!operands[0])
if (!llvm::getSingleElement(operands))
return {};

static_assert(
Expand Down Expand Up @@ -268,8 +267,7 @@ template <
class CalculationT = function_ref<TargetElementValueT(ElementValueT, bool)>>
Attribute constFoldCastOp(ArrayRef<Attribute> operands, Type resType,
CalculationT &&calculate) {
assert(operands.size() == 1 && "Cast op takes one operand");
if (!operands[0])
if (!llvm::getSingleElement(operands))
return {};

static_assert(
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Analysis/SliceAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ static void getBackwardSliceImpl(Operation *op,
// into us. For now, just bail.
if (parentOp && backwardSlice->count(parentOp) == 0) {
assert(parentOp->getNumRegions() == 1 &&
parentOp->getRegion(0).getBlocks().size() == 1);
llvm::hasSingleElement(parentOp->getRegion(0).getBlocks()));
getBackwardSliceImpl(parentOp, backwardSlice, options);
}
} else {
Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -834,8 +834,7 @@ struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
LogicalResult
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
assert(adaptor.getOperands().size() == 1);
Type srcType = adaptor.getOperands().front().getType();
Type srcType = llvm::getSingleElement(adaptor.getOperands()).getType();
Type dstType = this->getTypeConverter()->convertType(op.getType());
if (!dstType)
return getTypeConversionFailure(rewriter, op);
Expand Down
6 changes: 2 additions & 4 deletions mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,7 @@ struct WmmaConstantOpToSPIRVLowering final
LogicalResult
matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
assert(adaptor.getOperands().size() == 1);
Value cst = adaptor.getOperands().front();
Value cst = llvm::getSingleElement(adaptor.getOperands());
auto coopType = getTypeConverter()->convertType(op.getType());
if (!coopType)
return rewriter.notifyMatchFailure(op, "type conversion failed");
Expand Down Expand Up @@ -181,8 +180,7 @@ struct WmmaElementwiseOpToSPIRVScalarMulLowering final
"splat is not a composite construct");
}

assert(cc.getConstituents().size() == 1);
scalar = cc.getConstituents().front();
scalar = llvm::getSingleElement(cc.getConstituents());

auto coopType = getTypeConverter()->convertType(op.getType());
if (!coopType)
Expand Down
6 changes: 2 additions & 4 deletions mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -419,13 +419,11 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
SmallVector<Value> dynDims, dynDevice;
for (auto dim : adaptor.getDimsDynamic()) {
// type conversion should be 1:1 for ints
assert(dim.size() == 1);
dynDims.emplace_back(dim[0]);
dynDims.emplace_back(llvm::getSingleElement(dim));
}
// same for device
for (auto device : adaptor.getDeviceDynamic()) {
assert(device.size() == 1);
dynDevice.emplace_back(device[0]);
dynDevice.emplace_back(llvm::getSingleElement(device));
}

// To keep the code simple, convert dims/device to values when they are
Expand Down
6 changes: 2 additions & 4 deletions mlir/lib/Dialect/Affine/IR/AffineOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1236,8 +1236,7 @@ mlir::affine::makeComposedFoldedAffineApply(OpBuilder &b, Location loc,
}

applyOp->erase();
assert(foldResults.size() == 1 && "expected 1 folded result");
return foldResults.front();
return llvm::getSingleElement(foldResults);
}

OpFoldResult
Expand Down Expand Up @@ -1306,8 +1305,7 @@ static OpFoldResult makeComposedFoldedMinMax(OpBuilder &b, Location loc,
}

minMaxOp->erase();
assert(foldResults.size() == 1 && "expected 1 folded result");
return foldResults.front();
return llvm::getSingleElement(foldResults);
}

OpFoldResult
Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1249,8 +1249,7 @@ struct GreedyFusion {
SmallVector<Operation *, 2> sibLoadOpInsts;
sibNode->getLoadOpsForMemref(memref, &sibLoadOpInsts);
// Currently findSiblingNodeToFuse searches for siblings with one load.
assert(sibLoadOpInsts.size() == 1);
Operation *sibLoadOpInst = sibLoadOpInsts[0];
Operation *sibLoadOpInst = llvm::getSingleElement(sibLoadOpInsts);

// Gather 'dstNode' load ops to 'memref'.
SmallVector<Operation *, 2> dstLoadOpInsts;
Expand Down
6 changes: 2 additions & 4 deletions mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1604,10 +1604,8 @@ SmallVector<AffineForOp, 8> mlir::affine::tile(ArrayRef<AffineForOp> forOps,
ArrayRef<uint64_t> sizes,
AffineForOp target) {
SmallVector<AffineForOp, 8> res;
for (auto loops : tile(forOps, sizes, ArrayRef<AffineForOp>(target))) {
assert(loops.size() == 1);
res.push_back(loops[0]);
}
for (auto loops : tile(forOps, sizes, ArrayRef<AffineForOp>(target)))
res.push_back(llvm::getSingleElement(loops));
return res;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,30 +44,27 @@ struct LinalgCopyOpInterface
linalg::CopyOp> {
OpOperand &getSourceOperand(Operation *op) const {
auto copyOp = cast<CopyOp>(op);
assert(copyOp.getInputs().size() == 1 && "expected single input");
return copyOp.getInputsMutable()[0];
return llvm::getSingleElement(copyOp.getInputsMutable());
}

bool
isEquivalentSubset(Operation *op, Value candidate,
function_ref<bool(Value, Value)> equivalenceFn) const {
auto copyOp = cast<CopyOp>(op);
assert(copyOp.getOutputs().size() == 1 && "expected single output");
return equivalenceFn(candidate, copyOp.getOutputs()[0]);
return equivalenceFn(candidate,
llvm::getSingleElement(copyOp.getOutputs()));
}

Value buildSubsetExtraction(Operation *op, OpBuilder &builder,
Location loc) const {
auto copyOp = cast<CopyOp>(op);
assert(copyOp.getOutputs().size() == 1 && "expected single output");
return copyOp.getOutputs()[0];
return llvm::getSingleElement(copyOp.getOutputs());
}

SmallVector<Value>
getValuesNeededToBuildSubsetExtraction(Operation *op) const {
auto copyOp = cast<CopyOp>(op);
assert(copyOp.getOutputs().size() == 1 && "expected single output");
return {copyOp.getOutputs()[0]};
return {llvm::getSingleElement(copyOp.getOutputs())};
}
};
} // namespace
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ static bool isOpItselfPotentialAutomaticAllocation(Operation *op) {
/// extending the lifetime of allocations.
static bool lastNonTerminatorInRegion(Operation *op) {
return op->getNextNode() == op->getBlock()->getTerminator() &&
op->getParentRegion()->getBlocks().size() == 1;
llvm::hasSingleElement(op->getParentRegion()->getBlocks());
}

/// Inline an AllocaScopeOp if either the direct parent is an allocation scope
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ class QuantizedTypeConverter : public TypeConverter {

static Value materializeConversion(OpBuilder &builder, Type type,
ValueRange inputs, Location loc) {
assert(inputs.size() == 1);
return builder.create<quant::StorageCastOp>(loc, type, inputs[0]);
return builder.create<quant::StorageCastOp>(loc, type,
llvm::getSingleElement(inputs));
}

public:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ static Value castBuffer(OpBuilder &b, Value buffer, Type type) {
static bool doesNotAliasExternalValue(Value value, Region *region,
ValueRange exceptions,
const OneShotAnalysisState &state) {
assert(region->getBlocks().size() == 1 &&
assert(llvm::hasSingleElement(region->getBlocks()) &&
"expected region with single block");
bool result = true;
state.applyOnAliases(value, [&](Value alias) {
Expand Down
15 changes: 5 additions & 10 deletions mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,6 @@ static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
return result;
}

/// Assert that the given value range contains a single value and return it.
static Value getSingleValue(ValueRange values) {
assert(values.size() == 1 && "expected single value");
return values.front();
}

// CRTP
// A base class that takes care of 1:N type conversion, which maps the converted
// op results (computed by the derived class) and materializes 1:N conversion.
Expand Down Expand Up @@ -119,9 +113,9 @@ class ConvertForOpTypes
// We can not do clone as the number of result types after conversion
// might be different.
ForOp newOp = rewriter.create<ForOp>(
op.getLoc(), getSingleValue(adaptor.getLowerBound()),
getSingleValue(adaptor.getUpperBound()),
getSingleValue(adaptor.getStep()),
op.getLoc(), llvm::getSingleElement(adaptor.getLowerBound()),
llvm::getSingleElement(adaptor.getUpperBound()),
llvm::getSingleElement(adaptor.getStep()),
flattenValues(adaptor.getInitArgs()));

// Reserve whatever attributes in the original op.
Expand Down Expand Up @@ -149,7 +143,8 @@ class ConvertIfOpTypes
TypeRange dstTypes) const {

IfOp newOp = rewriter.create<IfOp>(
op.getLoc(), dstTypes, getSingleValue(adaptor.getCondition()), true);
op.getLoc(), dstTypes, llvm::getSingleElement(adaptor.getCondition()),
true);
newOp->setAttrs(op->getAttrs());

// We do not need the empty blocks created by rewriter.
Expand Down
6 changes: 2 additions & 4 deletions mlir/lib/Dialect/SCF/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1310,10 +1310,8 @@ SmallVector<Loops, 8> mlir::tile(ArrayRef<scf::ForOp> forOps,
Loops mlir::tile(ArrayRef<scf::ForOp> forOps, ArrayRef<Value> sizes,
scf::ForOp target) {
SmallVector<scf::ForOp, 8> res;
for (auto loops : tile(forOps, sizes, ArrayRef<scf::ForOp>(target))) {
assert(loops.size() == 1);
res.push_back(loops[0]);
}
for (auto loops : tile(forOps, sizes, ArrayRef<scf::ForOp>(target)))
res.push_back(llvm::getSingleElement(loops));
return res;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ struct AssumingOpInterface
size_t resultNum = std::distance(op->getOpResults().begin(),
llvm::find(op->getOpResults(), value));
// TODO: Support multiple blocks.
assert(assumingOp.getDoRegion().getBlocks().size() == 1 &&
assert(llvm::hasSingleElement(assumingOp.getDoRegion().getBlocks()) &&
"expected exactly 1 block");
auto yieldOp = dyn_cast<shape::AssumingYieldOp>(
assumingOp.getDoRegion().front().getTerminator());
Expand All @@ -49,7 +49,7 @@ struct AssumingOpInterface
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
auto assumingOp = cast<shape::AssumingOp>(op);
assert(assumingOp.getDoRegion().getBlocks().size() == 1 &&
assert(llvm::hasSingleElement(assumingOp.getDoRegion().getBlocks()) &&
"only 1 block supported");
auto yieldOp = cast<shape::AssumingYieldOp>(
assumingOp.getDoRegion().front().getTerminator());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,6 @@
using namespace mlir;
using namespace mlir::sparse_tensor;

/// Assert that the given value range contains a single value and return it.
static Value getSingleValue(ValueRange values) {
assert(values.size() == 1 && "expected single value");
return values.front();
}

static void convertLevelType(SparseTensorEncodingAttr enc, Level lvl,
SmallVectorImpl<Type> &fields) {
// Position and coordinate buffer in the sparse structure.
Expand Down Expand Up @@ -200,7 +194,7 @@ class ExtractIterSpaceConverter

// Construct the iteration space.
SparseIterationSpace space(loc, rewriter,
getSingleValue(adaptor.getTensor()), 0,
llvm::getSingleElement(adaptor.getTensor()), 0,
op.getLvlRange(), adaptor.getParentIter());

SmallVector<Value> result = space.toValues();
Expand All @@ -218,8 +212,8 @@ class ExtractValOpConverter : public OpConversionPattern<ExtractValOp> {
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value pos = adaptor.getIterator().back();
Value valBuf =
rewriter.create<ToValuesOp>(loc, getSingleValue(adaptor.getTensor()));
Value valBuf = rewriter.create<ToValuesOp>(
loc, llvm::getSingleElement(adaptor.getTensor()));
rewriter.replaceOpWithNewOp<memref::LoadOp>(op, valBuf, pos);
return success();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,6 @@ static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
return result;
}

/// Assert that the given value range contains a single value and return it.
static Value getSingleValue(ValueRange values) {
assert(values.size() == 1 && "expected single value");
return values.front();
}

/// Generates a load with proper `index` typing.
static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx) {
idx = genCast(builder, loc, idx, builder.getIndexType());
Expand Down Expand Up @@ -962,10 +956,10 @@ class SparseCompressConverter : public OpConversionPattern<CompressOp> {
SmallVector<Value> fields;
auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields,
op.getTensor().getType());
Value values = getSingleValue(adaptor.getValues());
Value filled = getSingleValue(adaptor.getFilled());
Value added = getSingleValue(adaptor.getAdded());
Value count = getSingleValue(adaptor.getCount());
Value values = llvm::getSingleElement(adaptor.getValues());
Value filled = llvm::getSingleElement(adaptor.getFilled());
Value added = llvm::getSingleElement(adaptor.getAdded());
Value count = llvm::getSingleElement(adaptor.getCount());
const SparseTensorType dstType(desc.getRankedTensorType());
Type eltType = dstType.getElementType();

Expand Down Expand Up @@ -1041,7 +1035,7 @@ class SparseInsertConverter : public OpConversionPattern<tensor::InsertOp> {
SmallVector<Value> params = llvm::to_vector(desc.getFields());
SmallVector<Value> flatIndices = flattenValues(adaptor.getIndices());
params.append(flatIndices.begin(), flatIndices.end());
params.push_back(getSingleValue(adaptor.getScalar()));
params.push_back(llvm::getSingleElement(adaptor.getScalar()));
SparseInsertGenerator insertGen(op.getDest().getType(), flatSpTensorTps,
params, /*genCall=*/true);
SmallVector<Value> ret = insertGen.genCallOrInline(rewriter, loc);
Expand Down
5 changes: 2 additions & 3 deletions mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -521,9 +521,8 @@ static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, ExprId exp) {
Value ptr = genSubscript(env, builder, t, args);
if (llvm::isa<TensorType>(ptr.getType())) {
assert(env.options().sparseEmitStrategy ==
SparseEmitStrategy::kSparseIterator &&
args.size() == 1);
return builder.create<ExtractValOp>(loc, ptr, args.front());
SparseEmitStrategy::kSparseIterator);
return builder.create<ExtractValOp>(loc, ptr, llvm::getSingleElement(args));
}
return builder.create<memref::LoadOp>(loc, ptr, args);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1106,9 +1106,7 @@ Value FilterIterator::genShouldFilter(OpBuilder &b, Location l) {
Value notLegit = genCrdNotLegitPredicate(b, l, wrapCrd);
return {notLegit};
});

assert(r.size() == 1);
return r.front();
return llvm::getSingleElement(r);
}

Value FilterIterator::genNotEndImpl(OpBuilder &b, Location l) {
Expand All @@ -1120,8 +1118,7 @@ Value FilterIterator::genNotEndImpl(OpBuilder &b, Location l) {
// crd < size
return {CMPI(ult, crd, size)};
});
assert(r.size() == 1);
return r.front();
return llvm::getSingleElement(r);
}

ValueRange FilterIterator::forwardImpl(OpBuilder &b, Location l) {
Expand All @@ -1145,7 +1142,6 @@ ValueRange FilterIterator::forwardImpl(OpBuilder &b, Location l) {
/*beforeBuilder=*/
[this](OpBuilder &b, Location l, ValueRange ivs) {
ValueRange isFirst = linkNewScope(ivs);
assert(isFirst.size() == 1);
scf::ValueVector cont =
genWhenInBound(b, l, *wrap, C_FALSE,
[this, isFirst](OpBuilder &b, Location l,
Expand All @@ -1155,7 +1151,7 @@ ValueRange FilterIterator::forwardImpl(OpBuilder &b, Location l) {
genCrdNotLegitPredicate(b, l, wrapCrd);
Value crd = fromWrapCrd(b, l, wrapCrd);
Value ret = ANDI(CMPI(ult, crd, size), notLegit);
ret = ORI(ret, isFirst.front());
ret = ORI(ret, llvm::getSingleElement(isFirst));
return {ret};
});
b.create<scf::ConditionOp>(l, cont.front(), ivs);
Expand Down Expand Up @@ -1200,8 +1196,7 @@ Value SubSectIterHelper::genNotEnd(OpBuilder &b, Location l) {
// crd < size
return {CMPI(ult, crd, subSect.subSectSz)};
});
assert(r.size() == 1);
return r.front();
return llvm::getSingleElement(r);
}

Value SubSectIterHelper::deref(OpBuilder &b, Location l) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -833,8 +833,7 @@ makeReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder,
moduleTranslation, &phis)))
return llvm::createStringError(
"failed to inline `combiner` region of `omp.declare_reduction`");
assert(phis.size() == 1);
result = phis[0];
result = llvm::getSingleElement(phis);
return builder.saveIP();
};
return gen;
Expand Down
Loading