Skip to content

Commit 323ff94

Browse files
authored
Arm backend: Introduce TOSA aware context (#11475)
Introduce a 'TosaLoweringContext' to track the TOSA specification being used in the backend to be able to make decisions in the pipeline. Signed-off-by: Per Åstrand <[email protected]>
1 parent 4df40e1 commit 323ff94

File tree

2 files changed

+38
-2
lines changed

2 files changed

+38
-2
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,10 @@
6262
UnsqueezeScalarPlaceholdersPass,
6363
)
6464

65-
from executorch.backends.arm.tosa_specification import TosaSpecification
65+
from executorch.backends.arm.tosa_specification import (
66+
TosaLoweringContext,
67+
TosaSpecification,
68+
)
6669
from executorch.backends.transforms.decompose_sdpa import (
6770
DecomposeScaledDotProductAttention,
6871
)
@@ -80,7 +83,8 @@ def __init__(self, tosa_spec: TosaSpecification) -> None:
8083
super().__init__()
8184

8285
def _transform(self, graph_module: GraphModule):
83-
return self(graph_module).graph_module
86+
with TosaLoweringContext(self.tosa_spec):
87+
return self(graph_module).graph_module
8488

8589
def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
8690
self.add_pass(FuseQuantizedActivationPass())

backends/arm/tosa_specification.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# JIT compiler flows.
1212
#
1313

14+
import contextvars
1415
import re
1516
from typing import List
1617

@@ -214,3 +215,34 @@ def support_integer(self):
214215

215216
def support_float(self):
216217
return "FP" in self.profiles
218+
219+
220+
class TosaLoweringContext:
221+
"""
222+
A context manager to handle the TOSA specific aspects of the lowering process.
223+
For now it only handles the TOSA specification context, but it can be extended
224+
to include other policies or configurations.
225+
"""
226+
227+
# Define a context variable for the spec
228+
tosa_spec_var: contextvars.ContextVar = contextvars.ContextVar("tosa_spec")
229+
230+
def __init__(self, spec: TosaSpecification):
231+
self.spec = spec
232+
233+
def __enter__(self):
234+
# Set the spec in the context variable and store the token for later reset
235+
self.token = TosaLoweringContext.tosa_spec_var.set(self.spec)
236+
return self
237+
238+
def __exit__(self, exc_type, exc_value, traceback):
239+
# Reset the context variable to its previous state
240+
TosaLoweringContext.tosa_spec_var.reset(self.token)
241+
242+
243+
# A helper function to retrieve the current spec anywhere in your code
244+
def get_context_spec() -> TosaSpecification:
245+
try:
246+
return TosaLoweringContext.tosa_spec_var.get()
247+
except LookupError:
248+
raise RuntimeError("Function must be executed within a TosaLoweringContext")

0 commit comments

Comments
 (0)