Skip to content

Commit 83e22d7

Browse files
ColinPepplerxuhancn
authored andcommitted
[dynamo] add meta fn for aten.kthvalue.default (pytorch#130562)
I saw ``` torch._dynamo.exc.Unsupported: unsupported operator: aten.kthvalue.default ``` Pull Request resolved: pytorch#130562 Approved by: https://github.com/jingsh, https://github.com/zou3519
1 parent 3dd6c07 commit 83e22d7

File tree

4 files changed

+16
-10
lines changed

4 files changed

+16
-10
lines changed

test/functorch/test_aotdispatch.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5947,7 +5947,6 @@ def forward(self, x):
59475947
xfail(
59485948
"index_fill", ""
59495949
), # Cannot call sizes() on tensor with symbolic sizes/strides
5950-
xfail("kthvalue", ""), # Cannot call sizes() on tensor with symbolic sizes/strides
59515950
xfail(
59525951
"linalg.lstsq", ""
59535952
), # aten.linalg_lstsq.default - couldn't find symbolic meta function/decomposition

test/test_meta.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -674,7 +674,6 @@ def run_meta_crossref(
674674
torch.functional.unique_consecutive : {f64, i32, i64, u8, i16, f16, bf16, b8, i8, f32, u16, u32, u64},
675675
torch.histogram : {f64, f32},
676676
torch.histogramdd : {f64, f32},
677-
torch.kthvalue : {f64, i32, i64, u8, i16, f16, bf16, i8, f32},
678677
torch.nn.functional.ctc_loss : {f64, f32},
679678
torch.nn.functional.gaussian_nll_loss : {f16, f64, bf16, f32},
680679
torch.linalg.lstsq : {f64, f32, c128, c64},
@@ -748,7 +747,6 @@ def run_meta_crossref(
748747
torch.functional.unique: {f16}, # aten::_unique2, aten::unique_dim
749748
torch.functional.unique_consecutive: {f16}, # aten::unique_consecutive
750749
torch.geqrf: {f32, f64}, # aten::geqrf
751-
torch.kthvalue: {f16}, # aten::kthvalue.values
752750
}
753751

754752
meta_function_device_skips['cpu'] = {
@@ -846,7 +844,6 @@ def __torch_function__(self, func, types, args=(), kwargs=None):
846844
aten.equal.default : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8},
847845
aten.histogram.bin_ct : {f32, f64},
848846
aten.histogram.bins_tensor : {f32, f64},
849-
aten.kthvalue.default : {i8, f64, i64, f16, bf16, f32, i32, i16, u8},
850847
aten.unique_consecutive.default : {i8, f64, i64, f16, bf16, f32, i32, b8, i16, u8, u16, u32, u64},
851848
aten.unique_dim.default : {i8, f64, i64, f16, bf16, f32, i32, b8, i16, u8, u16, u32, u64},
852849
aten.upsample_nearest3d.vec : {bf16, f32, f64, u8},
@@ -895,7 +892,6 @@ def __torch_function__(self, func, types, args=(), kwargs=None):
895892
aten._use_cudnn_ctc_loss.Tensor: {f32, f64}, # aten::_use_cudnn_ctc_loss.Tensor
896893
aten.cudnn_grid_sampler.default: {f16, f32, f64}, # aten::cudnn_grid_sampler
897894
aten.geqrf.default: {f32, f64}, # aten::geqrf
898-
aten.kthvalue.default: {f16}, # aten::kthvalue.values
899895
aten.linalg_eigvalsh.out: {f32, f64}, # aten::linalg_eigvalsh.out
900896
aten.log_sigmoid_forward.default: {bf16, f16, f64, f32},
901897
aten.log_sigmoid_forward.output : {bf16, f16, f64, f32}, # aten::log_sigmoid_forward.output

test/test_proxy_tensor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1999,7 +1999,6 @@ def f(t):
19991999
xfail('geqrf', ''), # aten.geqrf.default - couldn't find symbolic meta function/decomposition
20002000
xfail('histogram', ''), # Could not run 'aten::histogram.bin_ct' with arguments from the 'Meta' backend. This c...
20012001
xfail('histogramdd', ''), # aten._histogramdd_bin_edges.default - couldn't find symbolic meta function/decomposition
2002-
xfail('kthvalue', ''), # aten.kthvalue.default - couldn't find symbolic meta function/decomposition
20032002
xfail('nanquantile', ''), # Could not run 'aten::equal' with arguments from the 'Meta' backend.
20042003
xfail('nn.functional.binary_cross_entropy', ''), # aten.new_empty.default - couldn't find symbolic meta function/decom...
20052004
xfail('nn.functional.cross_entropy', ''), # aten.size.default - couldn't find symbolic meta function/decomposition

torch/_meta_registrations.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5772,10 +5772,6 @@ def scalar_tensor(s, dtype=None, layout=None, device=None, pin_memory=None):
57725772
def topk_meta(self, k, dim=-1, largest=True, sorted=True):
57735773
# From aten/src/ATen/native/Sorting.cpp
57745774
dim = maybe_wrap_dim(dim, self.dim(), wrap_scalar=True)
5775-
torch._check(
5776-
k >= 0 and k <= (self.size(dim) if self.dim() > 0 else 1),
5777-
lambda: "selected index k out of range",
5778-
)
57795775
sliceSize = 1 if self.dim() == 0 else self.size(dim)
57805776
torch._check(k >= 0 and k <= sliceSize, lambda: "k not in range for dimension")
57815777

@@ -5785,6 +5781,22 @@ def topk_meta(self, k, dim=-1, largest=True, sorted=True):
57855781
return self.new_empty(topKSize), self.new_empty(topKSize, dtype=torch.int64)
57865782

57875783

5784+
@register_meta([aten.kthvalue.default, aten.kthvalue.values])
5785+
@out_wrapper("values", "indices")
5786+
def kthvalue_meta(self, k, dim=-1, keepdim=False):
5787+
dim = maybe_wrap_dim(dim, self.dim(), wrap_scalar=True)
5788+
dimSize = self.size(dim) if self.dim() > 0 else 1
5789+
torch._check(
5790+
k >= 1 and k <= dimSize,
5791+
lambda: f"kthvalue(): selected number k out of range for dimension {dim}",
5792+
)
5793+
5794+
shape = list(self.shape[:dim] + self.shape[dim + 1 :])
5795+
if keepdim and self.dim() > 0:
5796+
shape.insert(dim, 1)
5797+
return self.new_empty(shape), self.new_empty(shape, dtype=torch.int64)
5798+
5799+
57885800
legacy_contiguous_memory_format = torch.contiguous_format
57895801

57905802

0 commit comments

Comments
 (0)