@@ -36,8 +36,8 @@ def batch_norm(
36
36
input : TRTTensor ,
37
37
weight : Optional [Union [torch .Tensor , np .ndarray ]],
38
38
bias : Optional [Union [torch .Tensor , np .ndarray ]],
39
- running_mean : Union [TRTTensor , Optional [ Union [ torch .Tensor , np .ndarray ] ]],
40
- running_var : Union [TRTTensor , Optional [ Union [ torch .Tensor , np .ndarray ] ]],
39
+ running_mean : Optional [ Union [TRTTensor , torch .Tensor , np .ndarray ]],
40
+ running_var : Optional [ Union [TRTTensor , torch .Tensor , np .ndarray ]],
41
41
training : bool ,
42
42
momentum : float ,
43
43
eps : float ,
@@ -55,13 +55,13 @@ def batch_norm(
55
55
if isinstance (running_mean , TRTTensor ) or isinstance (running_var , TRTTensor ):
56
56
# Default values if weight, bias, running_mean, running_var are None
57
57
if weight is None :
58
- weight = get_trt_tensor (ctx , 1.0 , f"{ name } _weight" , input . dtype )
58
+ weight = get_trt_tensor (ctx , 1.0 , f"{ name } _weight" )
59
59
if bias is None :
60
- bias = get_trt_tensor (ctx , 0.0 , f"{ name } _bias" , input . dtype )
60
+ bias = get_trt_tensor (ctx , 0.0 , f"{ name } _bias" )
61
61
if running_mean is None :
62
- running_mean = get_trt_tensor (ctx , 0.0 , f"{ name } _running_mean" , input . dtype )
62
+ running_mean = get_trt_tensor (ctx , 0.0 , f"{ name } _running_mean" )
63
63
if running_var is None :
64
- running_var = get_trt_tensor (ctx , 1.0 , f"{ name } _running_var" , input . dtype )
64
+ running_var = get_trt_tensor (ctx , 1.0 , f"{ name } _running_var" )
65
65
66
66
# eps_tensor for numerical stability
67
67
eps_tensor = get_trt_tensor (ctx , eps , f"{ name } _eps" )
0 commit comments