Skip to content

Commit 3be1c5e

Browse files
Added better error messages for type validators (stack trace) (#7999)
* Added better error messages for type validators (stack trace) * changed error message format (moved newlines for better readability.) * Fixed linting issue. * Fixed unit test (it checks the first element of the new pair)
1 parent 9bd18f6 commit 3be1c5e

File tree

3 files changed

+13
-7
lines changed

3 files changed

+13
-7
lines changed

exir/tests/test_arg_validator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def forward(self, x):
6464
ops.edge.aten._log_softmax.default.name(),
6565
)
6666
self.assertDictEqual(
67-
validator.violating_ops[key],
67+
validator.violating_ops[key][0],
6868
{
6969
"self": torch.bfloat16,
7070
"__ret_0": torch.bfloat16,

exir/verification/arg_validator.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,9 @@ class EdgeOpArgValidator(torch.fx.Interpreter):
3737

3838
def __init__(self, graph_module: torch.fx.GraphModule) -> None:
3939
super().__init__(graph_module)
40-
self.violating_ops: Dict[EdgeOpOverload, Dict[str, Optional[torch.dtype]]] = (
41-
defaultdict(dict)
42-
)
40+
self.violating_ops: Dict[
41+
EdgeOpOverload, Tuple[Dict[str, Optional[torch.dtype]], torch.fx.Node]
42+
] = defaultdict(dict)
4343

4444
def run_node(self, n: torch.fx.Node) -> None:
4545
self.node = n
@@ -125,5 +125,5 @@ def call_function( # noqa: C901 # pyre-fixme[14]
125125

126126
valid = target._schema.dtype_constraint.validate(tensor_arg_types)
127127
if not valid:
128-
self.violating_ops[target] = tensor_arg_types
128+
self.violating_ops[target] = (tensor_arg_types, self.node)
129129
return super().call_function(target, args, kwargs) # pyre-fixme[6]

exir/verification/verifier.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,9 +189,15 @@ def _check_tensor_args_matching_op_allowed_dtype(gm: GraphModule) -> None:
189189
return
190190

191191
if validator.violating_ops:
192+
error_msg = ""
193+
for op, node in validator.violating_ops.items():
194+
# error_msg += f"#####################################################\n"
195+
error_msg += f"\nOperator: {op} with args: {node[0]}\n"
196+
error_msg += f"stack trace: {node[1].stack_trace}\n"
197+
# error_msg += f"#####################################################\n"
192198
raise SpecViolationError(
193-
f"These operators are taking Tensor inputs with mismatched dtypes: {validator.violating_ops}"
194-
"Please make sure the dtypes of the Tensor inputs are the same as the dtypes of the corresponding "
199+
f"These operators are taking Tensor inputs with mismatched dtypes:\n{error_msg}"
200+
"Please make sure the dtypes of the Tensor inputs are the same as the dtypes of the corresponding outputs."
195201
)
196202

197203

0 commit comments

Comments
 (0)