-
Notifications
You must be signed in to change notification settings - Fork 36
Length check before applying TensorProduct #111
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Length check before applying TensorProduct #111
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for opening an issue to sort this out. I think a slightly more general solution is in order though, as this problem is definitely floating around all over the place, and our errors are not the most easy to decipher at the minute.
In particular, the validate_dims
and validate_inplace_dims
functions in util.jl should be extended to also accept the kernel as an argument, and a check added to ensure that the inputs and the kernel have consistent input dimensionality -- this could be done cleanly by adding methods of dim
for a various kernels.
Not all kernels have a fixed / known input dimensionality, so a sensible fallback for dim(::Kernel)
would be a singleton type called ArbitraryDimensionality
with no fields that you can error check on. I'm thinking something like
struct ArbitraryDimensionality end
dim(::Kernel) = ArbitraryDimensionality()
dim(k::TensorProduct) = length(kernel.kernels)
# Example validate_dims implementation
function validate_dims(k::Kernel, x::AbstractVector, y::AbstractVector)
if dim(x) != dim(y)
throw(DimensionMismatch(
"Dimensionality of x ($(dim(x))) not equality to that of y ($(dim(y)))",
))
end
validate_kernel_dims(k, x)
end
function validate_kernel_dims(k::Kernel, x::AbstractVector)
if dim(k) !== ArbitraryDimensionality() && dim(k) != dim(x)
throw(DimensionMismatch(
"Dimensionality of kernel ($(dim(k))) not compatible ",
"with that of input ($(dim(y)))",
))
end
end
If you would rather open this out into a separate issue, that's also fine. But I think my comment regarding the type of error should definitely be addressed before merging as is.
An alternative would be to define isdimcompatible(::Kernel, x) = true
isdimcompatible(kernel::TensorProductKernel, x) = length(kernel.kernels) == dim(x)
function validate_kernel_dims(k::Kernel, x::AbstractVector)
if isdimcompatible(k, x)
throw(DimensionMismatch(
"Dimensionality of kernel not compatible ",
"with that of input",
))
end
end and only define |
Think I would prefer the name |
Actually do you think this is something we could put in common to use with specific transforms. Like if one use |
I guess you could just define (not sure about the exact fields an types) function dims_are_compatible(kernel::TransformedKernel{<:Kernel,<:LinearTransform}, x)
return size(kernel.transform.P, 2) == dim(x)
end Did you have that in mind? |
Yeah but it gets trickier for more complex structures. I think it's better to leave the error happen when the transformation is applied, and return a meaningful message. |
@willtebbutt I have addressed this suggestion. I will create a separate PR for a more general solution? |
LGTM. Will merge once tests pass. |
@sharanry yes please 🙂 |
As discussed with @willtebbutt earlier,
TensorProduct
doesn't check the dimension of the input before applying. This can lead to bugs which are hard to trace.Currently:
It picks up the first
length(k)
dimensions of the input, ignoring the rest.