Skip to content

Commit c65baf7

Browse files
Raise NotImplementedError in Group.__init__
1 parent fef06a2 commit c65baf7

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

pymc3/variational/opvi.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
from pymc3.model import modelcontext
6464
from pymc3.util import get_default_varnames, get_transformed
6565
from pymc3.variational.updates import adagrad_window
66+
from pymc3.vartypes import discrete_types
6667

6768
__all__ = ["ObjectiveFunction", "Operator", "TestFunction", "Group", "Approximation"]
6869

@@ -831,6 +832,9 @@ def __init__(
831832
options=None,
832833
**kwargs,
833834
):
835+
# XXX: Needs to be refactored for v4
836+
raise NotImplementedError("This class needs to be refactored for v4")
837+
834838
if local and not self.supports_batched:
835839
raise LocalGroupError("%s does not support local groups" % self.__class__)
836840
if local and rowwise:
@@ -957,7 +961,7 @@ def __init_group__(self, group):
957961
# self.ordering = ArrayOrdering([])
958962
self.replacements = dict()
959963
for var in self.group:
960-
if isinstance(var.distribution, pm.Discrete):
964+
if var.type.numpy_dtype.name in discrete_types:
961965
raise ParametrizationError(f"Discrete variables are not supported by VI: {var}")
962966
begin = self.ddim
963967
if self.batched:

0 commit comments

Comments
 (0)