-
Notifications
You must be signed in to change notification settings - Fork 91
Generic flatten (2d and 3d) #202
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
LGTM |
@@ -18,13 +18,20 @@ module nf_flatten_layer | |||
integer, allocatable :: input_shape(:) | |||
integer :: output_size | |||
|
|||
real, allocatable :: gradient(:,:,:) | |||
real, allocatable :: gradient_2d(:,:) | |||
real, allocatable :: gradient_3d(:,:,:) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, I thought about that but decided not to make the code even less SOLID
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But here we have a choice between SOLID and less boilerplate, I think I agree that the second one is better
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, and, most importantly for me, this approach allows for a unified API (only one flatten()
for the user).
src/nf/nf_flatten_layer.f90
Outdated
|
||
procedure :: forward_2d | ||
procedure :: forward_3d | ||
generic :: forward => forward_2d, forward_3d |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps, just make it one method with assumed-rank input?
pure module subroutine forward(self, input)
class(flatten_layer), intent(in out) :: self
real, intent(in) :: input(..)
select rank(input)
rank(2)
self % output = pack(input, .true.)
rank(3)
self % output = pack(input, .true.)
rank default
error stop "Unsupported rank of input"
end select
end subroutine forward
It will reduce boilerplate a little
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool! If it works, let's do it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
Thank you! I'll rebase and test it out! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Everything works. Great job!
I attempted to make a generic
flatten
so that the user doesn't need to doflatten2d
. It seems like it will work.In support of
Linear2d
(#197)