Skip to content

Infer logprob of Ifelse graphs #6529

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 17, 2023
Merged

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Feb 17, 2023

import pymc as pm
import numpy as np

from pytensor.ifelse import ifelse

def ifelse_dist(cond, _):
    x = pm.Normal.dist(mu=[-1, 1], shape=(2,))
    y = pm.HalfNormal.dist(sigma=[0.1, 1, 10, 100], shape=(4,))
    return ifelse(cond, x, y)

with pm.Model() as m:
    cond = pm.Bernoulli("cond", p=0.5)
    y = pm.CustomDist("y", 1-cond, dist=ifelse_dist)

logp = m.compile_logp([y], sum=False)
print(logp({"cond": 0, "y": np.ones(2)}))  # [array([-2.91893853, -0.91893853])]
print(logp({"cond": 1, "y": np.ones(4)}))  # [array([-47.92320626,  -0.72579135,  -2.53337645,  -4.83101154])]

This kind of graphs can't be handled by switch mixtures, because that requires outputs to have the same shapes.

This could one day be useful for variable size samplers... for now it's just here for completeness and libraries building on top of PyMC.

@codecov
Copy link

codecov bot commented Feb 17, 2023

Codecov Report

Merging #6529 (d89ce91) into main (c709abf) will decrease coverage by 4.82%.
The diff coverage is 34.61%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #6529      +/-   ##
==========================================
- Coverage   92.02%   87.21%   -4.82%     
==========================================
  Files          92       92              
  Lines       15563    15587      +24     
==========================================
- Hits        14322    13594     -728     
- Misses       1241     1993     +752     
Impacted Files Coverage Δ
pymc/logprob/mixture.py 31.03% <34.61%> (-67.29%) ⬇️

... and 14 files with indirect coverage changes

Copy link
Member

@lucianopaz lucianopaz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice approach @ricardoV94. I left a pair of comments. One thing that I found confusing were the first two commits. Are those really central to this PR or could they be PR in itself?

rvs_to_values_else = {else_rv: value for else_rv, value in zip(base_rvs[len(values) :], values)}

logps_then = [
logprob(rv_then, value, **kwargs) for rv_then, value in rvs_to_values_then.items()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What’s the behaviour of logprob when the value variable doesn’t have the expected shape? Will it raise a ValueError? I think that would be bad in this case. Imagine if you have a step method on the condition variable. The stepper might choose to move the condition to an infeasible value and that would kill the sampling process. I would like the condition that doesn’t match shapes to simply return -inf logprob. That way the stepper would discard the proposal and stay in reasonable regions.

Copy link
Member Author

@ricardoV94 ricardoV94 Feb 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The behavior of the logprob when the value variable doesn't match shape is exactly the same as the logprob of the underlying components. If it doesn't match, it will try to broadcast and fail if it cannot. Other than that we don't use size information in any of the core logprob functions.

pm.logp(pm.Normal.dist(size=(4,)), np.ones((2,))) will be happy to return a logp with two values.

One other case where this shows up is in graphs of the form pt.ones((5,)) + pm.Normal.dist() which we infer to have an equivalent logp as that of pm.Normal.dist(shape=(5,)) even though the generative process contains only one true random variable, and not 5.

I think we need a bigger discussion about the role of shape information in the random and logp graphs, so I wouldn't treat IfElse differently for now.

Copy link
Member Author

@ricardoV94 ricardoV94 Feb 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Opened an issue here: #6530

logps_then = replace_rvs_by_values(logps_then, rvs_to_values=rvs_to_values_then)
logps_else = replace_rvs_by_values(logps_else, rvs_to_values=rvs_to_values_else)

return ifelse(if_var, logps_then, logps_else)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the goal only to allow for explicit conditions in the mixture instead of marginalising?

Copy link
Member Author

@ricardoV94 ricardoV94 Feb 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The applications may vary. We don't explicitly marginalize anything in the logprob submodule, but something like MarginalModel would marginalize this just fine if we can give it the logprob function that this PR offers.

@ricardoV94
Copy link
Member Author

ricardoV94 commented Feb 17, 2023

Nice approach @ricardoV94. I left a pair of comments. One thing that I found confusing were the first two commits. Are those really central to this PR or could they be PR in itself?

They could be in a PR by itself, but this one builds directly on the refactoring done in those commits, so I would wait for it to get merged first if it was separate.

@ricardoV94
Copy link
Member Author

ricardoV94 commented Mar 14, 2023

One thing that I found confusing were the first two commits. Are those really central to this PR or could they be PR in itself?

These have now been merged elsewhere

Copy link
Member

@larryshamalama larryshamalama left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some more questions :)

@michaelosthege michaelosthege added this to the v5.2.0 milestone Mar 16, 2023
@ricardoV94 ricardoV94 merged commit 0334994 into pymc-devs:main Mar 17, 2023
@ricardoV94 ricardoV94 deleted the ifelse_logprob branch June 6, 2023 03:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants