-
Notifications
You must be signed in to change notification settings - Fork 363
Group norm bug fix #3014
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Group norm bug fix #3014
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -149,6 +149,8 @@ def native_group_norm( | |
eps: float, | ||
return_mean_rstd: bool = True, | ||
) -> Union[TRTTensor, Sequence[TRTTensor]]: | ||
# TODO: Ask TRT team about the usage of INormalization Layer usage with num_groups and update the implementation | ||
# with INormalization Layer | ||
assert ( | ||
len(input.shape) >= 3 | ||
), f"The input dimension should not be less than 3, got {len(input.shape)}!" | ||
|
@@ -187,28 +189,105 @@ def native_group_norm( | |
shape, | ||
) | ||
|
||
if weight is None: | ||
weight = to_numpy(1.0) | ||
|
||
if bias is None: | ||
bias = to_numpy(0.0) | ||
|
||
weight = get_trt_tensor(ctx, weight, f"{name}_weight") | ||
bias = get_trt_tensor(ctx, bias, f"{name}_bias") | ||
if tuple(reshaped_input.shape) != tuple(weight.shape): | ||
weight = impl.slice.expand( | ||
ctx, | ||
target, | ||
source_ir, | ||
f"{name}_expand_weight", | ||
weight, | ||
reshaped_input.shape, | ||
) | ||
if tuple(reshaped_input.shape) != tuple(bias.shape): | ||
bias = impl.slice.expand( | ||
ctx, target, source_ir, f"{name}_expand_bias", bias, reshaped_input.shape | ||
) | ||
weight_bias_shape = (1, C) + (1,) * (len(input.shape) - 2) | ||
|
||
dims = list(range(1, len(input.shape))) | ||
axes = get_axes_for_reduce_op(dims) | ||
group_norm = ctx.net.add_normalization(reshaped_input, weight, bias, axes) | ||
group_norm.epsilon = eps | ||
group_norm.compute_precision = input.dtype | ||
set_layer_name(group_norm, target, f"{name}_group_norm", source_ir) | ||
output = group_norm.get_output(0) | ||
|
||
# E[X] | ||
mean_trt = impl.reduce.mean( | ||
ctx, | ||
target, | ||
source_ir, | ||
f"{name}_mean", | ||
reshaped_input, | ||
dims, | ||
True, | ||
) | ||
|
||
mean_trt = impl.slice.expand( | ||
ctx, | ||
target, | ||
source_ir, | ||
f"{name}_expand_mean_trt", | ||
mean_trt, | ||
reshaped_input.shape, | ||
) | ||
|
||
# X - E[X] | ||
sub_trt = impl.elementwise.sub( | ||
ctx, | ||
target, | ||
source_ir, | ||
f"{name}_sub", | ||
reshaped_input, | ||
mean_trt, | ||
) | ||
|
||
# variance | ||
pow_trt = get_trt_tensor(ctx, 2, f"{name}_power", np.float32) | ||
pow_var = impl.elementwise.pow( | ||
ctx, | ||
target, | ||
source_ir, | ||
f"{name}_pow", | ||
sub_trt, | ||
pow_trt, | ||
) | ||
|
||
var_trt = impl.reduce.mean( | ||
ctx, | ||
target, | ||
source_ir, | ||
f"{name}_mean_var", | ||
pow_var, | ||
dims, | ||
True, | ||
) | ||
|
||
var_trt = impl.slice.expand( | ||
ctx, | ||
target, | ||
source_ir, | ||
f"{name}_expand_var_trt", | ||
var_trt, | ||
reshaped_input.shape, | ||
) | ||
|
||
eps_trt = get_trt_tensor(ctx, eps, f"{name}_eps", np.float32) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this converted to np.float32? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is 1e-6 in most cases. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. float16 cannot handle that close to 0 |
||
add_trt = impl.elementwise.add( | ||
ctx, | ||
target, | ||
source_ir, | ||
f"{name}_add", | ||
var_trt, | ||
eps_trt, | ||
) | ||
|
||
sqrt_trt = impl.unary.sqrt( | ||
ctx, | ||
target, | ||
source_ir, | ||
f"{name}_sqrt", | ||
add_trt, | ||
) | ||
|
||
# y = (X - E[X]) / sqrt((var + eps)) | ||
output = impl.elementwise.div( | ||
ctx, | ||
target, | ||
source_ir, | ||
f"{name}_div", | ||
sub_trt, | ||
sqrt_trt, | ||
) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wanted to clarify if this div would require any mode, eg: trunc? Are the data types always compatible with the output types? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think so. The previous implementation from Evan did not include any mode. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. just FYI: we have |
||
shape = list(output.shape) | ||
for i, s in enumerate(shape): | ||
|
@@ -222,6 +301,40 @@ def native_group_norm( | |
reshaped_output = impl.shuffle.reshape( | ||
ctx, target, source_ir, f"{name}_reshape_output", output, shape | ||
) | ||
reshaped_gamma = impl.shuffle.reshape( | ||
ctx, | ||
target, | ||
source_ir, | ||
f"{name}_reshape_gamma", | ||
weight, | ||
weight_bias_shape, | ||
) | ||
|
||
reshaped_output = impl.elementwise.mul( | ||
ctx, | ||
target, | ||
source_ir, | ||
f"{name}_mul_gamma", | ||
reshaped_output, | ||
reshaped_gamma, | ||
) | ||
|
||
reshaped_bias = impl.shuffle.reshape( | ||
ctx, | ||
target, | ||
source_ir, | ||
f"{name}_reshape_beta", | ||
bias, | ||
weight_bias_shape, | ||
) | ||
reshaped_output = impl.elementwise.add( | ||
ctx, | ||
target, | ||
source_ir, | ||
f"{name}_add_beta", | ||
reshaped_output, | ||
reshaped_bias, | ||
) | ||
cehongwang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if return_mean_rstd: | ||
# return fake mean and rstd for now | ||
return reshaped_output, None, None | ||
|
Uh oh!
There was an error while loading. Please reload this page.