Skip to content

Length check before applying TensorProduct #111

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 5, 2020
Merged

Length check before applying TensorProduct #111

merged 2 commits into from
May 5, 2020

Conversation

sharanry
Copy link
Contributor

@sharanry sharanry commented May 4, 2020

As discussed with @willtebbutt earlier, TensorProduct doesn't check the dimension of the input before applying. This can lead to bugs which are hard to trace.

Currently:

julia> k = TensorProduct(ConstantKernel(c=2.0), ConstantKernel(c=2.0))
Tensor product of 2 kernels:
	- Constant Kernel (c = 2.0)
	- Constant Kernel (c = 2.0)

julia> k(rand(3), rand(3))
4.0

It picks up the first length(k) dimensions of the input, ignoring the rest.

Copy link
Member

@willtebbutt willtebbutt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for opening an issue to sort this out. I think a slightly more general solution is in order though, as this problem is definitely floating around all over the place, and our errors are not the most easy to decipher at the minute.

In particular, the validate_dims and validate_inplace_dims functions in util.jl should be extended to also accept the kernel as an argument, and a check added to ensure that the inputs and the kernel have consistent input dimensionality -- this could be done cleanly by adding methods of dim for a various kernels.

Not all kernels have a fixed / known input dimensionality, so a sensible fallback for dim(::Kernel) would be a singleton type called ArbitraryDimensionality with no fields that you can error check on. I'm thinking something like

struct ArbitraryDimensionality end

dim(::Kernel) = ArbitraryDimensionality()
dim(k::TensorProduct) = length(kernel.kernels)

# Example validate_dims implementation

function validate_dims(k::Kernel, x::AbstractVector, y::AbstractVector)
    if dim(x) != dim(y)
        throw(DimensionMismatch(
            "Dimensionality of x ($(dim(x))) not equality to that of y ($(dim(y)))",
        ))
    end
    validate_kernel_dims(k, x)
end

function validate_kernel_dims(k::Kernel, x::AbstractVector)
    if dim(k) !== ArbitraryDimensionality() && dim(k) != dim(x)
        throw(DimensionMismatch(
            "Dimensionality of kernel ($(dim(k))) not compatible ",
            "with that of input ($(dim(y)))",
        ))
    end
end

If you would rather open this out into a separate issue, that's also fine. But I think my comment regarding the type of error should definitely be addressed before merging as is.

@devmotion
Copy link
Member

devmotion commented May 4, 2020

Not all kernels have a fixed / known input dimensionality, so a sensible fallback for dim(::Kernel) would be a singleton type called ArbitraryDimensionality with no fields that you can error check on. I'm thinking something like

An alternative would be to define

isdimcompatible(::Kernel, x) = true
isdimcompatible(kernel::TensorProductKernel, x) = length(kernel.kernels) == dim(x)

function validate_kernel_dims(k::Kernel, x::AbstractVector)
    if isdimcompatible(k, x)
        throw(DimensionMismatch(
            "Dimensionality of kernel not compatible ",
            "with that of input",
        ))
    end
end

and only define isdimcompatible(kernel, x) if needed.

@willtebbutt
Copy link
Member

Think I would prefer the name dims_are_compatible(k, x), but that's a good point @devmotion .

@theogf
Copy link
Member

theogf commented May 4, 2020

Actually do you think this is something we could put in common to use with specific transforms. Like if one use LinearTransform or ARDTransform the dimension of the kernel is no more arbitrary.

@devmotion
Copy link
Member

I guess you could just define (not sure about the exact fields an types)

function dims_are_compatible(kernel::TransformedKernel{<:Kernel,<:LinearTransform}, x)
    return size(kernel.transform.P, 2) == dim(x)
end

Did you have that in mind?

@theogf
Copy link
Member

theogf commented May 4, 2020

Yeah but it gets trickier for more complex structures. I think it's better to leave the error happen when the transformation is applied, and return a meaningful message.

@sharanry
Copy link
Contributor Author

sharanry commented May 5, 2020

If you would rather open this out into a separate issue, that's also fine. But I think my comment regarding the type of error should definitely be addressed before merging as is.

@willtebbutt I have addressed this suggestion. I will create a separate PR for a more general solution?

@willtebbutt
Copy link
Member

LGTM. Will merge once tests pass.

@willtebbutt willtebbutt merged commit 2e92ecd into JuliaGaussianProcesses:master May 5, 2020
@willtebbutt
Copy link
Member

I will create a separate PR for a more general solution?

@sharanry yes please 🙂

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants