Skip to content

Allow batched scalar sigma in ZeroSumNormal #7063

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 13, 2023

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Dec 12, 2023

Description

The check that existed before was meant to prevent non-scalar sigma across zero-sum axes, but there is no reason to prevent it across batch dimensions.

Related Issue

  • Closes #
  • Related to #

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pymc--7063.org.readthedocs.build/en/7063/

Copy link

codecov bot commented Dec 12, 2023

Codecov Report

Merging #7063 (9477d2e) into main (2e05854) will increase coverage by 0.00%.
Report is 1 commits behind head on main.
The diff coverage is 100.00%.

Additional details and impacted files

Impacted file tree graph

@@           Coverage Diff           @@
##             main    #7063   +/-   ##
=======================================
  Coverage   92.19%   92.20%           
=======================================
  Files         101      101           
  Lines       16893    16895    +2     
=======================================
+ Hits        15575    15578    +3     
+ Misses       1318     1317    -1     
Files Coverage Δ
pymc/distributions/multivariate.py 93.49% <100.00%> (+0.12%) ⬆️

Copy link
Member

@lucianopaz lucianopaz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me. I just left one nitpick and one question in the code

@@ -2706,7 +2702,12 @@ def check_zerosum_axes(cls, n_zerosum_axes: Optional[int]) -> int:

@classmethod
def rv_op(cls, sigma, n_zerosum_axes, support_shape, size=None):
shape = to_tuple(size) + tuple(support_shape)
if size is not None:
shape = tuple(size) + tuple(support_shape)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to check that size and sigma.shape are compatible or that done somewhere else before this call?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't have to check. Shape issues will arise when creating Normal below or at evaluation time when broadcasting fails between size and sigma

@ricardoV94 ricardoV94 force-pushed the fix_batched_sigma_zsn branch from 705fb55 to 9477d2e Compare December 13, 2023 10:31
@ricardoV94 ricardoV94 merged commit 35cd657 into pymc-devs:main Dec 13, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants