@@ -29,30 +29,40 @@ def get_input_act_qspec(self) -> QuantizationSpec | None:
29
29
"""Returns QuantizationSpec 'input_activation' after asserting that input_activation.qscheme is valid."""
30
30
if self .input_activation is None :
31
31
return None
32
- assert self .input_activation .qscheme in [
32
+ # Validate that input_activation uses a supported qscheme
33
+ if self .input_activation .qscheme not in [
33
34
torch .per_tensor_affine ,
34
35
torch .per_tensor_symmetric ,
35
- ], f"Unsupported quantization_spec { self .input_activation } for input_activation."
36
+ ]:
37
+ raise ValueError (
38
+ f"Unsupported quantization_spec { self .input_activation } for input_activation."
39
+ )
36
40
return self .input_activation
37
41
38
42
def get_output_act_qspec (self ) -> QuantizationSpec | None :
39
43
"""Returns QuantizationSpec 'output_activation' after asserting that output_activation.qscheme is valid."""
40
44
if self .output_activation is None :
41
45
return None
42
- assert self .output_activation .qscheme in [
46
+ # Validate that output_activation uses a supported qscheme
47
+ if self .output_activation .qscheme not in [
43
48
torch .per_tensor_affine ,
44
49
torch .per_tensor_symmetric ,
45
- ], f"Unsupported quantization_spec { self .output_activation } for output_activation."
50
+ ]:
51
+ raise ValueError (
52
+ f"Unsupported quantization_spec { self .output_activation } for output_activation."
53
+ )
46
54
return self .output_activation
47
55
48
56
def get_weight_qspec (self ) -> QuantizationSpec | None :
49
57
"""Returns QuantizationSpec 'weight' after asserting that weight.qscheme is valid."""
50
58
if self .weight is None :
51
59
return None
52
- assert self .weight .qscheme in [
60
+ # Validate that weight uses a supported qscheme
61
+ if self .weight .qscheme not in [
53
62
torch .per_tensor_symmetric ,
54
63
torch .per_channel_symmetric ,
55
- ], f"Unsupported quantization_spec { self .weight } for weight"
64
+ ]:
65
+ raise ValueError (f"Unsupported quantization_spec { self .weight } for weight" )
56
66
return self .weight
57
67
58
68
def get_bias_qspec (self , node : torch .fx .Node ) -> QuantizationSpec | None :
@@ -61,11 +71,11 @@ def get_bias_qspec(self, node: torch.fx.Node) -> QuantizationSpec | None:
61
71
def _derive_qparams_fn (
62
72
obs_or_fqs : list [ObserverOrFakeQuantize ],
63
73
) -> tuple [torch .Tensor , torch .Tensor ]:
64
- assert (
65
- len (obs_or_fqs ) == 2
66
- ), "Expecting two obs/fqs, one for activation and one for weight, got: {}" . format (
67
- len (obs_or_fqs )
68
- )
74
+ # Validate expected number of observers/fake-quantizes
75
+ if len (obs_or_fqs ) != 2 :
76
+ raise ValueError (
77
+ f"Expecting two obs/fqs, one for activation and one for weight, got: { len (obs_or_fqs )} "
78
+ )
69
79
act_obs_or_fq = obs_or_fqs [0 ]
70
80
weight_obs_or_fq = obs_or_fqs [1 ]
71
81
act_scale , act_zp = act_obs_or_fq .calculate_qparams ()
@@ -94,9 +104,11 @@ def _derive_qparams_fn(
94
104
95
105
if self .bias is None :
96
106
return None
97
- assert (
98
- self .bias .dtype == torch .float
99
- ), "Only float dtype for bias is supported for bias right now"
107
+ # Validate that bias dtype is floating-point
108
+ if self .bias .dtype != torch .float :
109
+ raise ValueError (
110
+ "Only float dtype for bias is supported for bias right now"
111
+ )
100
112
return self .bias
101
113
102
114
def get_fixed_qspec (
0 commit comments