@@ -79,22 +79,16 @@ def extract_tensor_meta(meta):
79
79
# Class to capture arguments and turn into tensor references for TOSA OPs
80
80
class TosaArg :
81
81
def __process_node (self , argument : torch .fx .Node ):
82
- self .name = argument .name
82
+ self .name : str = argument .name
83
83
self .dtype , self .shape , self .dim_order = extract_tensor_meta (argument .meta )
84
84
85
85
def __process_list (self , argument ):
86
- self .special = list (argument )
86
+ self .special : list = list (argument )
87
87
88
88
def __process_number (self , argument : float | int ):
89
- self .number = argument
89
+ self .number : float | int = argument
90
90
91
91
def __init__ (self , argument : Any ) -> None :
92
- self .name = None # type: ignore[assignment]
93
- self .dtype = None
94
- self .shape = None
95
- self .dim_order = None
96
- self .special = None
97
-
98
92
if argument is None :
99
93
return
100
94
@@ -114,3 +108,20 @@ def __init__(self, argument: Any) -> None:
114
108
raise RuntimeError (
115
109
f"Unhandled node input argument: { argument } , of type { type (argument )} "
116
110
)
111
+
112
+ def __repr__ (self ):
113
+ attrs = []
114
+ if hasattr (self , "name" ):
115
+ if self .name is not None :
116
+ attrs .append (f"name={ self .name !r} " )
117
+ if self .dtype is not None :
118
+ attrs .append (f"dtype={ ts .DTypeNames [self .dtype ]} " )
119
+ if self .shape is not None :
120
+ attrs .append (f"shape={ self .shape !r} " )
121
+ if self .dim_order is not None :
122
+ attrs .append (f"dim_order={ self .dim_order !r} " )
123
+ if hasattr (self , "special" ) and self .special is not None :
124
+ attrs .append (f"special={ self .special !r} " )
125
+ if hasattr (self , "number" ) and self .number is not None :
126
+ attrs .append (f"number={ self .number !r} " )
127
+ return f"{ self .__class__ .__name__ } ({ ', ' .join (attrs )} )"
0 commit comments