Skip to content

Commit 98b6aee

Browse files
authored
[Torch] emit aten.argsort and decompose it to aten.sort (#4027)
1 parent 0271ae1 commit 98b6aee

File tree

8 files changed

+100
-0
lines changed

8 files changed

+100
-0
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15176,6 +15176,31 @@ def Torch_AtenSortOp : Torch_Op<"aten.sort", [
1517615176
let hasFolder = 1;
1517715177
}
1517815178

15179+
def Torch_AtenArgsortOp : Torch_Op<"aten.argsort", [
15180+
AllowsTypeRefinement,
15181+
HasValueSemantics,
15182+
ReadOnly
15183+
]> {
15184+
let summary = "Generated op for `aten::argsort : (Tensor, int, bool) -> (Tensor)`";
15185+
let arguments = (ins
15186+
AnyTorchTensorType:$self,
15187+
Torch_IntType:$dim,
15188+
Torch_BoolType:$descending
15189+
);
15190+
let results = (outs
15191+
AnyTorchOptionalTensorType:$result
15192+
);
15193+
let hasCustomAssemblyFormat = 1;
15194+
let extraClassDefinition = [{
15195+
ParseResult AtenArgsortOp::parse(OpAsmParser &parser, OperationState &result) {
15196+
return parseDefaultTorchOp(parser, result, 3, 1);
15197+
}
15198+
void AtenArgsortOp::print(OpAsmPrinter &printer) {
15199+
printDefaultTorchOp(printer, *this, 3, 1);
15200+
}
15201+
}];
15202+
}
15203+
1517915204
def Torch_AtenSplitTensorOp : Torch_Op<"aten.split.Tensor", [
1518015205
AllowsTypeRefinement,
1518115206
ReadOnly

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10295,6 +10295,13 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1029510295
" %1 = torch.prim.TupleConstruct %0#1, %int4 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
1029610296
" return %1 : !torch.tuple<int, int>\n"
1029710297
" }\n"
10298+
" func.func @\"__torch_mlir_shape_fn.aten.argsort\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.list<int> {\n"
10299+
" return %arg0 : !torch.list<int>\n"
10300+
" }\n"
10301+
" func.func @\"__torch_mlir_dtype_fn.aten.argsort\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.int {\n"
10302+
" %int4 = torch.constant.int 4\n"
10303+
" return %int4 : !torch.int\n"
10304+
" }\n"
1029810305
" func.func @\"__torch_mlir_shape_fn.aten.narrow\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list<int> {\n"
1029910306
" %int1 = torch.constant.int 1\n"
1030010307
" %0 = torch.aten.add.int %arg2, %arg3 : !torch.int, !torch.int -> !torch.int\n"

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10214,6 +10214,31 @@ class DecomposeAtenTopkOp : public OpRewritePattern<AtenTopkOp> {
1021410214
};
1021510215
} // namespace
1021610216

10217+
namespace {
10218+
// decompose aten.argsort to aten.sort
10219+
class DecomposeAtenArgsortOp : public OpRewritePattern<AtenArgsortOp> {
10220+
public:
10221+
using OpRewritePattern::OpRewritePattern;
10222+
LogicalResult matchAndRewrite(AtenArgsortOp op,
10223+
PatternRewriter &rewriter) const override {
10224+
Location loc = op.getLoc();
10225+
auto context = op.getContext();
10226+
10227+
Value self = op.getSelf();
10228+
Value dim = op.getDim();
10229+
Value descending = op.getDescending();
10230+
auto selfType = cast<BaseTensorType>(self.getType());
10231+
auto sortIndicesType = selfType.getWithSizesAndDtype(
10232+
selfType.getOptionalSizes(),
10233+
IntegerType::get(context, 64, IntegerType::Signed));
10234+
auto sortOpResult = rewriter.create<AtenSortOp>(
10235+
loc, self.getType(), sortIndicesType, self, dim, descending);
10236+
rewriter.replaceOp(op, sortOpResult->getResult(1));
10237+
return success();
10238+
}
10239+
};
10240+
} // namespace
10241+
1021710242
namespace {
1021810243

1021910244
/// Creates coefficients based on DFT definition, see
@@ -11781,6 +11806,7 @@ class DecomposeComplexOpsPass
1178111806
patterns);
1178211807
addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanDimOp>(patterns);
1178311808
addPatternIfTargetOpIsIllegal<DecomposeAtenTopkOp>(patterns);
11809+
addPatternIfTargetOpIsIllegal<DecomposeAtenArgsortOp>(patterns);
1178411810
addPatternIfTargetOpIsIllegal<DecomposeAtenFftRfftOp>(patterns);
1178511811
addPatternIfTargetOpIsIllegal<DecomposeAtenHannWindowPeriodicOp>(patterns);
1178611812
addPatternIfTargetOpIsIllegal<DecomposeAtenScalarTensor>(patterns);

lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -557,6 +557,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
557557
target.addIllegalOp<AtenCrossEntropyLossOp>();
558558
target.addIllegalOp<AtenVarMeanDimOp>();
559559
target.addIllegalOp<AtenTopkOp>();
560+
target.addIllegalOp<AtenArgsortOp>();
560561
target.addIllegalOp<AtenHannWindowPeriodicOp>();
561562
target.addIllegalOp<AtenScalarTensorOp>();
562563
target.addIllegalOp<AtenScatterValueOp>();

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -531,6 +531,8 @@
531531
}
532532

533533
FX_IMPORTER_STABLEHLO_XFAIL_SET = {
534+
"ArgsortTensor_basic",
535+
"ArgsortTensorInteger_basic",
534536
"AddFloatIntModule_basic",
535537
"AtenKthvalueDynamicDimsModule_basic",
536538
"AtenKthvalueFloat64DynamicDimsModule_basic",
@@ -3348,6 +3350,8 @@
33483350
}
33493351

33503352
FX_IMPORTER_TOSA_XFAIL_SET = {
3353+
"ArgsortTensor_basic",
3354+
"ArgsortTensorInteger_basic",
33513355
"AtenSymConstrainRangeForSize_basic",
33523356
"AtenSymConstrainRange_basic",
33533357
"Aten_AssertScalar_basic",

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1988,6 +1988,12 @@ def aten〇sort〡dtype(self_rank_dtype: Tuple[int, int], dim: int = -1, descend
19881988
_, input_dtype = self_rank_dtype
19891989
return input_dtype, torch.long
19901990

1991+
def aten〇argsort〡shape(self: List[int], dim: int = -1, descending: bool = False) -> List[int]:
1992+
return self
1993+
1994+
def aten〇argsort〡dtype(self_rank_dtype: Tuple[int, int], dim: int = -1, descending: bool = False) -> int:
1995+
return torch.long
1996+
19911997
def aten〇narrow〡shape(self: List[int], dim: int, start: int, length: int) -> List[int]:
19921998
return upstream_shape_functions.slice(self, dim, start, start + length, 1)
19931999

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1081,6 +1081,7 @@ def emit_with_mutating_variants(key, **kwargs):
10811081
emit("aten::any.bool : (bool[]) -> (bool)", has_folder=True)
10821082
emit("aten::sort.int : (int[], bool) -> ()", has_canonicalizer=True)
10831083
emit("aten::sort : (Tensor, int, bool) -> (Tensor, Tensor)", has_folder=True)
1084+
emit("aten::argsort : (Tensor, int, bool) -> (Tensor)")
10841085
emit("aten::split.Tensor : (Tensor, int, int) -> (Tensor[])")
10851086
emit("aten::split_with_sizes : (Tensor, int[], int) -> (Tensor[])")
10861087
emit(

projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5357,6 +5357,36 @@ def SortTensorNegativeDimension_basic(module, tu: TestUtils):
53575357
module.forward(tu.rand(3, 4, 5))
53585358

53595359

5360+
class ArgsortTensor(torch.nn.Module):
5361+
def __init__(self):
5362+
super().__init__()
5363+
5364+
@export
5365+
@annotate_args([None, ([-1, -1, -1], torch.float32, True)])
5366+
def forward(self, input):
5367+
return torch.argsort(input)
5368+
5369+
5370+
@register_test_case(module_factory=lambda: ArgsortTensor())
5371+
def ArgsortTensor_basic(module, tu: TestUtils):
5372+
module.forward(tu.rand(3, 4, 5))
5373+
5374+
5375+
class ArgsortTensorInteger(torch.nn.Module):
5376+
def __init__(self):
5377+
super().__init__()
5378+
5379+
@export
5380+
@annotate_args([None, ([-1, -1], torch.int64, True)])
5381+
def forward(self, input):
5382+
return torch.argsort(input)
5383+
5384+
5385+
@register_test_case(module_factory=lambda: ArgsortTensorInteger())
5386+
def ArgsortTensorInteger_basic(module, tu: TestUtils):
5387+
module.forward(tu.randint(2, 3))
5388+
5389+
53605390
# ==============================================================================
53615391

53625392

0 commit comments

Comments
 (0)