@@ -53,7 +53,7 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
53
53
}
54
54
}
55
55
56
- if (isa<AtenMaxOp, AtenMaxDimOp, AtenArgmaxOp>(op)) {
56
+ if (isa<AtenAmaxOp, AtenMaxOp, AtenMaxDimOp, AtenArgmaxOp>(op)) {
57
57
if (isa<mlir::FloatType>(elementTy)) {
58
58
auto constAttr = DenseElementsAttr::get (
59
59
constType,
@@ -121,6 +121,46 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
121
121
return nullptr ;
122
122
}
123
123
124
+ static Value createReduceOpWithSingleRegionOp (Operation *op, Value input,
125
+ Type outTy,
126
+ ArrayRef<int64_t > dims,
127
+ PatternRewriter &rewriter) {
128
+ auto inputTy = dyn_cast<RankedTensorType>(input.getType ());
129
+ if (!inputTy)
130
+ return nullptr ;
131
+ Value initValue =
132
+ createInitialValueForReduceOp (op, inputTy.getElementType (), rewriter);
133
+ if (!initValue)
134
+ return nullptr ;
135
+
136
+ stablehlo::ReduceOp reduce = rewriter.create <stablehlo::ReduceOp>(
137
+ op->getLoc (), outTy, input, initValue,
138
+ rewriter.getDenseI64ArrayAttr (dims));
139
+
140
+ Block &block = reduce.getBody ().emplaceBlock ();
141
+ auto blockArgumentTy = RankedTensorType::get ({}, inputTy.getElementType ());
142
+ block.addArgument (blockArgumentTy, op->getLoc ());
143
+ block.addArgument (blockArgumentTy, op->getLoc ());
144
+ auto *firstArgument = block.args_begin ();
145
+ auto secondArgument = block.args_rbegin ();
146
+
147
+ {
148
+ OpBuilder::InsertionGuard guard (rewriter);
149
+ rewriter.setInsertionPointToStart (&block);
150
+ Value result;
151
+ if (isa<AtenAmaxOp, AtenMaxOp, AtenMaxDimOp>(op)) {
152
+ result = rewriter.create <stablehlo::MaxOp>(
153
+ op->getLoc (), blockArgumentTy, *firstArgument, *secondArgument);
154
+ } else {
155
+ op->emitError (" unimplemented lowering in "
156
+ " createReduceOpWithSingleRegionOp" );
157
+ return nullptr ;
158
+ }
159
+ rewriter.create <stablehlo::ReturnOp>(op->getLoc (), result);
160
+ }
161
+ return reduce.getResults ()[0 ];
162
+ }
163
+
124
164
// Util for converting AtenArgmaxOp and AtenMaxDimOp
125
165
static std::optional<ValueRange>
126
166
getMaxInDim (ConversionPatternRewriter &rewriter, Operation *op, Value &input,
@@ -371,35 +411,64 @@ LogicalResult ConvertAtenReductionOp<AtenMaxDimOp>::matchAndRewrite(
371
411
op, " failed to get dimension sizes of the input" );
372
412
}
373
413
auto inputShapeVec = *inputShapeInfo;
374
- auto stablehloReduceResults = getMaxInDim (rewriter, op, input, inputShapeVec,
375
- dim, options.dimSizeIndexBits )
376
- .value ();
377
414
378
- if (keepDim) {
379
- auto outShapeVec = inputShapeVec;
380
- outShapeVec[dim] = rewriter.create <mlir::arith::ConstantOp>(
381
- op->getLoc (),
382
- rewriter.getIntegerAttr (
383
- rewriter.getIntegerType (options.dimSizeIndexBits ), 1 ));
384
- auto outShapeTensor = rewriter.create <mlir::tensor::FromElementsOp>(
385
- op->getLoc (), outShapeVec);
386
-
387
- auto stablehloReduceValueResult =
388
- rewriter.create <stablehlo::DynamicReshapeOp>(
389
- op->getLoc (), valResultType, stablehloReduceResults[0 ],
390
- outShapeTensor);
391
- auto stablehloReduceIndexResult =
392
- rewriter.create <stablehlo::DynamicReshapeOp>(
393
- op->getLoc (), idxResultType, stablehloReduceResults[1 ],
394
- outShapeTensor);
395
- rewriter.replaceOp (
396
- op, {stablehloReduceValueResult, stablehloReduceIndexResult});
415
+ if (op.getResult (1 ).use_empty ()) {
416
+ llvm::SmallVector<int64_t > outputShape (inputTy.getShape ());
417
+ outputShape.erase (outputShape.begin () + dim);
418
+ Value reduceResult = createReduceOpWithSingleRegionOp (
419
+ op, input, RankedTensorType::get (outputShape, inputElemTy),
420
+ ArrayRef<int64_t >{dim}, rewriter);
421
+ if (!reduceResult)
422
+ return failure ();
423
+
424
+ if (keepDim) {
425
+ auto outShapeVec = inputShapeVec;
426
+ outShapeVec[dim] = rewriter.create <mlir::arith::ConstantOp>(
427
+ op->getLoc (),
428
+ rewriter.getIntegerAttr (
429
+ rewriter.getIntegerType (options.dimSizeIndexBits ), 1 ));
430
+ auto outShapeTensor = rewriter.create <mlir::tensor::FromElementsOp>(
431
+ op->getLoc (), outShapeVec);
432
+
433
+ auto stablehloReduceValueResult =
434
+ rewriter.create <stablehlo::DynamicReshapeOp>(
435
+ op->getLoc (), valResultType, reduceResult, outShapeTensor);
436
+ rewriter.replaceOp (op, {stablehloReduceValueResult, Value ()});
437
+ return success ();
438
+ }
439
+ rewriter.replaceOp (op, {reduceResult, Value ()});
440
+ return success ();
441
+ } else {
442
+ auto stablehloReduceResults =
443
+ getMaxInDim (rewriter, op, input, inputShapeVec, dim,
444
+ options.dimSizeIndexBits )
445
+ .value ();
446
+
447
+ if (keepDim) {
448
+ auto outShapeVec = inputShapeVec;
449
+ outShapeVec[dim] = rewriter.create <mlir::arith::ConstantOp>(
450
+ op->getLoc (),
451
+ rewriter.getIntegerAttr (
452
+ rewriter.getIntegerType (options.dimSizeIndexBits ), 1 ));
453
+ auto outShapeTensor = rewriter.create <mlir::tensor::FromElementsOp>(
454
+ op->getLoc (), outShapeVec);
455
+
456
+ auto stablehloReduceValueResult =
457
+ rewriter.create <stablehlo::DynamicReshapeOp>(
458
+ op->getLoc (), valResultType, stablehloReduceResults[0 ],
459
+ outShapeTensor);
460
+ auto stablehloReduceIndexResult =
461
+ rewriter.create <stablehlo::DynamicReshapeOp>(
462
+ op->getLoc (), idxResultType, stablehloReduceResults[1 ],
463
+ outShapeTensor);
464
+ rewriter.replaceOp (
465
+ op, {stablehloReduceValueResult, stablehloReduceIndexResult});
466
+ return success ();
467
+ }
468
+ rewriter.replaceOp (op,
469
+ {stablehloReduceResults[0 ], stablehloReduceResults[1 ]});
397
470
return success ();
398
471
}
399
-
400
- rewriter.replaceOp (op,
401
- {stablehloReduceResults[0 ], stablehloReduceResults[1 ]});
402
- return success ();
403
472
}
404
473
} // namespace
405
474
@@ -692,11 +761,11 @@ LogicalResult ConvertAtenReductionOp<AtenProdOp>::matchAndRewrite(
692
761
}
693
762
} // namespace
694
763
695
- // AtenMaxOp
764
+ // AtenAmaxOp
696
765
namespace {
697
766
template <>
698
- LogicalResult ConvertAtenReductionOp<AtenMaxOp >::matchAndRewrite(
699
- AtenMaxOp op, OpAdaptor adaptor,
767
+ LogicalResult ConvertAtenReductionOp<AtenAmaxOp >::matchAndRewrite(
768
+ AtenAmaxOp op, OpAdaptor adaptor,
700
769
ConversionPatternRewriter &rewriter) const {
701
770
Value input = adaptor.getSelf ();
702
771
auto inputTy = dyn_cast<RankedTensorType>(input.getType ());
@@ -717,40 +786,102 @@ LogicalResult ConvertAtenReductionOp<AtenMaxOp>::matchAndRewrite(
717
786
" AtenMaxOp to StableHLO" );
718
787
}
719
788
789
+ bool keepDim = false ;
790
+ if (!matchPattern (op.getKeepdim (), m_TorchConstantBool (&keepDim))) {
791
+ return rewriter.notifyMatchFailure (op, " non-bool keepdim unsupported" );
792
+ }
793
+
794
+ SmallVector<int64_t > inputDims;
720
795
SmallVector<int64_t > dims;
796
+ if (!matchPattern (op.getDim (), m_TorchListOfConstantInts (inputDims))) {
797
+ return rewriter.notifyMatchFailure (
798
+ op, " non-const integer `dim` is not supported" );
799
+ }
800
+ for (auto d : inputDims) {
801
+ d = toPositiveDim (d, inputTy.getRank ());
802
+ // Drop invalid dims
803
+ if (isValidDim (d, inputTy.getRank ())) {
804
+ dims.push_back (d);
805
+ }
806
+ }
807
+ llvm::sort (dims.begin (), dims.end ());
808
+ std::unordered_set<int64_t > dimsSet (dims.begin (), dims.end ());
809
+ SmallVector<int64_t > reduceResultShape;
721
810
for (int64_t i = 0 ; i < inputTy.getRank (); i++) {
722
- dims.push_back (i);
811
+ if (dimsSet.find (i) == dimsSet.end ()) {
812
+ reduceResultShape.push_back (inputTy.getDimSize (i));
813
+ }
723
814
}
724
815
725
- Value initValue =
726
- createInitialValueForReduceOp (op, inputTy.getElementType (), rewriter);
727
- if (!initValue)
816
+ Value reduceResult = createReduceOpWithSingleRegionOp (
817
+ op, input, RankedTensorType::get (reduceResultShape, inputElemTy), dims,
818
+ rewriter);
819
+ if (!reduceResult)
728
820
return failure ();
729
- llvm::sort (dims.begin (), dims.end ());
730
- auto stablehloReduceOp = rewriter.create <stablehlo::ReduceOp>(
731
- op.getLoc (), RankedTensorType::get ({}, inputElemTy), input, initValue,
732
- rewriter.getDenseI64ArrayAttr (dims));
733
821
734
- Block &block = stablehloReduceOp.getBody ().emplaceBlock ();
735
- auto blockArgumentTy = RankedTensorType::get ({}, inputTy.getElementType ());
822
+ if (keepDim) {
823
+ const auto &options = getOptions ();
824
+ auto outShapeInfo =
825
+ hlo::getDimSizesOfTensor (rewriter, op, input, options.dimSizeIndexBits );
826
+ if (failed (outShapeInfo)) {
827
+ return rewriter.notifyMatchFailure (
828
+ op, " failed to get dimension sizes of the input" );
829
+ }
830
+ auto outShapeVec = *outShapeInfo;
831
+ auto one = rewriter.create <mlir::arith::ConstantOp>(
832
+ op->getLoc (),
833
+ rewriter.getIntegerAttr (
834
+ rewriter.getIntegerType (options.dimSizeIndexBits ), 1 ));
835
+ for (int64_t i : dims) {
836
+ outShapeVec[i] = one;
837
+ }
838
+ auto outShapeTensor = rewriter.create <mlir::tensor::FromElementsOp>(
839
+ op->getLoc (), outShapeVec);
840
+ rewriter.replaceOpWithNewOp <stablehlo::DynamicReshapeOp>(
841
+ op, getTypeConverter ()->convertType (op.getType ()), reduceResult,
842
+ outShapeTensor);
843
+ return success ();
844
+ }
845
+ rewriter.replaceOp (op, reduceResult);
846
+ return success ();
847
+ }
848
+ } // namespace
736
849
737
- block.addArgument (blockArgumentTy, op->getLoc ());
738
- block.addArgument (blockArgumentTy, op->getLoc ());
850
+ // AtenMaxOp
851
+ namespace {
852
+ template <>
853
+ LogicalResult ConvertAtenReductionOp<AtenMaxOp>::matchAndRewrite(
854
+ AtenMaxOp op, OpAdaptor adaptor,
855
+ ConversionPatternRewriter &rewriter) const {
856
+ Value input = adaptor.getSelf ();
857
+ auto inputTy = dyn_cast<RankedTensorType>(input.getType ());
858
+ if (!inputTy) {
859
+ return rewriter.notifyMatchFailure (
860
+ op, " only Tensor types supported in StableHLO" );
861
+ }
862
+ auto inputElemTy = inputTy.getElementType ();
863
+ if (!inputElemTy.isIntOrFloat ()) {
864
+ return op.emitError (
865
+ " only floating-point or integer datatype legalization supported" );
866
+ }
867
+ // Currently, (u)int8 dtype is not supported
868
+ if (isa<mlir::IntegerType>(inputElemTy) &&
869
+ inputElemTy.getIntOrFloatBitWidth () == 8 ) {
870
+ return rewriter.notifyMatchFailure (
871
+ op, " IntegerType with bitwidth 8 unsupported in convertion from "
872
+ " AtenMaxOp to StableHLO" );
873
+ }
739
874
740
- auto *firstArgument = block. args_begin ();
741
- auto secondArgument = block. args_rbegin ( );
875
+ SmallVector< int64_t > dims =
876
+ llvm::to_vector (llvm::seq< int64_t >( 0 , inputTy. getRank ()) );
742
877
743
- {
744
- OpBuilder::InsertionGuard guard (rewriter);
745
- rewriter.setInsertionPointToStart (&block);
746
- Value maxResult = rewriter.create <stablehlo::MaxOp>(
747
- op->getLoc (), blockArgumentTy, *firstArgument, *secondArgument);
748
- rewriter.create <stablehlo::ReturnOp>(op->getLoc (), maxResult);
749
- }
878
+ Value reduceResult = createReduceOpWithSingleRegionOp (
879
+ op, input, RankedTensorType::get ({}, inputElemTy), dims, rewriter);
880
+ if (!reduceResult)
881
+ return failure ();
750
882
751
883
rewriter.replaceOpWithNewOp <tensor::CastOp>(
752
- op, getTypeConverter ()->convertType (op.getType ()),
753
- stablehloReduceOp.getResults ());
884
+ op, getTypeConverter ()->convertType (op.getType ()), reduceResult);
754
885
return success ();
755
886
}
756
887
} // namespace
@@ -1205,6 +1336,7 @@ void mlir::torch::torch_to_stablehlo::populateReductionOpPatternsAndLegality(
1205
1336
patterns.add <ConvertAtenReductionOp<AtenOp>>(typeConverter, context, options)
1206
1337
INSERT_ATEN_REDUCTION_OP_PATTERN (AtenArgmaxOp);
1207
1338
INSERT_ATEN_REDUCTION_OP_PATTERN (AtenMaxDimOp);
1339
+ INSERT_ATEN_REDUCTION_OP_PATTERN (AtenAmaxOp);
1208
1340
INSERT_ATEN_REDUCTION_OP_PATTERN (AtenSumDimIntListOp);
1209
1341
INSERT_ATEN_REDUCTION_OP_PATTERN (AtenSumOp);
1210
1342
INSERT_ATEN_REDUCTION_OP_PATTERN (AtenProdOp);
0 commit comments