-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Better float32 sampling support for TruncatedNormal #7026
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #7026 +/- ##
==========================================
- Coverage 92.17% 90.64% -1.54%
==========================================
Files 101 101
Lines 16849 16827 -22
==========================================
- Hits 15530 15252 -278
- Misses 1319 1575 +256
|
@JasonTam thanks for the investigation and write-up. I think the other suggestion in the linked thread is easier to maintain: upcast the parameters to float64 and downcast to the dtype of the RV afterwards. |
@ricardoV94 I've updated the PR to showcase the upcast/downcast method with pm.Model() as model_cmp:
x_control = TruncatedNormal("x_control", mu=1, sigma=3, lower=0, upper=10, size=100_000)
x_manual_inv = CustomTruncatedNormal1("x_manual_inv", mu=1, sigma=3, lower=0, upper=10, size=100_000)
x_cast_cast = CustomTruncatedNormal2("x_cast_cast", mu=1, sigma=3, lower=0, upper= 10, size=100_000)
---
%%timeit
x_control_samp = x_control.eval()
8.99 ms ± 61.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
---
%%timeit
x_manual_inv_samp = x_manual_inv.eval()
9.74 ms ± 83.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
_samp
---
%%timeit
x_cast_cast_samp = x_cast_cast.eval()
9.18 ms ± 130 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) I'm just curious in terms of the pymc library, is this something that is worth fixing for all scipy RV's, or just mentioned in comments/docs, or maybe just having this PR here will be enough for posterity while scipy/scipy#15928 is worked on. |
No we shouldn't do anything special for all functions. This is a Scipy bug, but also seems to depend on how exactly the RV is being generated. If it was more widespread, it would be worth trying to press Scipy to fix it sooner. |
calling In any case, RVs are not so time-sensitive like other aspects of the library. Otherwise we would not be using simple python code for them to begin with. |
Oh, I see. I tried timing the def rng_fn_control(
rng: np.random.RandomState,
mu: Union[np.ndarray, float],
sigma: Union[np.ndarray, float],
lower: Union[np.ndarray, float],
upper: Union[np.ndarray, float],
size: Optional[Union[List[int], int]],
) -> np.ndarray:
return stats.truncnorm.rvs(
a=(lower - mu) / sigma,
b=(upper - mu) / sigma,
loc=mu,
scale=sigma,
size=size,
random_state=rng,
)
def rng_fn_manual_inv(
rng: np.random.RandomState,
mu: Union[np.ndarray, float],
sigma: Union[np.ndarray, float],
lower: Union[np.ndarray, float],
upper: Union[np.ndarray, float],
size: Optional[Union[List[int], int]],
) -> np.ndarray:
a = (lower - mu) / sigma
b = (upper - mu) / sigma
dist = stats.truncnorm(a=a, b=b, loc=mu, scale=sigma)
# Underlying uniform should be of the same dtype as other inputs (`mu` for now)
ps = rng.random(size=size, dtype=mu.dtype)
return dist.ppf(ps).clip(lower, upper)
def rng_fn_cast_cast(
rng: np.random.RandomState,
mu: Union[np.ndarray, float],
sigma: Union[np.ndarray, float],
lower: Union[np.ndarray, float],
upper: Union[np.ndarray, float],
size: Optional[Union[List[int], int]],
) -> np.ndarray:
return stats.truncnorm.rvs(
a=((lower - mu) / sigma).astype('float64'),
b=((upper - mu) / sigma).astype('float64'),
loc=(mu).astype('float64'),
scale=(sigma).astype('float64'),
size=size,
random_state=rng,
).astype(mu.dtype)
rng = np.random.default_rng(0)
mu, sigma, lower, upper = np.asarray([1, 3, 0, 10], dtype=np.float32)
size = 100_000
---
%%timeit
x_control_samp = rng_fn(rng, mu, sigma, lower, upper, size)
9.08 ms ± 48.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
---
%%timeit
x_manual_inv_samp = rng_fn_manual_inv(rng, mu, sigma, lower, upper, size)
9.87 ms ± 65.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
---
%%timeit
x_cast_cast_samp = rng_fn_cast_cast(rng, mu, sigma, lower, upper, size)
9.24 ms ± 98 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) |
Looks good to me, just need to run pre-commit. It's complaining about the unused import |
/pre-commit-run |
When using float32 (via
PYTENSOR_FLAGS='floatX=float32'
), the current usage ofstats.truncnorm.rvs
inrng_fn
can returnnan
forTruncatedNormalRV
(and probably other functions that use a scipy flavored rvs). Reason is described in the issue linked below:An issue was raised here (but will be closed without fix):
scipy/scipy#19554 (and at a higher level, scipy/scipy#15928 )
This PR is meant to just be a draft based on the suggested workaround and a discussion of how to best proceed. For now, I have personally just been using a custom patched distribution class with the changes in the PR.
📚 Documentation preview 📚: https://pymc--7026.org.readthedocs.build/en/7026/