Skip to content

Commit f346d43

Browse files
hsharma35facebook-github-bot
authored andcommitted
Enable IR checks (#7408)
Summary: As titled. Reviewed By: zonglinpeng Differential Revision: D59638386
1 parent f341da8 commit f346d43

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

backends/cadence/aot/compiler.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,15 @@ def export_to_edge(
201201
edge_prog_manager = to_edge(
202202
expo_program,
203203
compile_config=EdgeCompileConfig(
204-
_check_ir_validity=False, _skip_dim_order=True
204+
_skip_dim_order=True,
205+
# Allow specific non-core aten ops in the IR.
206+
_core_aten_ops_exception_list=[
207+
torch.ops.aten.linear.default,
208+
torch.ops.aten.native_batch_norm.default,
209+
torch.ops.aten.linalg_vector_norm.default,
210+
torch.ops.aten.unfold.default,
211+
torch.ops.aten.angle.default,
212+
],
205213
),
206214
)
207215

backends/cadence/aot/replace_ops.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -579,7 +579,7 @@ def call_operator(self, op, args, kwargs, meta):
579579
arg_shape[:dim] + (left_padding_size,) + arg_shape[dim + 1 :]
580580
)
581581
left_padding_node = super().call_operator(
582-
torch.ops.aten.full.default,
582+
exir_ops.edge.aten.full.default,
583583
(
584584
left_padding_shape,
585585
value,
@@ -596,7 +596,7 @@ def call_operator(self, op, args, kwargs, meta):
596596
arg_shape[:dim] + (right_padding_size,) + arg_shape[dim + 1 :]
597597
)
598598
right_padding_node = super().call_operator(
599-
torch.ops.aten.full.default,
599+
exir_ops.edge.aten.full.default,
600600
(
601601
right_padding_shape,
602602
value,
@@ -726,7 +726,7 @@ def call_operator(self, op, args, kwargs, meta):
726726

727727
flipped_weight = (
728728
super().call_operator(
729-
torch.ops.aten.flip.default,
729+
exir_ops.edge.aten.flip.default,
730730
(
731731
transposed_weight,
732732
[-1] if transposed_weight.to_tensor().dim() == 3 else [-1, -2],

0 commit comments

Comments
 (0)