-
Notifications
You must be signed in to change notification settings - Fork 36
Input check relaxation and change of `pairwise #147
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
I guess it would be good to add some tests? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is looking great. Could you please add something to the docs explaining this interface for people who wish to implement new types?
src/utils.jl
Outdated
@@ -94,9 +91,24 @@ For a transform return its parameters, for a `ChainTransform` return a vector of | |||
""" | |||
#params | |||
|
|||
dim(x) = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add another comment here pointing out that this is a passes-by-default choice.
@willtebbutt I added some docs and changed slightly the |
Co-authored-by: David Widmann <[email protected]>
docs/src/create_kernel.md
Outdated
- Define the trainable parameters of your kernel with `KernelFunctions.trainable(k)` which should return a `Tuple` of your parameters. | ||
This parameters will be then passed to `Flux.params` function | ||
- `KernelFunctions.iskroncompatible(k)`, if your kernel factorizes in the dimensions. You can declare your kernel as `iskroncompatible(k) = true` | ||
- `KernelFunctions.dim`: by default the dimension of the inputs will only be checked for vectors of `AbstractVector{<:Real}`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be good to generalise this to say that you're allowed to add methods to validate_inputs
if you've got special checks that you want to run for your input types, with dim
being a nice special case.
Co-authored-by: willtebbutt <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm happy with this now, subject to @devmotion being happy and a patch version bump.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
This PR aims at :
validate_dims
is replaced byvalidate_inputs
. It will only check dimensions forColVecs
andRowVecs
.broadcast(k, x, y')
bybroadcast(k, x, permutedims(y))
in ourpairwise
for vectors. It appears to solve some Zygote issue (for me) and make it more coherent with the rest of the package