Skip to content

Commit 5928f68

Browse files
authored
[Stablehlo] refactor amax, max, max.dim's lowering to stablehlo (#3348)
* not to decompose `aten.amax` on `stablehlo` backend. Because it could be lowering to `stablehlo.reduce` directly. * lowering `aten.max.dim` to `stablehlo.reduce apply max` when `AtenMaxDimOp.getIndices()` doesn't have users. It's more simple.
1 parent 6b95dd4 commit 5928f68

File tree

2 files changed

+186
-54
lines changed

2 files changed

+186
-54
lines changed

lib/Conversion/TorchToStablehlo/Reduction.cpp

Lines changed: 185 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
5353
}
5454
}
5555

56-
if (isa<AtenMaxOp, AtenMaxDimOp, AtenArgmaxOp>(op)) {
56+
if (isa<AtenAmaxOp, AtenMaxOp, AtenMaxDimOp, AtenArgmaxOp>(op)) {
5757
if (isa<mlir::FloatType>(elementTy)) {
5858
auto constAttr = DenseElementsAttr::get(
5959
constType,
@@ -121,6 +121,46 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
121121
return nullptr;
122122
}
123123

124+
static Value createReduceOpWithSingleRegionOp(Operation *op, Value input,
125+
Type outTy,
126+
ArrayRef<int64_t> dims,
127+
PatternRewriter &rewriter) {
128+
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
129+
if (!inputTy)
130+
return nullptr;
131+
Value initValue =
132+
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
133+
if (!initValue)
134+
return nullptr;
135+
136+
stablehlo::ReduceOp reduce = rewriter.create<stablehlo::ReduceOp>(
137+
op->getLoc(), outTy, input, initValue,
138+
rewriter.getDenseI64ArrayAttr(dims));
139+
140+
Block &block = reduce.getBody().emplaceBlock();
141+
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
142+
block.addArgument(blockArgumentTy, op->getLoc());
143+
block.addArgument(blockArgumentTy, op->getLoc());
144+
auto *firstArgument = block.args_begin();
145+
auto secondArgument = block.args_rbegin();
146+
147+
{
148+
OpBuilder::InsertionGuard guard(rewriter);
149+
rewriter.setInsertionPointToStart(&block);
150+
Value result;
151+
if (isa<AtenAmaxOp, AtenMaxOp, AtenMaxDimOp>(op)) {
152+
result = rewriter.create<stablehlo::MaxOp>(
153+
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
154+
} else {
155+
op->emitError("unimplemented lowering in "
156+
"createReduceOpWithSingleRegionOp");
157+
return nullptr;
158+
}
159+
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), result);
160+
}
161+
return reduce.getResults()[0];
162+
}
163+
124164
// Util for converting AtenArgmaxOp and AtenMaxDimOp
125165
static std::optional<ValueRange>
126166
getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
@@ -371,35 +411,64 @@ LogicalResult ConvertAtenReductionOp<AtenMaxDimOp>::matchAndRewrite(
371411
op, "failed to get dimension sizes of the input");
372412
}
373413
auto inputShapeVec = *inputShapeInfo;
374-
auto stablehloReduceResults = getMaxInDim(rewriter, op, input, inputShapeVec,
375-
dim, options.dimSizeIndexBits)
376-
.value();
377414

378-
if (keepDim) {
379-
auto outShapeVec = inputShapeVec;
380-
outShapeVec[dim] = rewriter.create<mlir::arith::ConstantOp>(
381-
op->getLoc(),
382-
rewriter.getIntegerAttr(
383-
rewriter.getIntegerType(options.dimSizeIndexBits), 1));
384-
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
385-
op->getLoc(), outShapeVec);
386-
387-
auto stablehloReduceValueResult =
388-
rewriter.create<stablehlo::DynamicReshapeOp>(
389-
op->getLoc(), valResultType, stablehloReduceResults[0],
390-
outShapeTensor);
391-
auto stablehloReduceIndexResult =
392-
rewriter.create<stablehlo::DynamicReshapeOp>(
393-
op->getLoc(), idxResultType, stablehloReduceResults[1],
394-
outShapeTensor);
395-
rewriter.replaceOp(
396-
op, {stablehloReduceValueResult, stablehloReduceIndexResult});
415+
if (op.getResult(1).use_empty()) {
416+
llvm::SmallVector<int64_t> outputShape(inputTy.getShape());
417+
outputShape.erase(outputShape.begin() + dim);
418+
Value reduceResult = createReduceOpWithSingleRegionOp(
419+
op, input, RankedTensorType::get(outputShape, inputElemTy),
420+
ArrayRef<int64_t>{dim}, rewriter);
421+
if (!reduceResult)
422+
return failure();
423+
424+
if (keepDim) {
425+
auto outShapeVec = inputShapeVec;
426+
outShapeVec[dim] = rewriter.create<mlir::arith::ConstantOp>(
427+
op->getLoc(),
428+
rewriter.getIntegerAttr(
429+
rewriter.getIntegerType(options.dimSizeIndexBits), 1));
430+
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
431+
op->getLoc(), outShapeVec);
432+
433+
auto stablehloReduceValueResult =
434+
rewriter.create<stablehlo::DynamicReshapeOp>(
435+
op->getLoc(), valResultType, reduceResult, outShapeTensor);
436+
rewriter.replaceOp(op, {stablehloReduceValueResult, Value()});
437+
return success();
438+
}
439+
rewriter.replaceOp(op, {reduceResult, Value()});
440+
return success();
441+
} else {
442+
auto stablehloReduceResults =
443+
getMaxInDim(rewriter, op, input, inputShapeVec, dim,
444+
options.dimSizeIndexBits)
445+
.value();
446+
447+
if (keepDim) {
448+
auto outShapeVec = inputShapeVec;
449+
outShapeVec[dim] = rewriter.create<mlir::arith::ConstantOp>(
450+
op->getLoc(),
451+
rewriter.getIntegerAttr(
452+
rewriter.getIntegerType(options.dimSizeIndexBits), 1));
453+
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
454+
op->getLoc(), outShapeVec);
455+
456+
auto stablehloReduceValueResult =
457+
rewriter.create<stablehlo::DynamicReshapeOp>(
458+
op->getLoc(), valResultType, stablehloReduceResults[0],
459+
outShapeTensor);
460+
auto stablehloReduceIndexResult =
461+
rewriter.create<stablehlo::DynamicReshapeOp>(
462+
op->getLoc(), idxResultType, stablehloReduceResults[1],
463+
outShapeTensor);
464+
rewriter.replaceOp(
465+
op, {stablehloReduceValueResult, stablehloReduceIndexResult});
466+
return success();
467+
}
468+
rewriter.replaceOp(op,
469+
{stablehloReduceResults[0], stablehloReduceResults[1]});
397470
return success();
398471
}
399-
400-
rewriter.replaceOp(op,
401-
{stablehloReduceResults[0], stablehloReduceResults[1]});
402-
return success();
403472
}
404473
} // namespace
405474

@@ -692,11 +761,11 @@ LogicalResult ConvertAtenReductionOp<AtenProdOp>::matchAndRewrite(
692761
}
693762
} // namespace
694763

695-
// AtenMaxOp
764+
// AtenAmaxOp
696765
namespace {
697766
template <>
698-
LogicalResult ConvertAtenReductionOp<AtenMaxOp>::matchAndRewrite(
699-
AtenMaxOp op, OpAdaptor adaptor,
767+
LogicalResult ConvertAtenReductionOp<AtenAmaxOp>::matchAndRewrite(
768+
AtenAmaxOp op, OpAdaptor adaptor,
700769
ConversionPatternRewriter &rewriter) const {
701770
Value input = adaptor.getSelf();
702771
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
@@ -717,40 +786,102 @@ LogicalResult ConvertAtenReductionOp<AtenMaxOp>::matchAndRewrite(
717786
"AtenMaxOp to StableHLO");
718787
}
719788

789+
bool keepDim = false;
790+
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) {
791+
return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported");
792+
}
793+
794+
SmallVector<int64_t> inputDims;
720795
SmallVector<int64_t> dims;
796+
if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(inputDims))) {
797+
return rewriter.notifyMatchFailure(
798+
op, "non-const integer `dim` is not supported");
799+
}
800+
for (auto d : inputDims) {
801+
d = toPositiveDim(d, inputTy.getRank());
802+
// Drop invalid dims
803+
if (isValidDim(d, inputTy.getRank())) {
804+
dims.push_back(d);
805+
}
806+
}
807+
llvm::sort(dims.begin(), dims.end());
808+
std::unordered_set<int64_t> dimsSet(dims.begin(), dims.end());
809+
SmallVector<int64_t> reduceResultShape;
721810
for (int64_t i = 0; i < inputTy.getRank(); i++) {
722-
dims.push_back(i);
811+
if (dimsSet.find(i) == dimsSet.end()) {
812+
reduceResultShape.push_back(inputTy.getDimSize(i));
813+
}
723814
}
724815

725-
Value initValue =
726-
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
727-
if (!initValue)
816+
Value reduceResult = createReduceOpWithSingleRegionOp(
817+
op, input, RankedTensorType::get(reduceResultShape, inputElemTy), dims,
818+
rewriter);
819+
if (!reduceResult)
728820
return failure();
729-
llvm::sort(dims.begin(), dims.end());
730-
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
731-
op.getLoc(), RankedTensorType::get({}, inputElemTy), input, initValue,
732-
rewriter.getDenseI64ArrayAttr(dims));
733821

734-
Block &block = stablehloReduceOp.getBody().emplaceBlock();
735-
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
822+
if (keepDim) {
823+
const auto &options = getOptions();
824+
auto outShapeInfo =
825+
hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
826+
if (failed(outShapeInfo)) {
827+
return rewriter.notifyMatchFailure(
828+
op, "failed to get dimension sizes of the input");
829+
}
830+
auto outShapeVec = *outShapeInfo;
831+
auto one = rewriter.create<mlir::arith::ConstantOp>(
832+
op->getLoc(),
833+
rewriter.getIntegerAttr(
834+
rewriter.getIntegerType(options.dimSizeIndexBits), 1));
835+
for (int64_t i : dims) {
836+
outShapeVec[i] = one;
837+
}
838+
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
839+
op->getLoc(), outShapeVec);
840+
rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>(
841+
op, getTypeConverter()->convertType(op.getType()), reduceResult,
842+
outShapeTensor);
843+
return success();
844+
}
845+
rewriter.replaceOp(op, reduceResult);
846+
return success();
847+
}
848+
} // namespace
736849

737-
block.addArgument(blockArgumentTy, op->getLoc());
738-
block.addArgument(blockArgumentTy, op->getLoc());
850+
// AtenMaxOp
851+
namespace {
852+
template <>
853+
LogicalResult ConvertAtenReductionOp<AtenMaxOp>::matchAndRewrite(
854+
AtenMaxOp op, OpAdaptor adaptor,
855+
ConversionPatternRewriter &rewriter) const {
856+
Value input = adaptor.getSelf();
857+
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
858+
if (!inputTy) {
859+
return rewriter.notifyMatchFailure(
860+
op, "only Tensor types supported in StableHLO");
861+
}
862+
auto inputElemTy = inputTy.getElementType();
863+
if (!inputElemTy.isIntOrFloat()) {
864+
return op.emitError(
865+
"only floating-point or integer datatype legalization supported");
866+
}
867+
// Currently, (u)int8 dtype is not supported
868+
if (isa<mlir::IntegerType>(inputElemTy) &&
869+
inputElemTy.getIntOrFloatBitWidth() == 8) {
870+
return rewriter.notifyMatchFailure(
871+
op, "IntegerType with bitwidth 8 unsupported in convertion from "
872+
"AtenMaxOp to StableHLO");
873+
}
739874

740-
auto *firstArgument = block.args_begin();
741-
auto secondArgument = block.args_rbegin();
875+
SmallVector<int64_t> dims =
876+
llvm::to_vector(llvm::seq<int64_t>(0, inputTy.getRank()));
742877

743-
{
744-
OpBuilder::InsertionGuard guard(rewriter);
745-
rewriter.setInsertionPointToStart(&block);
746-
Value maxResult = rewriter.create<stablehlo::MaxOp>(
747-
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
748-
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), maxResult);
749-
}
878+
Value reduceResult = createReduceOpWithSingleRegionOp(
879+
op, input, RankedTensorType::get({}, inputElemTy), dims, rewriter);
880+
if (!reduceResult)
881+
return failure();
750882

751883
rewriter.replaceOpWithNewOp<tensor::CastOp>(
752-
op, getTypeConverter()->convertType(op.getType()),
753-
stablehloReduceOp.getResults());
884+
op, getTypeConverter()->convertType(op.getType()), reduceResult);
754885
return success();
755886
}
756887
} // namespace
@@ -1205,6 +1336,7 @@ void mlir::torch::torch_to_stablehlo::populateReductionOpPatternsAndLegality(
12051336
patterns.add<ConvertAtenReductionOp<AtenOp>>(typeConverter, context, options)
12061337
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenArgmaxOp);
12071338
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenMaxDimOp);
1339+
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenAmaxOp);
12081340
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumDimIntListOp);
12091341
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumOp);
12101342
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenProdOp);

projects/pt1/python/torch_mlir/torchscript.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ def _get_for_tracing(
212212
"aten.adaptive_avg_pool2d",
213213
"aten.unflatten.int",
214214
],
215-
OutputType.STABLEHLO: [],
215+
OutputType.STABLEHLO: ["aten.amax"],
216216
}
217217

218218

0 commit comments

Comments
 (0)