-
Notifications
You must be signed in to change notification settings - Fork 608
Remove memory-format workaround for Arm backend #3981
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove memory-format workaround for Arm backend #3981
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/3981
Note: Links to docs will display an error until the docs builds have been completed. ❗ 2 Active SEVsThere are 2 currently active SEVs. If your PR is affected, please view them below:
✅ No FailuresAs of commit da256dc with merge base dd7fa6a ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
dim_order = HWCM_Order | ||
else: | ||
dim_order = tuple(range(len(data.shape))) | ||
node.meta["tosa_dim_order"] = dim_order |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me think, there are at least three ways of doing this.
(1) the way you are doing it. i.e. in preprocess()
insert a dim_order tuple, track it manually and transform things as we convert to TOSA.
👍 Flexibility
👍 Mostly transparent to the user, except the compile_spec
👎 You are expecting the input to the call_delegate()
to be in NHWC format
👎 Doesn't use pytorch infra
👎 node.meta['tosa_dim_order'] != node.meta["val"].dim_order()
which may cause confusion
(2) Global switch, i.e. gm = export(model.to(memory_format=torch.channels_last), torch.randn(1,2,3,4).to(memory_format=torch.channels_last)
. This way gm
should have everything in nhwc dim_order.
👍 Clean impl, uses pytorch infra
👎 If something falls out of delegate, that must be run in NHWC format
👎 Assumes input conversion to NHWC happens outside of the model. Might be OK if you are getting it already in NHWC.
(3) Runtime conversion: After export we insert the to_dim_order(NHWC)
ops in the beginning/end (also on the weights/placeholders etc.) as we want. Consume these nodes as we see fit in the delegate.
👍 Mix of (1) and (2), allows you to select where the input conversion to NHWC happens.
👍 Makes it flexible in terms of which node runs in which format. That is if you leave the node, it will take NCHW and convert it at runtime on the CPU. If you partition and delete the node the input to the delegate and to the model (I think, I have to check, this is similar to deleting first q node) will be NHWC.
👍 Still uses pytorch infra, i.e. no need to manually maintain tosa_dim_order
unless for cases like DW conv (but I see that as a special constraint from the backend so..)
👎 slightly harder to implement, particularly if you are sprinkling multiple nodes in the graph. But if you run this pass before partitioner then it might not be too difficult.
👎 Potentially involves running a pass outside i.e. modifying the graph before to_backend()
. As long as this is user visible it might be ok.
👎 Exposes some footguns when multiple delegates are involved but should be the case here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you Digant for your comments, much appreciated!
(1) This would assume that the input is in NHWC-format if permute_memory_to_nhwc==True
, yes. The next step of this PR could be to add a pass or extend this pass to add to_dim_order-ops around all our delegates such that the graph becomes:
Input (default format) -> ... -> to_dim_order(nhwc) -> delegate -> to_dim_order(default format) -> ... -> output (default format). This would require that we modify the graph outside (just before and after) our delegate after partitioning. What's the proper way of doing this?
I agree with the last two bullets. Those are the drawbacks of the current solution.
(2) model.to(memory_format=torch.channels_last)
and node.meta["val"].to(memory_format=torch.channels_last)
gives dim_order=(0,1,2,3)
for certain shapes such as (3,1,3,3). Otherwise I think node.meta["val"].to(memory_format=torch.channels_last)
could have been used in (1) instead of annotating node.meta["tosa_dim_order"]
.
(3) I don't think I fully understand this solution. We insert to_dim_order-ops at the inputs, outputs and placeholders of the full graph or just our delegate sub-graph? Will the dim_order be propagated forward when inserting to_dim_order-ops?
c1eac04
to
0cd6a43
Compare
I will take a look tomorrow. |
elif node.op == "placeholder": | ||
# node is an input, weight or bias node | ||
consumer_node = list(node.users)[0] | ||
if self.is_weight_node_for_dw_conv(consumer_node): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: use consistent naming for depthwise convolution helper functions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated in lastest rebase.
if self.is_weight_node_for_dw_conv(node): | ||
dim_order = HWCM_Order | ||
else: | ||
dim_order = tuple(range(node_data.dim())) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume this is mainly for the dim_order agnostic ops?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I consider it be for all ops, as we assume that the input is always transposed. It corresponds to this block https://github.com/pytorch/executorch/blob/main/backends/arm/arm_backend.py#L261C1-L261C70.
@@ -120,7 +120,7 @@ def _test_linear_tosa_MI_pipeline( | |||
ArmTester( | |||
module, | |||
example_inputs=test_data, | |||
compile_spec=common.get_tosa_compile_spec(), | |||
compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=False), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess rank==4 ==> NHWC breaks this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, (1, ,1, 1, C) works though. As a step towards to solution (2) we plan on removing this flag in the near future.
data=None, | ||
placeholderFilename=inputs[0].name + ".npy", | ||
) | ||
tosa_graph.addInputTensor(tensor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC this means we will expect NHWC input tensor, do you want to add a runtime assert for this else it might just consume wrong format data and tricky to debug.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The functionality is the same as before, but it sounds like good thing to do. We can do it in an upcoming PR or when we implement (2) if that's okay?
@@ -16,7 +16,7 @@ | |||
build_rescale_from_int32, | |||
build_rescale_to_int32, | |||
) | |||
from executorch.backends.arm.tosa_utils import broadcast_shapes, getNodeArgs | |||
from executorch.backends.arm.tosa_utils import broadcast_shapes, getNodeArgs, tosa_shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you planning to add more ops? For instance Permute op argument needs to be updated based on what format input it gets when you update the previous node's output format.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, there are a few open PR:s with new ops. The ambition is that we can just update shapes such as permute's argument with tosa_shape(shape, dim_order)
.
@digantdesai has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
|
||
def call(self, graph_module: torch.fx.GraphModule): | ||
NHWC_Order = (0, 2, 3, 1) | ||
HWCM_Order = (2, 3, 0, 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I dont think I follow this dim order. Can you add some comments? Did you mean HWNC?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a special case for depthwise convolution where the weights have shape = (H, W, C, M) => different dim-order (https://www.mlplatform.org/tosa/tosa_spec.html#_depthwise_conv2d).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So not clear to me if this is really making use of dim order utils? is this in preparation to that?
Yes this is the intermediate step. See convo on Slack - https://pytorch.slack.com/archives/C01FV3A914N/p1719838104994989?thread_ts=1718355881.095489&cid=C01FV3A914N |
0cd6a43
to
89fb201
Compare
@digantdesai has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
The dim-order of each node is annotated in a pass. Some refactoring of arm_backend.py. Signed-off-by: Oscar Andersson <[email protected]> Change-Id: I14691b51b99acb9e8605100fd25731ab45c55a9d
Signed-off-by: Oscar Andersson <[email protected]> Change-Id: Ica6addb95d6b925beef4696780334268821af608
89fb201
to
9de956f
Compare
@digantdesai has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Change-Id: Id0877faeddec21acdb918fc773bee410dbe6dbb5
@digantdesai has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@digantdesai merged this pull request in d3c92de. |
Remove temporary fix for memory format introduced in #2371. The dim-order of each node is annotated in a pass. Also some refactoring of arm_backend.py.