Skip to content

Better float32 sampling support for TruncatedNormal #7026

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Nov 28, 2023

Conversation

JasonTam
Copy link
Contributor

@JasonTam JasonTam commented Nov 23, 2023

When using float32 (via PYTENSOR_FLAGS='floatX=float32'), the current usage of stats.truncnorm.rvs in rng_fn can return nan for TruncatedNormalRV (and probably other functions that use a scipy flavored rvs). Reason is described in the issue linked below:

An issue was raised here (but will be closed without fix):
scipy/scipy#19554 (and at a higher level, scipy/scipy#15928 )

This PR is meant to just be a draft based on the suggested workaround and a discussion of how to best proceed. For now, I have personally just been using a custom patched distribution class with the changes in the PR.


📚 Documentation preview 📚: https://pymc--7026.org.readthedocs.build/en/7026/

Copy link

codecov bot commented Nov 23, 2023

Codecov Report

Merging #7026 (abe6db9) into main (547bcb4) will decrease coverage by 1.54%.
Report is 8 commits behind head on main.
The diff coverage is 100.00%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #7026      +/-   ##
==========================================
- Coverage   92.17%   90.64%   -1.54%     
==========================================
  Files         101      101              
  Lines       16849    16827      -22     
==========================================
- Hits        15530    15252     -278     
- Misses       1319     1575     +256     
Files Coverage Δ
pymc/distributions/continuous.py 97.77% <100.00%> (-0.03%) ⬇️

... and 5 files with indirect coverage changes

@ricardoV94
Copy link
Member

@JasonTam thanks for the investigation and write-up. I think the other suggestion in the linked thread is easier to maintain: upcast the parameters to float64 and downcast to the dtype of the RV afterwards.

@ricardoV94 ricardoV94 added the bug label Nov 23, 2023
@JasonTam
Copy link
Contributor Author

@ricardoV94 I've updated the PR to showcase the upcast/downcast method
Based on some quick tests, performance seems comparable between the methods. Casting seems to be a bit better

with pm.Model() as model_cmp:
    x_control = TruncatedNormal("x_control", mu=1, sigma=3, lower=0, upper=10, size=100_000)
    x_manual_inv = CustomTruncatedNormal1("x_manual_inv", mu=1, sigma=3, lower=0, upper=10, size=100_000)
    x_cast_cast = CustomTruncatedNormal2("x_cast_cast", mu=1, sigma=3, lower=0, upper= 10, size=100_000)
---
%%timeit
x_control_samp = x_control.eval()
8.99 ms ± 61.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
---
%%timeit
x_manual_inv_samp = x_manual_inv.eval()
9.74 ms ± 83.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
_samp
---
%%timeit
x_cast_cast_samp = x_cast_cast.eval()
9.18 ms ± 130 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

I'm just curious in terms of the pymc library, is this something that is worth fixing for all scipy RV's, or just mentioned in comments/docs, or maybe just having this PR here will be enough for posterity while scipy/scipy#15928 is worked on.

@ricardoV94
Copy link
Member

No we shouldn't do anything special for all functions. This is a Scipy bug, but also seems to depend on how exactly the RV is being generated. If it was more widespread, it would be worth trying to press Scipy to fix it sooner.

@ricardoV94
Copy link
Member

ricardoV94 commented Nov 23, 2023

calling eval is not a good way to benchmark, most of the time is consumed by compiling the function and very littel actually evaluating it. You could simply time the python code directly.

In any case, RVs are not so time-sensitive like other aspects of the library. Otherwise we would not be using simple python code for them to begin with.

@JasonTam
Copy link
Contributor Author

JasonTam commented Nov 24, 2023

Oh, I see. I tried timing the rng_fn themselves and saw very similar results (see below).
Anyway, I guess we can close this then(?), and for anyone that runs into this issue, hopefully this workaround will be found.

def rng_fn_control(
    rng: np.random.RandomState,
    mu: Union[np.ndarray, float],
    sigma: Union[np.ndarray, float],
    lower: Union[np.ndarray, float],
    upper: Union[np.ndarray, float],
    size: Optional[Union[List[int], int]],
) -> np.ndarray:
    return stats.truncnorm.rvs(
        a=(lower - mu) / sigma,
        b=(upper - mu) / sigma,
        loc=mu,
        scale=sigma,
        size=size,
        random_state=rng,
    )
    
def rng_fn_manual_inv(
    rng: np.random.RandomState,
    mu: Union[np.ndarray, float],
    sigma: Union[np.ndarray, float],
    lower: Union[np.ndarray, float],
    upper: Union[np.ndarray, float],
    size: Optional[Union[List[int], int]],
) -> np.ndarray:
    a = (lower - mu) / sigma
    b = (upper - mu) / sigmadist = stats.truncnorm(a=a, b=b, loc=mu, scale=sigma)
    # Underlying uniform should be of the same dtype as other inputs (`mu` for now)
    ps = rng.random(size=size, dtype=mu.dtype)
    return dist.ppf(ps).clip(lower, upper)
​
​
def rng_fn_cast_cast(
    rng: np.random.RandomState,
    mu: Union[np.ndarray, float],
    sigma: Union[np.ndarray, float],
    lower: Union[np.ndarray, float],
    upper: Union[np.ndarray, float],
    size: Optional[Union[List[int], int]],
) -> np.ndarray:
    return stats.truncnorm.rvs(
        a=((lower - mu) / sigma).astype('float64'),
        b=((upper - mu) / sigma).astype('float64'),
        loc=(mu).astype('float64'),
        scale=(sigma).astype('float64'),
        size=size,
        random_state=rng,
    ).astype(mu.dtype)
rng = np.random.default_rng(0)
mu, sigma, lower, upper = np.asarray([1, 3, 0, 10], dtype=np.float32)
size = 100_000
---
%%timeit
x_control_samp = rng_fn(rng, mu, sigma, lower, upper, size)
9.08 ms ± 48.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
---
%%timeit
x_manual_inv_samp = rng_fn_manual_inv(rng, mu, sigma, lower, upper, size)
9.87 ms ± 65.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
---
%%timeit
x_cast_cast_samp = rng_fn_cast_cast(rng, mu, sigma, lower, upper, size)
9.24 ms ± 98 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

@ricardoV94
Copy link
Member

Looks good to me, just need to run pre-commit. It's complaining about the unused import

@twiecki
Copy link
Member

twiecki commented Nov 26, 2023

/pre-commit-run

@twiecki twiecki marked this pull request as ready for review November 28, 2023 13:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants