Skip to content

Infer logprob of absolute operations and fix logprob of powers with negative values #6414

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 7, 2023

Conversation

LukeLB
Copy link
Contributor

@LukeLB LukeLB commented Dec 27, 2022

What is this PR about?
Following from #6400 and implements #6402. This PR allows the use of an absolute transform for cases like

import pymc as pm

x = pm.math.abs(pm.Normal.dist())
y = pm.HalfNormal.dist()
assert pm.logp(x, 2.5).eval() == pm.logp(y, 2.5).eval()

I have included the above example as a test.

Some notes:

  • I wasn't completely sure on how to derive the Jacobian determinant for an absolute function, I think its just the derivative wrt x, so that's what I've included.
  • Should we expect x and y (in the example above) to return the same value if the test_value is negative? e.g.
import pymc as pm

x = pm.math.abs(pm.Normal.dist())
y = pm.HalfNormal.dist()
assert pm.logp(x, -2.5).eval() == pm.logp(y, -2.5).eval()

Because in this case, x returns -inf while y returns nan

Checklist

New features

  • The ability to evaluate the log prob of an absolute transformed random variable

@codecov
Copy link

codecov bot commented Dec 27, 2022

Codecov Report

Merging #6414 (0ce4eec) into main (b081cec) will increase coverage by 8.68%.
The diff coverage is 97.10%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #6414      +/-   ##
==========================================
+ Coverage   86.05%   94.73%   +8.68%     
==========================================
  Files         148      148              
  Lines       27646    27698      +52     
==========================================
+ Hits        23792    26241    +2449     
+ Misses       3854     1457    -2397     
Impacted Files Coverage Δ
pymc/logprob/transforms.py 96.13% <92.30%> (+28.35%) ⬆️
pymc/tests/logprob/test_transforms.py 99.75% <100.00%> (+99.75%) ⬆️
pymc/logprob/cumsum.py 100.00% <0.00%> (+3.12%) ⬆️
pymc/tests/logprob/utils.py 50.00% <0.00%> (+3.65%) ⬆️
pymc/logprob/rewriting.py 97.05% <0.00%> (+5.88%) ⬆️
pymc/logprob/abstract.py 97.56% <0.00%> (+6.09%) ⬆️
pymc/logprob/utils.py 100.00% <0.00%> (+13.79%) ⬆️
pymc/logprob/joint_logprob.py 97.01% <0.00%> (+19.40%) ⬆️
pymc/logprob/tensor.py 82.40% <0.00%> (+24.00%) ⬆️
... and 16 more

@LukeLB LukeLB changed the title Infer logprob of absolute operations #6402 Infer logprob of absolute operations Dec 27, 2022
@ricardoV94
Copy link
Member

ricardoV94 commented Dec 27, 2022

Thanks for opening a PR!

Should we expect x and y (in the example above) to return the same value if the test_value is negative? e.g.

Hmm, it should behave the same for negative values... Not sure what's the best way to encode that.

What happens with a sqrt transform? Is the jacobian what leads to -inf?

In that case the jacobian could perhaps be:

return at.switch(value >= 0, 0, -np.inf)?

But the logp will be broken in both cases without jacobian.

Perhaps we need to implement a specific logprob for the absolute and power transforms


Regardless of the solution, we should add test for those cases, hadn't thought of them

@LukeLB
Copy link
Contributor Author

LukeLB commented Dec 27, 2022

Thanks for opening a PR!

Should we expect x and y (in the example above) to return the same value if the test_value is negative? e.g.

Hmm, it should behave the same for negative values... Not sure what's the best way to encode that.

What happens with a sqrt transform? Is the jacobian what leads to -inf?

In that case the jacobian could perhaps be:

return at.switch(value >= 0, 0, -np.inf)?

But the logp will be broken in both cases without jacobian.

Perhaps we need to implement a specific logprob for the absolute and power transforms

Regardless of the solution, we should add test for those cases, hadn't thought of them

Applogies I have gotten this the wrong way round... in the example:

import pymc as pm
x = pm.math.abs(pm.Normal.dist())
y = pm.HalfNormal.dist() 
z = pm.Normal.dist() ** 0.5
print(pm.logp(x, -2.5).eval(), pm.logp(y, -2.5).eval(), pm.logp(z, -2.5).eval())

x = nan, y = -inf, z = nan. So the transforms return nan values.

@ricardoV94
Copy link
Member

Yup, -inf would be preferable

@ricardoV94
Copy link
Member

ricardoV94 commented Jan 3, 2023

Hi @LukeLB

I wanted to leave a review but ended up having to try many things locally that in the end it was easier to push my suggestions as a separate commit.

I realized the issue you found was not specific to power transforms, but any transform that is constrained. For instance the exp transform would also return nan for a negative value. I tweaked the logprob function to look for nan in the jacobian and return a -inf logp in that case. Since the jacobian depends on the values, this should be generally sufficient?

I also tweaked some of your tests and tried to simplify the PowerTransform logic. Let me know what you think, and if you agree with the changes I will clean a bit the commits before merging.

Thanks a lot for the work so far!

@LukeLB
Copy link
Contributor Author

LukeLB commented Jan 3, 2023

@ricardoV94 yes this looks like a more robust solution, just checking for the nan jacobian and returning -inf up stream seems is better especially if any other transformations are added in the future. Is there anything else you would like me to do on this?

@ricardoV94
Copy link
Member

I'll just clean a bit the commits. I don't think anything else is needed, just wanted your input :)

@LukeLB
Copy link
Contributor Author

LukeLB commented Jan 3, 2023

Awesome! Thank you for helping me through this PR!

ricardoV94 and others added 2 commits January 6, 2023 12:26
We rely on the jacobian returning `nan` for invalid values

Co-authored-by: Luke LB <[email protected]>
Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cleaned up the commits, didn't change any content

@ricardoV94 ricardoV94 changed the title Infer logprob of absolute operations Infer logprob of absolute operations and fix logprob of powers with negative values Jan 6, 2023
@ricardoV94 ricardoV94 merged commit db11a23 into pymc-devs:main Jan 7, 2023
@ricardoV94
Copy link
Member

Thanks @LukeLB! This one required a bit of digging because it wasn't really working from the previous PRs, nice finding!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants