Skip to content

Commit 530de07

Browse files
[llvm] Add getSingleElement helper and use in MLIR
1 parent 6a030b3 commit 530de07

File tree

23 files changed

+67
-102
lines changed

23 files changed

+67
-102
lines changed

llvm/include/llvm/ADT/STLExtras.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,14 @@ template <typename ContainerTy> bool hasSingleElement(ContainerTy &&C) {
325325
return B != E && std::next(B) == E;
326326
}
327327

328+
/// Asserts that the given container has a single element and returns that
329+
/// element.
330+
template <typename ContainerTy>
331+
decltype(auto) getSingleElement(ContainerTy &&C) {
332+
assert(hasSingleElement(C) && "expected container with single element");
333+
return *adl_begin(C);
334+
}
335+
328336
/// Return a range covering \p RangeOrContainer with the first N elements
329337
/// excluded.
330338
template <typename T> auto drop_begin(T &&RangeOrContainer, size_t N = 1) {

mlir/include/mlir/Dialect/CommonFolders.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -196,8 +196,7 @@ template <class AttrElementT,
196196
function_ref<std::optional<ElementValueT>(ElementValueT)>>
197197
Attribute constFoldUnaryOpConditional(ArrayRef<Attribute> operands,
198198
CalculationT &&calculate) {
199-
assert(operands.size() == 1 && "unary op takes one operands");
200-
if (!operands[0])
199+
if (!llvm::getSingleElement(operands))
201200
return {};
202201

203202
static_assert(
@@ -268,8 +267,7 @@ template <
268267
class CalculationT = function_ref<TargetElementValueT(ElementValueT, bool)>>
269268
Attribute constFoldCastOp(ArrayRef<Attribute> operands, Type resType,
270269
CalculationT &&calculate) {
271-
assert(operands.size() == 1 && "Cast op takes one operand");
272-
if (!operands[0])
270+
if (!llvm::getSingleElement(operands))
273271
return {};
274272

275273
static_assert(

mlir/lib/Analysis/SliceAnalysis.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ static void getBackwardSliceImpl(Operation *op,
107107
// into us. For now, just bail.
108108
if (parentOp && backwardSlice->count(parentOp) == 0) {
109109
assert(parentOp->getNumRegions() == 1 &&
110-
parentOp->getRegion(0).getBlocks().size() == 1);
110+
llvm::hasSingleElement(parentOp->getRegion(0).getBlocks()));
111111
getBackwardSliceImpl(parentOp, backwardSlice, options);
112112
}
113113
} else {

mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -834,8 +834,7 @@ struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
834834
LogicalResult
835835
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
836836
ConversionPatternRewriter &rewriter) const override {
837-
assert(adaptor.getOperands().size() == 1);
838-
Type srcType = adaptor.getOperands().front().getType();
837+
Type srcType = llvm::getSingleElement(adaptor.getOperands()).getType();
839838
Type dstType = this->getTypeConverter()->convertType(op.getType());
840839
if (!dstType)
841840
return getTypeConversionFailure(rewriter, op);

mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,7 @@ struct WmmaConstantOpToSPIRVLowering final
101101
LogicalResult
102102
matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp op, OpAdaptor adaptor,
103103
ConversionPatternRewriter &rewriter) const override {
104-
assert(adaptor.getOperands().size() == 1);
105-
Value cst = adaptor.getOperands().front();
104+
Value cst = llvm::getSingleElement(adaptor.getOperands());
106105
auto coopType = getTypeConverter()->convertType(op.getType());
107106
if (!coopType)
108107
return rewriter.notifyMatchFailure(op, "type conversion failed");
@@ -181,8 +180,7 @@ struct WmmaElementwiseOpToSPIRVScalarMulLowering final
181180
"splat is not a composite construct");
182181
}
183182

184-
assert(cc.getConstituents().size() == 1);
185-
scalar = cc.getConstituents().front();
183+
scalar = llvm::getSingleElement(cc.getConstituents());
186184

187185
auto coopType = getTypeConverter()->convertType(op.getType());
188186
if (!coopType)

mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -419,13 +419,11 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
419419
SmallVector<Value> dynDims, dynDevice;
420420
for (auto dim : adaptor.getDimsDynamic()) {
421421
// type conversion should be 1:1 for ints
422-
assert(dim.size() == 1);
423-
dynDims.emplace_back(dim[0]);
422+
dynDims.emplace_back(llvm::getSingleElement(dim));
424423
}
425424
// same for device
426425
for (auto device : adaptor.getDeviceDynamic()) {
427-
assert(device.size() == 1);
428-
dynDevice.emplace_back(device[0]);
426+
dynDevice.emplace_back(llvm::getSingleElement(device));
429427
}
430428

431429
// To keep the code simple, convert dims/device to values when they are
@@ -771,18 +769,17 @@ struct ConvertMeshToMPIPass
771769
typeConverter.addConversion([](Type type) { return type; });
772770

773771
// convert mesh::ShardingType to a tuple of RankedTensorTypes
774-
typeConverter.addConversion(
775-
[](ShardingType type,
776-
SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
777-
auto i16 = IntegerType::get(type.getContext(), 16);
778-
auto i64 = IntegerType::get(type.getContext(), 64);
779-
std::array<int64_t, 2> shp = {ShapedType::kDynamic,
780-
ShapedType::kDynamic};
781-
results.emplace_back(RankedTensorType::get(shp, i16));
782-
results.emplace_back(RankedTensorType::get(shp, i64)); // actually ?x2
783-
results.emplace_back(RankedTensorType::get(shp, i64));
784-
return success();
785-
});
772+
typeConverter.addConversion([](ShardingType type,
773+
SmallVectorImpl<Type> &results)
774+
-> std::optional<LogicalResult> {
775+
auto i16 = IntegerType::get(type.getContext(), 16);
776+
auto i64 = IntegerType::get(type.getContext(), 64);
777+
std::array<int64_t, 2> shp = {ShapedType::kDynamic, ShapedType::kDynamic};
778+
results.emplace_back(RankedTensorType::get(shp, i16));
779+
results.emplace_back(RankedTensorType::get(shp, i64)); // actually ?x2
780+
results.emplace_back(RankedTensorType::get(shp, i64));
781+
return success();
782+
});
786783

787784
// To 'extract' components, a UnrealizedConversionCastOp is expected
788785
// to define the input

mlir/lib/Dialect/Affine/IR/AffineOps.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1236,8 +1236,7 @@ mlir::affine::makeComposedFoldedAffineApply(OpBuilder &b, Location loc,
12361236
}
12371237

12381238
applyOp->erase();
1239-
assert(foldResults.size() == 1 && "expected 1 folded result");
1240-
return foldResults.front();
1239+
return llvm::getSingleElement(foldResults);
12411240
}
12421241

12431242
OpFoldResult
@@ -1306,8 +1305,7 @@ static OpFoldResult makeComposedFoldedMinMax(OpBuilder &b, Location loc,
13061305
}
13071306

13081307
minMaxOp->erase();
1309-
assert(foldResults.size() == 1 && "expected 1 folded result");
1310-
return foldResults.front();
1308+
return llvm::getSingleElement(foldResults);
13111309
}
13121310

13131311
OpFoldResult

mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1249,8 +1249,7 @@ struct GreedyFusion {
12491249
SmallVector<Operation *, 2> sibLoadOpInsts;
12501250
sibNode->getLoadOpsForMemref(memref, &sibLoadOpInsts);
12511251
// Currently findSiblingNodeToFuse searches for siblings with one load.
1252-
assert(sibLoadOpInsts.size() == 1);
1253-
Operation *sibLoadOpInst = sibLoadOpInsts[0];
1252+
Operation *sibLoadOpInst = llvm::getSingleElement(sibLoadOpInsts);
12541253

12551254
// Gather 'dstNode' load ops to 'memref'.
12561255
SmallVector<Operation *, 2> dstLoadOpInsts;

mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1604,10 +1604,8 @@ SmallVector<AffineForOp, 8> mlir::affine::tile(ArrayRef<AffineForOp> forOps,
16041604
ArrayRef<uint64_t> sizes,
16051605
AffineForOp target) {
16061606
SmallVector<AffineForOp, 8> res;
1607-
for (auto loops : tile(forOps, sizes, ArrayRef<AffineForOp>(target))) {
1608-
assert(loops.size() == 1);
1609-
res.push_back(loops[0]);
1610-
}
1607+
for (auto loops : tile(forOps, sizes, ArrayRef<AffineForOp>(target)))
1608+
res.push_back(llvm::getSingleElement(loops));
16111609
return res;
16121610
}
16131611

mlir/lib/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.cpp

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,30 +44,27 @@ struct LinalgCopyOpInterface
4444
linalg::CopyOp> {
4545
OpOperand &getSourceOperand(Operation *op) const {
4646
auto copyOp = cast<CopyOp>(op);
47-
assert(copyOp.getInputs().size() == 1 && "expected single input");
48-
return copyOp.getInputsMutable()[0];
47+
return llvm::getSingleElement(copyOp.getInputsMutable());
4948
}
5049

5150
bool
5251
isEquivalentSubset(Operation *op, Value candidate,
5352
function_ref<bool(Value, Value)> equivalenceFn) const {
5453
auto copyOp = cast<CopyOp>(op);
55-
assert(copyOp.getOutputs().size() == 1 && "expected single output");
56-
return equivalenceFn(candidate, copyOp.getOutputs()[0]);
54+
return equivalenceFn(candidate,
55+
llvm::getSingleElement(copyOp.getOutputs()));
5756
}
5857

5958
Value buildSubsetExtraction(Operation *op, OpBuilder &builder,
6059
Location loc) const {
6160
auto copyOp = cast<CopyOp>(op);
62-
assert(copyOp.getOutputs().size() == 1 && "expected single output");
63-
return copyOp.getOutputs()[0];
61+
return llvm::getSingleElement(copyOp.getOutputs());
6462
}
6563

6664
SmallVector<Value>
6765
getValuesNeededToBuildSubsetExtraction(Operation *op) const {
6866
auto copyOp = cast<CopyOp>(op);
69-
assert(copyOp.getOutputs().size() == 1 && "expected single output");
70-
return {copyOp.getOutputs()[0]};
67+
return {llvm::getSingleElement(copyOp.getOutputs())};
7168
}
7269
};
7370
} // namespace

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -471,7 +471,7 @@ static bool isOpItselfPotentialAutomaticAllocation(Operation *op) {
471471
/// extending the lifetime of allocations.
472472
static bool lastNonTerminatorInRegion(Operation *op) {
473473
return op->getNextNode() == op->getBlock()->getTerminator() &&
474-
op->getParentRegion()->getBlocks().size() == 1;
474+
llvm::hasSingleElement(op->getParentRegion()->getBlocks());
475475
}
476476

477477
/// Inline an AllocaScopeOp if either the direct parent is an allocation scope

mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ class QuantizedTypeConverter : public TypeConverter {
4646

4747
static Value materializeConversion(OpBuilder &builder, Type type,
4848
ValueRange inputs, Location loc) {
49-
assert(inputs.size() == 1);
50-
return builder.create<quant::StorageCastOp>(loc, type, inputs[0]);
49+
return builder.create<quant::StorageCastOp>(loc, type,
50+
llvm::getSingleElement(inputs));
5151
}
5252

5353
public:

mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ static Value castBuffer(OpBuilder &b, Value buffer, Type type) {
5252
static bool doesNotAliasExternalValue(Value value, Region *region,
5353
ValueRange exceptions,
5454
const OneShotAnalysisState &state) {
55-
assert(region->getBlocks().size() == 1 &&
55+
assert(llvm::hasSingleElement(region->getBlocks()) &&
5656
"expected region with single block");
5757
bool result = true;
5858
state.applyOnAliases(value, [&](Value alias) {

mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,6 @@ static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
2424
return result;
2525
}
2626

27-
/// Assert that the given value range contains a single value and return it.
28-
static Value getSingleValue(ValueRange values) {
29-
assert(values.size() == 1 && "expected single value");
30-
return values.front();
31-
}
32-
3327
// CRTP
3428
// A base class that takes care of 1:N type conversion, which maps the converted
3529
// op results (computed by the derived class) and materializes 1:N conversion.
@@ -119,9 +113,9 @@ class ConvertForOpTypes
119113
// We can not do clone as the number of result types after conversion
120114
// might be different.
121115
ForOp newOp = rewriter.create<ForOp>(
122-
op.getLoc(), getSingleValue(adaptor.getLowerBound()),
123-
getSingleValue(adaptor.getUpperBound()),
124-
getSingleValue(adaptor.getStep()),
116+
op.getLoc(), llvm::getSingleElement(adaptor.getLowerBound()),
117+
llvm::getSingleElement(adaptor.getUpperBound()),
118+
llvm::getSingleElement(adaptor.getStep()),
125119
flattenValues(adaptor.getInitArgs()));
126120

127121
// Reserve whatever attributes in the original op.
@@ -149,7 +143,8 @@ class ConvertIfOpTypes
149143
TypeRange dstTypes) const {
150144

151145
IfOp newOp = rewriter.create<IfOp>(
152-
op.getLoc(), dstTypes, getSingleValue(adaptor.getCondition()), true);
146+
op.getLoc(), dstTypes, llvm::getSingleElement(adaptor.getCondition()),
147+
true);
153148
newOp->setAttrs(op->getAttrs());
154149

155150
// We do not need the empty blocks created by rewriter.

mlir/lib/Dialect/SCF/Utils/Utils.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1310,10 +1310,8 @@ SmallVector<Loops, 8> mlir::tile(ArrayRef<scf::ForOp> forOps,
13101310
Loops mlir::tile(ArrayRef<scf::ForOp> forOps, ArrayRef<Value> sizes,
13111311
scf::ForOp target) {
13121312
SmallVector<scf::ForOp, 8> res;
1313-
for (auto loops : tile(forOps, sizes, ArrayRef<scf::ForOp>(target))) {
1314-
assert(loops.size() == 1);
1315-
res.push_back(loops[0]);
1316-
}
1313+
for (auto loops : tile(forOps, sizes, ArrayRef<scf::ForOp>(target)))
1314+
res.push_back(llvm::getSingleElement(loops));
13171315
return res;
13181316
}
13191317

mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ struct AssumingOpInterface
3838
size_t resultNum = std::distance(op->getOpResults().begin(),
3939
llvm::find(op->getOpResults(), value));
4040
// TODO: Support multiple blocks.
41-
assert(assumingOp.getDoRegion().getBlocks().size() == 1 &&
41+
assert(llvm::hasSingleElement(assumingOp.getDoRegion().getBlocks()) &&
4242
"expected exactly 1 block");
4343
auto yieldOp = dyn_cast<shape::AssumingYieldOp>(
4444
assumingOp.getDoRegion().front().getTerminator());
@@ -49,7 +49,7 @@ struct AssumingOpInterface
4949
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
5050
const BufferizationOptions &options) const {
5151
auto assumingOp = cast<shape::AssumingOp>(op);
52-
assert(assumingOp.getDoRegion().getBlocks().size() == 1 &&
52+
assert(llvm::hasSingleElement(assumingOp.getDoRegion().getBlocks()) &&
5353
"only 1 block supported");
5454
auto yieldOp = cast<shape::AssumingYieldOp>(
5555
assumingOp.getDoRegion().front().getTerminator());

mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,6 @@
1212
using namespace mlir;
1313
using namespace mlir::sparse_tensor;
1414

15-
/// Assert that the given value range contains a single value and return it.
16-
static Value getSingleValue(ValueRange values) {
17-
assert(values.size() == 1 && "expected single value");
18-
return values.front();
19-
}
20-
2115
static void convertLevelType(SparseTensorEncodingAttr enc, Level lvl,
2216
SmallVectorImpl<Type> &fields) {
2317
// Position and coordinate buffer in the sparse structure.
@@ -200,7 +194,7 @@ class ExtractIterSpaceConverter
200194

201195
// Construct the iteration space.
202196
SparseIterationSpace space(loc, rewriter,
203-
getSingleValue(adaptor.getTensor()), 0,
197+
llvm::getSingleElement(adaptor.getTensor()), 0,
204198
op.getLvlRange(), adaptor.getParentIter());
205199

206200
SmallVector<Value> result = space.toValues();
@@ -218,8 +212,8 @@ class ExtractValOpConverter : public OpConversionPattern<ExtractValOp> {
218212
ConversionPatternRewriter &rewriter) const override {
219213
Location loc = op.getLoc();
220214
Value pos = adaptor.getIterator().back();
221-
Value valBuf =
222-
rewriter.create<ToValuesOp>(loc, getSingleValue(adaptor.getTensor()));
215+
Value valBuf = rewriter.create<ToValuesOp>(
216+
loc, llvm::getSingleElement(adaptor.getTensor()));
223217
rewriter.replaceOpWithNewOp<memref::LoadOp>(op, valBuf, pos);
224218
return success();
225219
}

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,6 @@ static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
4747
return result;
4848
}
4949

50-
/// Assert that the given value range contains a single value and return it.
51-
static Value getSingleValue(ValueRange values) {
52-
assert(values.size() == 1 && "expected single value");
53-
return values.front();
54-
}
55-
5650
/// Generates a load with proper `index` typing.
5751
static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx) {
5852
idx = genCast(builder, loc, idx, builder.getIndexType());
@@ -962,10 +956,10 @@ class SparseCompressConverter : public OpConversionPattern<CompressOp> {
962956
SmallVector<Value> fields;
963957
auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields,
964958
op.getTensor().getType());
965-
Value values = getSingleValue(adaptor.getValues());
966-
Value filled = getSingleValue(adaptor.getFilled());
967-
Value added = getSingleValue(adaptor.getAdded());
968-
Value count = getSingleValue(adaptor.getCount());
959+
Value values = llvm::getSingleElement(adaptor.getValues());
960+
Value filled = llvm::getSingleElement(adaptor.getFilled());
961+
Value added = llvm::getSingleElement(adaptor.getAdded());
962+
Value count = llvm::getSingleElement(adaptor.getCount());
969963
const SparseTensorType dstType(desc.getRankedTensorType());
970964
Type eltType = dstType.getElementType();
971965

@@ -1041,7 +1035,7 @@ class SparseInsertConverter : public OpConversionPattern<tensor::InsertOp> {
10411035
SmallVector<Value> params = llvm::to_vector(desc.getFields());
10421036
SmallVector<Value> flatIndices = flattenValues(adaptor.getIndices());
10431037
params.append(flatIndices.begin(), flatIndices.end());
1044-
params.push_back(getSingleValue(adaptor.getScalar()));
1038+
params.push_back(llvm::getSingleElement(adaptor.getScalar()));
10451039
SparseInsertGenerator insertGen(op.getDest().getType(), flatSpTensorTps,
10461040
params, /*genCall=*/true);
10471041
SmallVector<Value> ret = insertGen.genCallOrInline(rewriter, loc);

mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -521,9 +521,8 @@ static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, ExprId exp) {
521521
Value ptr = genSubscript(env, builder, t, args);
522522
if (llvm::isa<TensorType>(ptr.getType())) {
523523
assert(env.options().sparseEmitStrategy ==
524-
SparseEmitStrategy::kSparseIterator &&
525-
args.size() == 1);
526-
return builder.create<ExtractValOp>(loc, ptr, args.front());
524+
SparseEmitStrategy::kSparseIterator);
525+
return builder.create<ExtractValOp>(loc, ptr, llvm::getSingleElement(args));
527526
}
528527
return builder.create<memref::LoadOp>(loc, ptr, args);
529528
}

0 commit comments

Comments
 (0)