diff --git a/pymc_marketing/mmm/causal.py b/pymc_marketing/mmm/causal.py index 3b3eb1701..a3fc5b845 100644 --- a/pymc_marketing/mmm/causal.py +++ b/pymc_marketing/mmm/causal.py @@ -11,11 +11,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Causal identification class.""" +"""Causal module.""" +from __future__ import annotations + +import re import warnings +try: + import networkx as nx +except ImportError: # Optional dependency + nx = None # type: ignore[assignment] import pandas as pd +import pymc as pm +import pytensor.tensor as pt +from pydantic import Field, InstanceOf, validate_call +from pymc_extras.prior import Prior try: from dowhy import CausalModel @@ -34,6 +45,476 @@ def __init__(self, *args, **kwargs): CausalModel = LazyCausalModel +class BuildModelFromDAG: + """Build a PyMC probabilistic model directly from a Causal DAG and a tabular dataset. + + The class interprets a Directed Acyclic Graph (DAG) where each node is a column + in the provided `df`. For every edge ``A -> B`` it creates a slope prior for + the contribution of ``A`` into the mean of ``B``. Each node receives a + likelihood prior. Dims and coords are used to align and index observed data + via ``pm.Data`` and xarray. + + Parameters + ---------- + dag : str + DAG in DOT format (e.g. ``digraph { A -> B; B -> C; }``) or as a simple + comma/newline separated list of edges (e.g. ``"A->B, B->C"``). + df : pandas.DataFrame + DataFrame that contains a column for every node present in the DAG and + all columns named by the provided ``dims``. + target : str + Name of the target node present in both the DAG and ``df``. This is not + used to restrict modeling but is validated to exist in the DAG. + dims : tuple[str, ...] + Dims for the observed variables and likelihoods (e.g. ``("date", "channel")``). + coords : dict + Mapping from dim names to coordinate values. All coord keys must exist as + columns in ``df`` and will be used to pivot the data to match dims. + model_config : dict, optional + Optional configuration with priors for keys ``"intercept"``, ``"slope"`` and + ``"likelihood"``. Values should be ``pymc_extras.prior.Prior`` instances. + Missing keys fall back to :pyattr:`default_model_config`. + + Examples + -------- + Minimal example using DOT format: + + .. code-block:: python + + import numpy as np + import pandas as pd + + from pymc_marketing.mmm.causal import BuildModelFromDAG + + dates = pd.date_range("2024-01-01", periods=5, freq="D") + df = pd.DataFrame( + { + "date": dates, + "X": np.random.normal(size=5), + "Y": np.random.normal(size=5), + } + ) + + dag = "digraph { X -> Y; }" + dims = ("date",) + coords = {"date": dates} + + builder = BuildModelFromDAG( + dag=dag, df=df, target="Y", dims=dims, coords=coords + ) + model = builder.build() + + Edge-list format and custom likelihood prior: + + .. code-block:: python + + from pymc_extras.prior import Prior + + dag = "X->Y" # equivalent to the DOT example above + model_config = { + "likelihood": Prior( + "StudentT", nu=5, sigma=Prior("HalfNormal", sigma=1), dims=("date",) + ), + } + + builder = BuildModelFromDAG( + dag=dag, + df=df, + target="Y", + dims=("date",), + coords={"date": dates}, + model_config=model_config, + ) + model = builder.build() + """ + + @validate_call + def __init__( + self, + *, + dag: str = Field(..., description="DAG in DOT string format or A->B list"), + df: InstanceOf[pd.DataFrame] = Field( + ..., description="DataFrame containing all DAG node columns" + ), + target: str = Field(..., description="Target node name present in DAG and df"), + dims: tuple[str, ...] = Field( + ..., description="Dims for observed/likelihood variables" + ), + coords: dict = Field( + ..., + description=( + "Required coords mapping for dims and priors. All coord keys must exist as columns in df." + ), + ), + model_config: dict | None = Field( + None, + description=( + "Optional model config with Priors for 'intercept', 'slope' and " + "'likelihood'. Keys not supplied fall back to defaults." + ), + ), + ) -> None: + self.dag = dag + self.df = df + self.target = target + self.dims = dims + self.coords = coords + + # Parse graph and validate target + self.graph = self._parse_dag(self.dag) + self.nodes = list(nx.topological_sort(self.graph)) + if self.target not in self.nodes: + raise ValueError(f"Target '{self.target}' not in DAG nodes: {self.nodes}") + + # Merge provided model_config with defaults + provided = model_config + self.model_config = self.default_model_config + if provided is not None: + self.model_config.update(provided) + + # Validate required priors are present and of correct type + self._validate_model_config_priors() + + # Validate coords are present and consistent with dims, priors, and df + self._validate_coords_required_are_consistent() + + # Validate prior dims consistency early (does not require building the model) + self._warning_if_slope_dims_dont_match_likelihood_dims() + self._validate_intercept_dims_match_slope_dims() + + @property + def default_model_config(self) -> dict[str, Prior]: + """Default priors for intercepts, slopes and likelihood using ``pymc_extras.Prior``. + + Returns + ------- + dict + Dictionary with keys ``"intercept"``, ``"slope"`` and ``"likelihood"`` + mapping to ``Prior`` instances with dims derived from + :pyattr:`dims`. + """ + slope_dims = tuple(dim for dim in (self.dims or ()) if dim != "date") + return { + "intercept": Prior("Normal", mu=0, sigma=1, dims=slope_dims), + "slope": Prior("Normal", mu=0, sigma=1, dims=slope_dims), + "likelihood": Prior( + "Normal", + sigma=Prior("HalfNormal", sigma=1), + dims=self.dims, + ), + } + + @staticmethod + def _parse_dag(dag_str: str) -> nx.DiGraph: + """Parse DOT digraph or edge-list string into a directed acyclic graph.""" + if nx is None: + raise ImportError( + "To use Causal Graph functionality, please install the optional dependencies with: " + "pip install pymc-marketing[dag]" + ) + # Primary format: DOT digraph + s = dag_str.strip() + g = nx.DiGraph() + + if s.lower().startswith("digraph"): + # Extract content within the first top-level {...} + brace_start = s.find("{") + brace_end = s.rfind("}") + if brace_start == -1 or brace_end == -1 or brace_end <= brace_start: + raise ValueError("Malformed DOT digraph: missing braces") + body = s[brace_start + 1 : brace_end] + + # Remove comments (// ... or # ... at line end) + lines = [] + for raw_line in body.splitlines(): + line = re.split(r"//|#", raw_line, maxsplit=1)[0].strip() + if line: + lines.append(line) + body = "\n".join(lines) + + # Find edges "A -> B" possibly ending with ';' + for m in re.finditer( + r"\b([A-Za-z0-9_]+)\s*->\s*([A-Za-z0-9_]+)\s*;?", body + ): + a, b = m.group(1), m.group(2) + g.add_edge(a, b) + + # Find standalone node declarations (lines with single identifier, optional ';') + for raw_line in body.splitlines(): + line = raw_line.strip().rstrip(";") + if not line or "->" in line or "[" in line or "]" in line: + continue + mnode = re.match(r"^([A-Za-z0-9_]+)$", line) + if mnode: + g.add_node(mnode.group(1)) + + else: + # Fallback: simple comma/newline-separated "A->B" tokens + edges: list[tuple[str, str]] = [] + for token in re.split(r"[,\n]+", s): + token = token.strip().rstrip(";") + if not token: + continue + medge = re.match(r"^([A-Za-z0-9_]+)\s*->\s*([A-Za-z0-9_]+)$", token) + if not medge: + raise ValueError(f"Invalid edge token: '{token}'") + a, b = medge.group(1), medge.group(2) + edges.append((a, b)) + g.add_edges_from(edges) + + if not nx.is_directed_acyclic_graph(g): + raise ValueError("Provided graph is not a DAG.") + return g + + def _warning_if_slope_dims_dont_match_likelihood_dims(self) -> None: + """Warn if slope prior dims differ from likelihood dims without the 'date' dim.""" + slope_prior = self.model_config["slope"] + likelihood_prior = self.model_config["likelihood"] + + like_dims = getattr(likelihood_prior, "dims", None) + if isinstance(like_dims, str): + like_dims = (like_dims,) + elif isinstance(like_dims, list): + like_dims = tuple(like_dims) + + # Guard against None dims (treat as empty) + if like_dims is None: + expected_slope_dims = () + else: + expected_slope_dims = tuple(dim for dim in like_dims if dim != "date") + + slope_dims = getattr(slope_prior, "dims", None) + if slope_dims is None or not isinstance(slope_dims, tuple): + slope_dims = () + elif isinstance(slope_dims, str): + slope_dims = (slope_dims,) + elif isinstance(slope_dims, list): + slope_dims = tuple(slope_dims) + + if slope_dims != expected_slope_dims: + warnings.warn( + ( + "Slope prior dims " + f"{slope_dims if slope_dims else '()'} do not match expected dims " + f"{expected_slope_dims} (likelihood dims without 'date')." + ), + stacklevel=2, + ) + + def _validate_intercept_dims_match_slope_dims(self) -> None: + """Ensure intercept prior dims match slope prior dims exactly.""" + + def _to_tuple(maybe_dims): + if maybe_dims is None: + return tuple() + if isinstance(maybe_dims, str): + return (maybe_dims,) + if isinstance(maybe_dims, (list, tuple)): + return tuple(maybe_dims) + return tuple() + + slope_dims = _to_tuple(getattr(self.model_config["slope"], "dims", None)) + intercept_dims = _to_tuple( + getattr(self.model_config["intercept"], "dims", None) + ) + + if slope_dims != intercept_dims: + raise ValueError( + "model_config['intercept'].dims must match model_config['slope'].dims. " + f"Got intercept dims {intercept_dims or '()'} and slope dims {slope_dims or '()'}." + ) + + def _validate_model_config_priors(self) -> None: + """Ensure required model_config entries are Prior instances. + + Enforces that keys 'slope' and 'likelihood' exist and are Prior objects, + so downstream code can safely index and call Prior helper methods. + """ + required_keys = ("intercept", "slope", "likelihood") + for key in required_keys: + if key not in self.model_config: + raise ValueError(f"model_config must include '{key}' as a Prior.") + for key in required_keys: + if not isinstance(self.model_config[key], Prior): + raise TypeError( + f"model_config['{key}'] must be a Prior, got " + f"{type(self.model_config[key]).__name__}." + ) + + def _validate_coords_required_are_consistent(self) -> None: + """Validate mutual consistency among dims, coords, priors, and data columns.""" + if self.coords is None: + raise ValueError("'coords' is required and cannot be None.") + + # 1) All coords keys must correspond to columns in the dataset + for key in self.coords.keys(): + if key not in self.df.columns: + raise KeyError( + f"Coordinate key '{key}' not found in DataFrame columns. Present columns: {list(self.df.columns)}" + ) + + # 2) Ensure dims are present in coords + for d in self.dims: + if d not in self.coords: + raise ValueError(f"Missing coordinate values for dim '{d}' in coords.") + + # 3) Ensure Prior.dims exist in coords (for all top-level priors we manage) + def _to_tuple(maybe_dims): + if isinstance(maybe_dims, str): + return (maybe_dims,) + if isinstance(maybe_dims, (list, tuple)): + return tuple(maybe_dims) + else: + return tuple() + + for prior_name, prior in self.model_config.items(): + if not isinstance(prior, Prior): + continue + for d in _to_tuple(getattr(prior, "dims", None)): + if d not in self.coords: + raise ValueError( + f"Dim '{d}' declared in Prior '{prior_name}' must be present in coords." + ) + + # 4) Enforce that likelihood dims match class dims exactly + likelihood_prior = self.model_config["likelihood"] + likelihood_dims = _to_tuple(getattr(likelihood_prior, "dims", None)) + if likelihood_dims and tuple(self.dims) != likelihood_dims: + raise ValueError( + "Likelihood Prior dims " + f"{likelihood_dims} must match class dims {tuple(self.dims)}. " + "When supplying a custom model_config, ensure likelihood.dims equals the 'dims' argument." + ) + + def _parents(self, node: str) -> list[str]: + """Return the list of parent node names for the given DAG node.""" + return list(self.graph.predecessors(node)) + + def build(self) -> pm.Model: + """Construct and return the PyMC model implied by the DAG and data. + + The method creates a ``pm.Data`` container for every node to align the + observed data with the declared ``dims``. For each edge ``A -> B``, a + slope prior is instantiated from ``model_config['slope']`` and used in the + mean of node ``B``'s likelihood, which is instantiated from + ``model_config['likelihood']``. + + Returns + ------- + pymc.Model + A fully specified model with slopes and likelihoods for all nodes. + + Examples + -------- + Build a model and sample from it: + + .. code-block:: python + + builder = BuildModelFromDAG( + dag="A->B", df=df, target="B", dims=("date",), coords={"date": dates} + ) + model = builder.build() + with model: + idata = pm.sample(100, tune=100, chains=2, cores=2) + + Multi-dimensional dims (e.g. date and country): + + .. code-block:: python + + dims = ("date", "country") + coords = {"date": dates, "country": ["Venezuela", "Colombia"]} + builder = BuildModelFromDAG( + dag="A->B, B->Y", df=df, target="Y", dims=dims, coords=coords + ) + model = builder.build() + """ + dims = self.dims + coords = self.coords + + with pm.Model(coords=coords) as model: + data_containers: dict[str, pm.Data] = {} + for node in self.nodes: + if node not in self.df.columns: + raise KeyError(f"Column '{node}' not found in df.") + # Ensure observed data has shape consistent with declared dims by pivoting via xarray + indexed = self.df.set_index(list(dims)) + xarr = indexed.to_xarray()[node] + values = xarr.values + + data_containers[node] = pm.Data(f"_{node}", values, dims=dims) + + # For each node add slope priors per parent and likelihood with sigma prior + slope_rvs: dict[tuple[str, str], pt.TensorVariable] = {} + + # Create priors in a stable deterministic order + for node in self.nodes: + parents = self._parents(node) + # Slopes for each parent -> node + mu_expr = 0 + for parent in parents: + slope_name = f"{parent.lower()}{node.lower()}" + slope_rv = self.model_config["slope"].create_variable(slope_name) + slope_rvs[(parent, node)] = slope_rv + mu_expr += slope_rv * data_containers[parent] + intercept_rv = self.model_config["intercept"].create_variable( + f"{node.lower()}_intercept" + ) + + self.model_config["likelihood"].create_likelihood_variable( + name=node, + mu=mu_expr + intercept_rv, + observed=data_containers[node], + ) + + self.model = model + return self.model + + def model_graph(self): + """Return a Graphviz visualization of the built PyMC model. + + Returns + ------- + graphviz.Source + Graphviz object representing the model graph. + + Examples + -------- + .. code-block:: python + + model = builder.build() + g = builder.model_graph() + g + """ + if not hasattr(self, "model"): + raise RuntimeError("Call build() first.") + return pm.model_to_graphviz(self.model) + + def dag_graph(self): + """Return a copy of the parsed DAG as a NetworkX directed graph. + + Returns + ------- + networkx.DiGraph + A directed acyclic graph with the same nodes and edges as the input DAG. + + Examples + -------- + .. code-block:: python + + g = builder.dag_graph() + list(g.edges()) + """ + if nx is None: + raise ImportError( + "To use Causal Graph functionality, please install the optional dependencies with: " + "pip install pymc-marketing[dag]" + ) + g = nx.DiGraph() + g.add_nodes_from(self.graph.nodes) + g.add_edges_from(self.graph.edges) + return g + + class CausalGraphModel: """Represent a causal model based on a Directed Acyclic Graph (DAG). @@ -64,7 +545,7 @@ def __init__( @classmethod def build_graphical_model( cls, graph: str, treatment: list[str] | tuple[str], outcome: str - ) -> "CausalGraphModel": + ) -> CausalGraphModel: """Create a CausalGraphModel from a string representation of a graph. Parameters diff --git a/tests/mmm/test_causal.py b/tests/mmm/test_causal.py index 243caf54b..354bdc9cc 100644 --- a/tests/mmm/test_causal.py +++ b/tests/mmm/test_causal.py @@ -13,14 +13,609 @@ # limitations under the License. import warnings +import graphviz +import networkx as nx +import numpy as np +import pandas as pd +import pymc as pm import pytest +from pydantic import ValidationError +from pymc_extras.prior import Prior -from pymc_marketing.mmm.causal import CausalGraphModel +from pymc_marketing.mmm.causal import BuildModelFromDAG, CausalGraphModel # Suppress specific dowhy warnings globally warnings.filterwarnings("ignore", message="The graph defines .* variables") +@pytest.fixture +def causal_df(): + rng = np.random.default_rng(123) + N = 500 + Q = rng.normal(size=N) + X = rng.normal(loc=0.14 * Q, scale=0.4, size=N) + Y = rng.normal(loc=0.7 * X + 0.11 * Q, scale=0.24, size=N) + P = rng.normal(loc=0.43 * X + 0.21 * Y, scale=0.22, size=N) + + dates = pd.date_range("2023-01-01", periods=N, freq="D") + return pd.DataFrame({"date": dates, "Q": Q, "X": X, "Y": Y, "P": P}) + + +def test_build_raises_when_coords_key_not_in_df(causal_df): + dag = """ + digraph { + Q -> X; + X -> Y; + Y -> P; + } + """ + # Inject an extra coordinate not present in the dataframe columns + coords = {"date": causal_df["date"].to_numpy()} + coords["ghost"] = np.arange(len(causal_df)) + + with pytest.raises( + KeyError, match="Coordinate key 'ghost' not found in DataFrame columns" + ): + BuildModelFromDAG( + dag=dag, + df=causal_df, + target="Y", + dims=("date",), + coords=coords, + ) + + +def test_build_raises_when_df_missing_column_present_in_coords(causal_df): + dag = """ + digraph { + Q -> X; + X -> Y; + Y -> P; + } + """ + # Inject an extra coordinate not present in the dataframe columns + coords = {"date": causal_df["date"].to_numpy()} + df_missing_date = causal_df.drop(columns=["date"]) # Remove date from dataset + + with pytest.raises( + KeyError, match="Coordinate key 'date' not found in DataFrame columns" + ): + BuildModelFromDAG( + dag=dag, + df=df_missing_date, + target="Y", + dims=("date",), + coords=coords, + ) + + +def test_build_with_custom_priors_builds(causal_df): + dag = """ + digraph { + Q -> X; + X -> Y; + Y -> P; + } + """ + + # Custom priors with matching dims expectation (likelihood has 'date', slope has no dims) + custom_config = { + "intercept": Prior("Normal", mu=0, sigma=0.5), + "slope": Prior("Normal", mu=0, sigma=0.5), # no dims implies () + "likelihood": Prior( + "Normal", sigma=Prior("HalfNormal", sigma=0.5), dims=("date",) + ), + } + + coords = {"date": causal_df["date"].unique()} + + builder = BuildModelFromDAG( + dag=dag, + df=causal_df, + target="Y", + dims=("date",), + coords=coords, + model_config=custom_config, + ) + + model = builder.build() + assert isinstance(model, pm.Model) + + +def test_warning_when_slope_dims_missing_vs_likelihood_dims(causal_df): + dag = """ + digraph { + Q -> X; + X -> Y; + Y -> P; + } + """ + + causal_df["country"] = "Venezuela" + + custom_config = { + "intercept": Prior("Normal", mu=0, sigma=1), + "slope": Prior("Normal", mu=0, sigma=1), # no dims + "likelihood": Prior( + "Normal", sigma=Prior("HalfNormal", sigma=1), dims=("date", "country") + ), + } + + coords = { + "date": causal_df["date"].unique(), + "country": causal_df["country"].unique(), + } + + with pytest.warns(UserWarning, match="Slope prior dims"): + builder = BuildModelFromDAG( + dag=dag, + df=causal_df, + target="Y", + dims=("date", "country"), + coords=coords, + model_config=custom_config, + ) + model = builder.build() + assert isinstance(model, pm.Model) + + +def test_no_warning_when_slope_dims_match_likelihood_dims(causal_df): + dag = """ + digraph { + Q -> X; + X -> Y; + Y -> P; + } + """ + + causal_df["country"] = "Venezuela" + + custom_config = { + "intercept": Prior("Normal", mu=0, sigma=1, dims=("country",)), + "slope": Prior("Normal", mu=0, sigma=1, dims=("country",)), + "likelihood": Prior( + "Normal", sigma=Prior("HalfNormal", sigma=1), dims=("date", "country") + ), + } + + coords = { + "date": causal_df["date"].unique(), + "country": causal_df["country"].unique(), + } + + builder = BuildModelFromDAG( + dag=dag, + df=causal_df, + target="Y", + dims=("date", "country"), + coords=coords, + model_config=custom_config, + ) + model = builder.build() + + assert isinstance(model, pm.Model) + + +def test_error_when_likelihood_dims_differ_from_class_dims(causal_df): + dag = """ + digraph { + Q -> X; + X -> Y; + Y -> P; + } + """ + + causal_df["country"] = "Venezuela" + + # Class dims only includes date, while likelihood dims include date and country -> should error + custom_config = { + "intercept": Prior("Normal", mu=0, sigma=1), + "slope": Prior("Normal", mu=0, sigma=1), + "likelihood": Prior( + "Normal", sigma=Prior("HalfNormal", sigma=1), dims=("date", "country") + ), + } + + coords = { + "date": causal_df["date"].unique(), + "country": causal_df["country"].unique(), + } + + with pytest.raises( + ValueError, match=r"Likelihood Prior dims .* must match class dims .*" + ): + BuildModelFromDAG( + dag=dag, + df=causal_df, + target="Y", + dims=("date",), + coords=coords, + model_config=custom_config, + ) + + +def test_model_and_dag_graph_return_types(causal_df): + dag = """ + digraph { + Q -> X; + X -> Y; + } + """ + + coords = {"date": causal_df["date"].unique()} + + builder = BuildModelFromDAG( + dag=dag, + df=causal_df, + target="Y", + dims=("date",), + coords=coords, + ) + model = builder.build() + assert isinstance(model, pm.Model) + + mg = builder.model_graph() + dg = builder.dag_graph() + assert isinstance(mg, graphviz.Digraph) + assert isinstance(dg, nx.DiGraph) + + +def test_default_model_config_contents_and_types(causal_df): + dag = """ + digraph { + Q -> X; + X -> Y; + } + """ + + coords = {"date": causal_df["date"].unique()} + + builder = BuildModelFromDAG( + dag=dag, + df=causal_df, + target="Y", + dims=("date",), + coords=coords, + ) + + cfg = builder.model_config + assert set(cfg.keys()) >= {"intercept", "slope", "likelihood"} + assert isinstance(cfg["intercept"], Prior) + assert isinstance(cfg["slope"], Prior) + assert isinstance(cfg["likelihood"], Prior) + + # Check default dims + like_dims = cfg["likelihood"].dims + if isinstance(like_dims, str): + like_dims = (like_dims,) + elif isinstance(like_dims, list): + like_dims = tuple(like_dims) + assert like_dims == ("date",) + + slope_dims = cfg["slope"].dims + if slope_dims is None: + slope_dims = tuple() + elif isinstance(slope_dims, str): + slope_dims = (slope_dims,) + elif isinstance(slope_dims, list): + slope_dims = tuple(slope_dims) + intercept_dims = cfg["intercept"].dims + if intercept_dims is None: + intercept_dims = tuple() + elif isinstance(intercept_dims, str): + intercept_dims = (intercept_dims,) + elif isinstance(intercept_dims, list): + intercept_dims = tuple(intercept_dims) + # Expect dims without 'date' -> empty tuple + assert slope_dims == tuple() + assert intercept_dims == slope_dims + + +def test_parse_dag_parses_dot_and_simple_formats(): + # DOT format + dag_dot = """ + digraph { + A -> B; + B -> C; + } + """ + g_dot = BuildModelFromDAG._parse_dag(dag_dot) + assert isinstance(g_dot, nx.DiGraph) + assert set(g_dot.edges()) == {("A", "B"), ("B", "C")} + + # Simple A->B tokens format + dag_simple = "A->B, B->C" + g_simple = BuildModelFromDAG._parse_dag(dag_simple) + assert isinstance(g_simple, nx.DiGraph) + assert set(g_simple.edges()) == {("A", "B"), ("B", "C")} + + # Cycle should raise + with pytest.raises(ValueError, match="not a DAG"): + BuildModelFromDAG._parse_dag("A->B, B->A") + + +def test_init_raises_when_target_not_in_dag(causal_df): + dag = """ + digraph { + A -> B; + } + """ + + coords = {"date": causal_df["date"].unique()} + + with pytest.raises(ValueError, match=r"Target 'Z' not in DAG nodes"): + BuildModelFromDAG( + dag=dag, + df=causal_df.rename(columns={"Q": "A", "X": "B"}), + target="Z", + dims=("date",), + coords=coords, + ) + + +def test_parse_dag_malformed_dot_raises(): + malformed = "digraph { A -> B;" # missing closing brace + with pytest.raises(ValueError, match="Malformed DOT digraph: missing braces"): + BuildModelFromDAG._parse_dag(malformed) + + +def test_parse_dag_handles_comments_and_standalone_nodes(): + dag = """ + digraph { + // comment line + A; + A -> B; // edge comment + C; # standalone node with hash comment + B -> C; + } + """ + g = BuildModelFromDAG._parse_dag(dag) + assert set(g.edges()) == {("A", "B"), ("B", "C")} + assert set(g.nodes()) >= {"A", "B", "C"} + + +def test_parse_dag_invalid_simple_token_raises(): + with pytest.raises(ValueError, match="Invalid edge token"): + BuildModelFromDAG._parse_dag("A-B, C->D") + + +def test_validate_coords_raises_when_coords_none(causal_df): + dag = """ + digraph { + Q -> X; + } + """ + # Pydantic validate_call intercepts before our internal check + with pytest.raises(ValidationError): + BuildModelFromDAG( + dag=dag, + df=causal_df, + target="X", + dims=("date",), + coords=None, + ) + + +def test_validate_coords_raises_when_dim_missing_in_coords(causal_df): + dag = """ + digraph { + Q -> X; + } + """ + causal_df["country"] = "Venezuela" + coords = {"date": causal_df["date"].unique()} + with pytest.raises( + ValueError, match=r"Missing coordinate values for dim 'country'" + ): + BuildModelFromDAG( + dag=dag, + df=causal_df, + target="X", + dims=("date", "country"), + coords=coords, + ) + + +def test_validate_coords_raises_when_prior_dims_not_in_coords(causal_df): + dag = """ + digraph { + Q -> X; + } + """ + coords = {"date": causal_df["date"].unique()} + custom_config = { + prior_name: Prior("Normal", mu=0, sigma=1, dims=("country",)) + for prior_name in ("intercept", "slope") + } + with pytest.raises( + ValueError, + match=r"Dim 'country' declared in Prior '(?:intercept|slope)' must be present in coords", + ): + BuildModelFromDAG( + dag=dag, + df=causal_df, + target="X", + dims=("date",), + coords=coords, + model_config=custom_config, + ) + + +def test_no_warning_when_dims_given_as_str_and_list(causal_df): + dag = """ + digraph { + Q -> X; + X -> Y; + } + """ + causal_df["country"] = "Venezuela" + custom_config = { + "intercept": Prior("Normal", mu=0, sigma=1, dims="country"), + "slope": Prior("Normal", mu=0, sigma=1, dims="country"), + "likelihood": Prior( + "Normal", sigma=Prior("HalfNormal", sigma=1), dims=["date", "country"] + ), + } + coords = { + "date": causal_df["date"].unique(), + "country": causal_df["country"].unique(), + } + builder = BuildModelFromDAG( + dag=dag, + df=causal_df, + target="Y", + dims=("date", "country"), + coords=coords, + model_config=custom_config, + ) + model = builder.build() + assert isinstance(model, pm.Model) + + +def test_likelihood_dims_none_init_ok(causal_df): + dag = """ + digraph { + Q -> X; + } + """ + coords = {"date": causal_df["date"].unique()} + custom_config = { + "intercept": Prior("Normal", mu=0, sigma=1), + "slope": Prior("Normal", mu=0, sigma=1), + "likelihood": Prior("Normal", sigma=Prior("HalfNormal", sigma=1), dims=None), + } + builder = BuildModelFromDAG( + dag=dag, + df=causal_df, + target="X", + dims=("date",), + coords=coords, + model_config=custom_config, + ) + assert isinstance(builder, BuildModelFromDAG) + + +def test_validate_coords_required_raises_valueerror_when_none(causal_df): + """Test that directly calling _validate_coords_required_are_consistent raises ValueError when coords is None.""" + dag = """ + digraph { + Q -> X; + } + """ + + # Create builder without going through pydantic validation + builder = object.__new__(BuildModelFromDAG) + builder.dag = dag + builder.df = causal_df + builder.target = "X" + builder.dims = ("date",) + builder.coords = None # Explicitly set to None + builder.graph = BuildModelFromDAG._parse_dag(dag) + builder.nodes = list(nx.topological_sort(builder.graph)) + builder.model_config = { + "intercept": Prior("Normal", mu=0, sigma=1), + "slope": Prior("Normal", mu=0, sigma=1), + "likelihood": Prior( + "Normal", sigma=Prior("HalfNormal", sigma=1), dims=("date",) + ), + } + + # This should raise the specific ValueError + with pytest.raises(ValueError, match=r"'coords' is required and cannot be None\."): + builder._validate_coords_required_are_consistent() + + +def test_error_when_likelihood_in_model_config_is_none(causal_df): + dag = """ + digraph { + Q -> X; + } + """ + coords = {"date": causal_df["date"].unique()} + with pytest.raises( + TypeError, match=r"model_config\['likelihood'\] must be a Prior" + ): + BuildModelFromDAG( + dag=dag, + df=causal_df, + target="X", + dims=("date",), + coords=coords, + model_config={ + "intercept": Prior("Normal", mu=0, sigma=1), + "likelihood": None, + "slope": Prior("Normal", mu=0, sigma=1), + }, + ) + + +def test_build_raises_when_missing_column_from_df(causal_df): + dag = """ + digraph { + A -> B; + } + """ + # Create df missing column 'B' + df = causal_df.rename(columns={"Q": "A"})[["date", "A"]] + coords = {"date": df["date"].unique()} + builder = BuildModelFromDAG( + dag=dag, + df=df, + target="B", + dims=("date",), + coords=coords, + ) + with pytest.raises(KeyError, match="Column 'B' not found in df"): + builder.build() + + +def test_model_graph_raises_when_called_before_build(causal_df): + dag = """ + digraph { + Q -> X; + } + """ + coords = {"date": causal_df["date"].unique()} + builder = BuildModelFromDAG( + dag=dag, + df=causal_df, + target="X", + dims=("date",), + coords=coords, + ) + with pytest.raises(RuntimeError, match=r"Call build\(\) first"): + builder.model_graph() + + +def test_default_model_config_slope_dims_excludes_date_multi_dim(causal_df): + dag = """ + digraph { + Q -> X; + X -> Y; + } + """ + causal_df["country"] = "Venezuela" + coords = { + "date": causal_df["date"].unique(), + "country": causal_df["country"].unique(), + } + builder = BuildModelFromDAG( + dag=dag, + df=causal_df, + target="Y", + dims=("date", "country"), + coords=coords, + ) + slope_dims = builder.model_config["slope"].dims + if isinstance(slope_dims, str): + slope_dims = (slope_dims,) + elif isinstance(slope_dims, list): + slope_dims = tuple(slope_dims) + elif slope_dims is None: + slope_dims = tuple() + assert slope_dims == ("country",) + + @pytest.mark.filterwarnings("ignore:The graph defines .* variables") @pytest.mark.parametrize( "dag, treatment, outcome, expected_adjustment_set", @@ -173,3 +768,83 @@ def test_compute_adjustment_sets( assert adjusted_controls == expected_controls, ( f"Expected {expected_controls}, but got {adjusted_controls}" ) + + +def test_networkx_import_error_in_parse_dag(monkeypatch): + """Test that _parse_dag raises ImportError when networkx is not available.""" + # Mock nx to be None + monkeypatch.setattr("pymc_marketing.mmm.causal.nx", None) + + with pytest.raises( + ImportError, + match=( + r"To use Causal Graph functionality, please install the " + r"optional dependencies with: pip install pymc-marketing\[dag\]" + ), + ): + BuildModelFromDAG._parse_dag("A->B") + + +def test_networkx_import_error_in_dag_graph(causal_df, monkeypatch): + """Test that dag_graph raises ImportError when networkx is not available.""" + dag = """ + digraph { + Q -> X; + } + """ + coords = {"date": causal_df["date"].unique()} + + # First create the builder normally + builder = BuildModelFromDAG( + dag=dag, + df=causal_df, + target="X", + dims=("date",), + coords=coords, + ) + + # Then mock nx to be None and test dag_graph + monkeypatch.setattr("pymc_marketing.mmm.causal.nx", None) + + with pytest.raises( + ImportError, + match=( + r"To use Causal Graph functionality, please install the " + r"optional dependencies with: pip install pymc-marketing\[dag\]" + ), + ): + builder.dag_graph() + + +def test_causal_model_lazy_import_when_dowhy_missing(monkeypatch): + """Test that CausalModel uses LazyCausalModel when dowhy is not available.""" + # Simulate dowhy not being installed by removing it from sys.modules + import sys + + import pymc_marketing.mmm.causal as causal_module + + # Save the original CausalModel + original_causal_model = causal_module.CausalModel + + try: + monkeypatch.setitem(sys.modules, "dowhy", None) + + # Force reload of the causal module to trigger the import error path + import importlib + + importlib.reload(causal_module) + + # Now CausalModel should be the LazyCausalModel that raises ImportError + with pytest.raises( + ImportError, + match=( + r"To use Causal Graph functionality, please install the " + r"optional dependencies with: pip install pymc-marketing\[dag\]" + ), + ): + causal_module.CausalModel( + data=pd.DataFrame(), graph="A->B", treatment=["A"], outcome="B" + ) + finally: + # Restore the original CausalModel to prevent test pollution + causal_module.CausalModel = original_causal_model