Skip to content

Commit afabed6

Browse files
jamesjwupytorchmergebot
authored andcommitted
[inductor][custom ops] Add tag to custom ops to preserve stride orders in inductor (pytorch#117298)
fixes pytorch#116715 Pull Request resolved: pytorch#117298 Approved by: https://github.com/eellison
1 parent 4155632 commit afabed6

File tree

5 files changed

+141
-13
lines changed

5 files changed

+141
-13
lines changed

aten/src/ATen/native/tags.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@
4242
desc: |
4343
This tag indicates if an operator doesn't guarantee bitwise equivalence
4444
across different runs of an operator with identical inputs.
45+
- tag: needs_fixed_stride_order
46+
desc: |
47+
This tag indicates that the operator should be passed Tensors following
48+
the same stride permutation as observed in eager when compiled in inductor.
4549
4650
# NOTE [Core ATen Ops]
4751
- tag: core

test/inductor/test_torchinductor.py

Lines changed: 108 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,20 @@
116116
)
117117
vec_dtypes = [torch.float, torch.bfloat16, torch.float16]
118118

119-
libfoo = None
119+
libtest = torch.library.Library("test", "FRAGMENT")
120+
ids = set()
121+
122+
123+
def define_custom_op_for_test(id_, fn_cpu, fn_cuda, fn_meta, tags=()):
124+
global libtest
125+
global ids
126+
if id_ not in ids:
127+
libtest.define(f"{id_}(Tensor self) -> Tensor", tags=tags)
128+
libtest.impl(id_, fn_cpu, "CPU")
129+
libtest.impl(id_, fn_cuda, "CUDA")
130+
libtest.impl(id_, fn_meta, "Meta")
131+
ids.add(id_)
132+
120133

121134
f32 = torch.float32
122135

@@ -7989,22 +8002,108 @@ def foo_cuda(x):
79898002
def foo_meta(x):
79908003
return torch.empty_like(x)
79918004

7992-
global libfoo
7993-
if libfoo is None:
7994-
libfoo = torch.library.Library("foo", "DEF")
7995-
libfoo.define("custom(Tensor self) -> Tensor")
7996-
libfoo.impl("custom", foo_cpu, "CPU")
7997-
libfoo.impl("custom", foo_cuda, "CUDA")
7998-
libfoo.impl("custom", foo_meta, "Meta")
8005+
define_custom_op_for_test("foo", foo_cpu, foo_cuda, foo_meta)
79998006

80008007
def fn(x):
80018008
a = torch.nn.functional.relu(x)
8002-
b = torch.ops.foo.custom(a)
8009+
b = torch.ops.test.foo(a)
80038010
c = torch.cos(b)
80048011
return c
80058012

80068013
self.common(fn, (torch.randn((16, 32)),), check_lowp=False)
80078014

8015+
@requires_gpu()
8016+
@torch._inductor.config.patch("layout_optimization", True)
8017+
@torch._inductor.config.patch("keep_output_stride", False)
8018+
@config.patch(implicit_fallbacks=True)
8019+
def test_custom_op_fixed_layout_sequential(self):
8020+
import torch.library
8021+
8022+
mod = nn.Conv2d(3, 128, 1, stride=1, bias=False).cuda()
8023+
inp = torch.rand(2, 3, 128, 128, device="cuda")
8024+
expected_stride = mod(inp).stride()
8025+
8026+
def bar_cpu(x):
8027+
self.assertEqual(x.stride(), expected_stride)
8028+
return x.clone()
8029+
8030+
def bar_cuda(x):
8031+
self.assertEqual(x.stride(), expected_stride)
8032+
return x.clone()
8033+
8034+
def bar_meta(x):
8035+
return torch.empty_like(x)
8036+
8037+
define_custom_op_for_test(
8038+
"bar",
8039+
bar_cpu,
8040+
bar_cuda,
8041+
bar_meta,
8042+
tags=[torch._C.Tag.needs_fixed_stride_order],
8043+
)
8044+
8045+
def fn(x):
8046+
z = mod(x)
8047+
output = torch.ops.test.bar(z)
8048+
return output
8049+
8050+
with torch.no_grad():
8051+
# With keep_output_stride False, inductor would normally have different layout from eager execution
8052+
# But because our custom op needs fixed layout, the assertions in the custom op will pass
8053+
self.common(fn, (inp,), check_lowp=False)
8054+
8055+
@requires_gpu()
8056+
@config.patch(implicit_fallbacks=True)
8057+
def test_custom_op_fixed_layout_channels_last(self):
8058+
class Block(nn.Module):
8059+
def __init__(
8060+
self,
8061+
):
8062+
super().__init__()
8063+
8064+
self.in_layers = nn.Sequential(
8065+
nn.Dropout(p=0.1),
8066+
)
8067+
8068+
def helper(self, x):
8069+
out = F.gelu(x)
8070+
out = self.in_layers(out)
8071+
return out
8072+
8073+
def forward(self, x):
8074+
out = self.helper(x)
8075+
out = torch.ops.test.baz(out)
8076+
return out
8077+
8078+
model = Block()
8079+
model = model.to("cuda").to(memory_format=torch.channels_last)
8080+
input_t = torch.randn([1, 320, 128, 128], dtype=torch.float32, device="cuda")
8081+
input_t = input_t.to(memory_format=torch.channels_last)
8082+
expected_strides = model.helper(input_t).stride()
8083+
8084+
def baz_cpu(x):
8085+
self.assertEqual(expected_strides, x.stride())
8086+
return x.clone()
8087+
8088+
def baz_cuda(x):
8089+
self.assertEqual(expected_strides, x.stride())
8090+
return x.clone()
8091+
8092+
def baz_meta(x):
8093+
return torch.empty_like(x)
8094+
8095+
define_custom_op_for_test(
8096+
"baz",
8097+
baz_cpu,
8098+
baz_cuda,
8099+
baz_meta,
8100+
tags=[torch._C.Tag.needs_fixed_stride_order],
8101+
)
8102+
8103+
with torch.no_grad():
8104+
net = torch.compile(model)
8105+
out = net(input_t)
8106+
80088107
def test_buffer_use_after_remove(self):
80098108
# https://github.com/pytorch/pytorch/issues/102857
80108109

test/inductor/test_torchinductor_codegen_dynamic_shapes.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,9 @@ def run(*ex, **kwargs):
198198
("cpu", "cuda")
199199
),
200200
"test_zero_element_mutation_dynamic_shapes": TestFailure(("cpu", "cuda")),
201+
"test_custom_op_fixed_layout_sequential_dynamic_shapes": TestFailure(
202+
("cpu", "cuda")
203+
),
201204
"test_cat_uint8_dynamic_shapes": TestFailure(
202205
("cpu",)
203206
), # cat on uint8 input is using aten fallback on cpu

torch/_custom_op/impl.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -958,14 +958,14 @@ def get_abstract_impl(qualname):
958958
return custom_op._get_impl("abstract").func
959959

960960

961-
def _custom_op_with_schema(qualname, schema):
961+
def _custom_op_with_schema(qualname, schema, needs_fixed_stride_order=True):
962962
ns, name = qualname.split("::")
963963
schema_str = f"{name}{schema}"
964964
function_schema = FunctionSchema.parse(schema_str)
965965
validate_schema(function_schema)
966-
966+
tags = [torch._C.Tag.needs_fixed_stride_order] if needs_fixed_stride_order else []
967967
lib = library.Library(ns, "FRAGMENT")
968-
lib.define(schema_str)
968+
lib.define(schema_str, tags=tags)
969969
ophandle = find_ophandle_or_throw(ns, function_schema.name)
970970
result = CustomOp(lib, ns, function_schema, name, ophandle, _private_access=True)
971971
result._register_autograd_kernel_indirection()

torch/_inductor/graph.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
TensorBox,
4848
)
4949
from .lowering import (
50+
constrain_to_fx_strides,
5051
FALLBACK_ALLOW_LIST,
5152
fallback_handler,
5253
fallback_node_due_to_unsupported_type,
@@ -671,6 +672,23 @@ def call_function(self, target, args, kwargs):
671672
# passthrough lowerings from .pattern_matcher
672673
return target(*args, **kwargs)
673674

675+
def get_custom_op_layout_constraints(target, args, kwargs):
676+
# Custom operations that require preserving stride order
677+
# which run through implicit fallback must constrain their
678+
# arguments' fx strides
679+
layout_constraint = None
680+
if torch._C.Tag.needs_fixed_stride_order in target.tags:
681+
# We have to set the current args because call_function will immediately
682+
# evaluate this lowering after creating the fallback, without evaluating
683+
# the layout constraint
684+
args, kwargs = constrain_to_fx_strides(
685+
self.current_node, *args, **kwargs
686+
)
687+
# Also register the layout constraint so when the fallback
688+
# is used again, we can constrain the args to the same layout
689+
layout_constraint = constrain_to_fx_strides
690+
return layout_constraint, args, kwargs
691+
674692
if target not in lowerings:
675693
assert isinstance(
676694
target, torch._ops.OpOverload
@@ -679,6 +697,9 @@ def call_function(self, target, args, kwargs):
679697
if base_name in FALLBACK_ALLOW_LIST:
680698
make_fallback(target)
681699
elif config.implicit_fallbacks:
700+
layout_constraint, args, kwargs = get_custom_op_layout_constraints(
701+
target, args, kwargs
702+
)
682703
error = (
683704
MissingOperatorWithDecomp
684705
if get_decompositions([target])
@@ -688,7 +709,8 @@ def call_function(self, target, args, kwargs):
688709
"Creating implicit fallback for:\n%s",
689710
error.operator_str(target, args, kwargs),
690711
)
691-
make_fallback(target)
712+
make_fallback(target, layout_constraint)
713+
692714
elif get_decompositions([target]):
693715
# There isn't a good way to dynamically patch this in
694716
# since AOT Autograd already ran. The error message tells

0 commit comments

Comments
 (0)