Skip to content

Extend logprob inference for scans with carried auxiliary states #6582

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 16, 2023

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Mar 8, 2023

PyMC can now infer graphs like MA(2)

import numpy as np
import pytensor

import pymc as pm

init_eps = pm.DiracDelta.dist(np.zeros(2, dtype="float64"))
rho = [0.3, 0.9]
sigma = 1
n_steps = 100

def ma2_step(eps_tm2, eps_tm1, rho, sigma):
    mu = eps_tm1 * rho[0] + eps_tm2 * rho[1]
    y = pm.Normal.dist(mu, sigma)
    eps = y - mu
    return eps, y

[_, ma2], _ = pytensor.scan(
    fn=ma2_step,
    outputs_info=[{"initial": init_eps, "taps": range(-2, 0)}, None],
    non_sequences=[rho, sigma],
    n_steps=n_steps,
    name="ma2",
)

ma2_test = ma2.eval()
pm.logp(ma2, ma2_test).sum().eval()

CC @jessegrabowski

@codecov
Copy link

codecov bot commented Mar 8, 2023

Codecov Report

Merging #6582 (eddcc8f) into main (c709abf) will decrease coverage by 14.28%.
The diff coverage is 12.50%.

Additional details and impacted files

Impacted file tree graph

@@             Coverage Diff             @@
##             main    #6582       +/-   ##
===========================================
- Coverage   92.02%   77.75%   -14.28%     
===========================================
  Files          92       92               
  Lines       15563    15592       +29     
===========================================
- Hits        14322    12123     -2199     
- Misses       1241     3469     +2228     
Impacted Files Coverage Δ
pymc/logprob/scan.py 17.67% <7.40%> (-79.25%) ⬇️
pymc/testing.py 82.19% <23.07%> (-9.65%) ⬇️

... and 53 files with indirect coverage changes

@ricardoV94 ricardoV94 force-pushed the fix_scan_infer_logp_ma2 branch from 7b0dadd to eddcc8f Compare March 14, 2023 10:40
@ricardoV94 ricardoV94 requested a review from Armavica March 14, 2023 10:40
Copy link
Member

@larryshamalama larryshamalama left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ouf, Scans a bit beyond my area of comfort. I just have minor comments for this one

@twiecki
Copy link
Member

twiecki commented Mar 16, 2023

Need to add tags so this shows up in release notes.

@ricardoV94 ricardoV94 changed the title Extend logprob inference for scans with carried deterministic states Extend logprob inference for scans with carried auxiliary states Mar 16, 2023
@ricardoV94 ricardoV94 merged commit 5dcd101 into pymc-devs:main Mar 16, 2023
@ricardoV94
Copy link
Member Author

Need to add tags so this shows up in release notes.

Already had. Without tags it still shows up in the release notes, but in the catch-all "maintenance" bucket. The only case it doesn't show up is if it has the label "no releasenotes"

@ricardoV94 ricardoV94 deleted the fix_scan_infer_logp_ma2 branch June 6, 2023 03:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants