Skip to content

Commit 677ada2

Browse files
committed
Address antiagainst's PR comment
1 parent eee2ea9 commit 677ada2

File tree

4 files changed

+38
-17
lines changed

4 files changed

+38
-17
lines changed

mlir/include/mlir/IR/Dialect.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,9 @@ class Dialect {
216216
{TypeID::get<ConcreteT>(), InterfaceT::getInterfaceID()});
217217
}
218218

219+
// Declare the same interface for multiple types.
220+
// Example:
221+
// declarePromisedInterfaces<FunctionOpInterface, MyFuncType1, MyFuncType2>()
219222
template <typename InterfaceT, typename... ConcreteT>
220223
void declarePromisedInterfaces() {
221224
(declarePromisedInterface<ConcreteT, InterfaceT>(), ...);

mlir/lib/Dialect/Linalg/Transforms/AllInterfaces.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//===- AllInterfaces.cpp - --------------------------------------*- C++ -*-===//
1+
//===- AllInterfaces.cpp - ------------------------------------------------===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.

mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,15 @@ static ReductionKind getReductionKind(Operation *op) {
5959
.Case([](arith::OrIOp op) { return ReductionKind::BitwiseOr; })
6060
.Case([](arith::XOrIOp op) { return ReductionKind::BitwiseXor; })
6161
.Case([](arith::AndIOp op) { return ReductionKind::Sum; })
62+
// TODO: handle signless, signed and unsigned types properly.
63+
// It is assumed that the element type of the collective operands and
64+
// result drive the meaning of the reduction kind, whether it is signed
65+
// or unsigned.
66+
// The reduction op inside the linalg op may have different result type
67+
// from the element type of the linalg op's result.
68+
// Also signed and unsigned Arith dialect ops may accept signed, unsigned
69+
// or signless operands.
70+
// Maybe expand the reduction kinds.
6271
.Case([](arith::MaxUIOp op) { return ReductionKind::Max; })
6372
.Case([](arith::MinUIOp op) { return ReductionKind::Min; })
6473
.Case([](arith::MaxSIOp op) { return ReductionKind::Max; })
@@ -67,7 +76,7 @@ static ReductionKind getReductionKind(Operation *op) {
6776
.Default([](Operation *op) { return ReductionKind::Generic; });
6877
}
6978

70-
static std::optional<Operation *> getReductionOp(LinalgOp op) {
79+
static std::optional<Operation *> getCombinerOp(LinalgOp op) {
7180
SmallVector<Operation *> combinerOps;
7281
Value reducedValue = matchReduction(op.getRegionOutputArgs(), 0, combinerOps);
7382
if (!reducedValue || combinerOps.size() != 1) {
@@ -78,10 +87,16 @@ static std::optional<Operation *> getReductionOp(LinalgOp op) {
7887
}
7988

8089
static ReductionKind getReductionKindOfLinalgOp(LinalgOp op) {
81-
std::optional<Operation *> reductionOp = getReductionOp(op);
90+
std::optional<Operation *> reductionOp = getCombinerOp(op);
8291
if (!reductionOp) {
8392
return ReductionKind::Generic;
8493
}
94+
Type resultElementType =
95+
llvm::cast<RankedTensorType>(op->getResult(0).getType()).getElementType();
96+
// TODO: handle case when result type of the reduction op does not match the
97+
// element type of the result tensor.
98+
// Would it makes sense at all?
99+
assert(resultElementType == reductionOp.value()->getResult(0).getType());
85100
return getReductionKind(reductionOp.value());
86101
}
87102

@@ -276,9 +291,8 @@ struct StructuredOpShardingInterface
276291
});
277292
if (!allIndexingMapsAreProjectedPermutation) {
278293
// TODO: handle non-projected permutations.
279-
op->emitOpError()
280-
<< "Only projected permutation indexing maps are supported.";
281-
return failure();
294+
return op->emitOpError()
295+
<< "supports indexing maps that are only projected permutation.";
282296
}
283297

284298
SmallVector<utils::IteratorType> loopIteratorTypes =

mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -539,8 +539,9 @@ static bool areValuesCompatibleWithFullReplicationShardings(
539539
if (std::size(values) != std::size(shardings)) {
540540
return false;
541541
}
542-
return llvm::all_of(llvm::zip(std::forward<ValueRange>(values),
543-
std::forward<MeshShardingAttrRage>(shardings)),
542+
return llvm::all_of(llvm::zip_equal(
543+
std::forward<ValueRange>(values),
544+
std::forward<MeshShardingAttrRage>(shardings)),
544545
[](auto valueAndSharding) {
545546
return isValueCompatibleWithFullReplicationSharding(
546547
std::get<0>(valueAndSharding),
@@ -588,11 +589,9 @@ ShardingArray mesh::getMeshAxisAssignmentForLoopIterators(
588589
SmallVector<MeshShardingAttr> operatorAndResultShardings;
589590
operatorAndResultShardings.reserve(operandShardings.size() +
590591
resultShardings.size());
591-
operatorAndResultShardings.insert(operatorAndResultShardings.end(),
592-
operandShardings.begin(),
593-
operandShardings.end());
592+
llvm::append_range(operatorAndResultShardings, operandShardings);
594593
for (auto [sharding, affineMap] :
595-
llvm::zip(operatorAndResultShardings, indexingMaps)) {
594+
llvm::zip_equal(operatorAndResultShardings, indexingMaps)) {
596595
if (!sharding) {
597596
continue;
598597
}
@@ -602,6 +601,12 @@ ShardingArray mesh::getMeshAxisAssignmentForLoopIterators(
602601
meshAxesAssignmentForTensorAxis.asArrayRef(), indexingExpr,
603602
meshAxisAssignmentForLoopIterators);
604603
}
604+
// Missing trailing split axes means replication on those tensor dimensions.
605+
for (unsigned i = sharding.getSplitAxes().size();
606+
i < affineMap.getNumResults(); ++i) {
607+
updateMeshAxisAssignmentForLoopIterators(
608+
{}, affineMap.getResults()[i], meshAxisAssignmentForLoopIterators);
609+
}
605610
}
606611

607612
ShardingArray res;
@@ -619,7 +624,7 @@ bool mesh::isAtLeastOneReductionIteratorSharded(
619624
ArrayRef<utils::IteratorType> loopIteratorTypes,
620625
ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators) {
621626
for (auto [loopIteratorType, meshAxisAssignment] :
622-
llvm::zip(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
627+
llvm::zip_equal(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
623628
if (loopIteratorType == utils::IteratorType::reduction &&
624629
!meshAxisAssignment.empty()) {
625630
return true;
@@ -633,10 +638,9 @@ SmallVector<MeshAxis> mesh::getReductionMeshAxes(
633638
ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators) {
634639
SmallVector<MeshAxis> meshAxes;
635640
for (auto [loopIteratorType, meshAxisAssignment] :
636-
llvm::zip(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
641+
llvm::zip_equal(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
637642
if (loopIteratorType == utils::IteratorType::reduction) {
638-
meshAxes.insert(meshAxes.end(), meshAxisAssignment.begin(),
639-
meshAxisAssignment.end());
643+
llvm::append_range(meshAxes, meshAxisAssignment);
640644
}
641645
}
642646
return meshAxes;
@@ -651,7 +655,7 @@ void mesh::spmdizeTriviallyShardableOperation(
651655
Operation *newOp = builder.clone(op, spmdizationMap);
652656
// Set the result types to the sharded counterparts.
653657
for (auto [oldResult, newResult, sharding] :
654-
llvm::zip(op.getResults(), newOp->getResults(), resultShardings)) {
658+
llvm::zip_equal(op.getResults(), newOp->getResults(), resultShardings)) {
655659
newResult.setType(shardType(newResult.getType(),
656660
getMesh(&op, sharding.getMesh(), symbolTable),
657661
sharding));

0 commit comments

Comments
 (0)