19
19
import re
20
20
import warnings
21
21
from collections .abc import Sequence
22
+ from typing import Annotated , Literal
22
23
23
24
try :
24
25
import networkx as nx
@@ -637,26 +638,43 @@ class TBFPC:
637
638
- Kass, R. & Raftery, A. (1995). "Bayes Factors."
638
639
"""
639
640
641
+ @validate_call (config = dict (arbitrary_types_allowed = True ))
640
642
def __init__ (
641
643
self ,
642
- target : str ,
644
+ target : Annotated [
645
+ str ,
646
+ Field (
647
+ min_length = 1 ,
648
+ description = "Name of the outcome variable to orient the search." ,
649
+ ),
650
+ ],
643
651
* ,
644
- target_edge_rule : str = "any" ,
645
- bf_thresh : float = 1.0 ,
652
+ target_edge_rule : Literal [ "any" , "conservative" , "fullS" ] = "any" ,
653
+ bf_thresh : Annotated [ float , Field ( gt = 0.0 )] = 1.0 ,
646
654
forbidden_edges : Sequence [tuple [str , str ]] | None = None ,
647
655
):
656
+ """Create a new TBFPC causal discovery model.
657
+
658
+ Parameters
659
+ ----------
660
+ target
661
+ Variable name for the model outcome; must be present in the data
662
+ used during fitting.
663
+ target_edge_rule
664
+ Rule that controls which driver → target edges are retained.
665
+ Options are ``"any"``, ``"conservative"``, and ``"fullS"``.
666
+ bf_thresh
667
+ Positive Bayes factor threshold applied during conditional
668
+ independence tests.
669
+ forbidden_edges
670
+ Optional sequence of node pairs that must not be connected in the
671
+ learned graph.
672
+ """
648
673
warnings .warn (
649
674
"TBFPC is experimental and its API may change; use with caution." ,
650
675
UserWarning ,
651
676
stacklevel = 2 ,
652
677
)
653
- if not isinstance (target , str ) or not target :
654
- raise ValueError ("target must be a non-empty string" )
655
- allowed_rules = {"any" , "conservative" , "fullS" }
656
- if target_edge_rule not in allowed_rules :
657
- raise ValueError (f"target_edge_rule must be one of { allowed_rules } " )
658
- if not isinstance (bf_thresh , (int , float )) or bf_thresh <= 0 :
659
- raise ValueError ("bf_thresh must be a positive float" )
660
678
661
679
self .target = target
662
680
self .target_edge_rule = target_edge_rule
@@ -670,7 +688,8 @@ def __init__(
670
688
self .nodes_ : list [str ] = []
671
689
self .test_results : dict [tuple [str , str , frozenset ], dict [str , float ]] = {}
672
690
673
- # Shared response vector for symbolic BIC
691
+ # Shared response vector for symbolic BIC computation
692
+ # Initialized with placeholder; will be updated with actual data during fitting
674
693
self .y_sh = pytensor .shared (np .zeros (1 , dtype = "float64" ), name = "y_sh" )
675
694
self ._bic_fn = self ._build_symbolic_bic_fn ()
676
695
@@ -708,17 +727,47 @@ def _remove_all(self, u: str, v: str) -> None:
708
727
self ._adj_directed .discard ((v , u ))
709
728
710
729
def _build_symbolic_bic_fn (self ):
711
- """Build and compile a function to compute BIC given a design matrix ``X`` and sample size ``n`` ."""
730
+ """Build a BIC callable using a fast solver with a pseudoinverse fallback ."""
712
731
X = tt .matrix ("X" )
713
732
n = tt .iscalar ("n" )
714
733
715
- beta = tt .nlinalg .pinv (X ) @ self .y_sh
716
- resid = self .y_sh - X @ beta
717
- rss = tt .sum (resid ** 2 )
734
+ xtx = tt .dot (X .T , X )
735
+ xty = tt .dot (X .T , self .y_sh )
736
+
737
+ beta_solve = tt .linalg .solve (xtx , xty )
738
+ resid_solve = self .y_sh - tt .dot (X , beta_solve )
739
+ rss_solve = tt .sum (resid_solve ** 2 )
740
+
741
+ beta_pinv = tt .nlinalg .pinv (X ) @ self .y_sh
742
+ resid_pinv = self .y_sh - tt .dot (X , beta_pinv )
743
+ rss_pinv = tt .sum (resid_pinv ** 2 )
744
+
718
745
k = X .shape [1 ]
719
746
720
- bic = n * tt .log (rss / n ) + k * tt .log (n )
721
- return pytensor .function ([X , n ], bic )
747
+ nf = tt .cast (n , "float64" )
748
+ rss_solve_safe = tt .maximum (rss_solve , np .finfo ("float64" ).tiny )
749
+ rss_pinv_safe = tt .maximum (rss_pinv , np .finfo ("float64" ).tiny )
750
+
751
+ bic_solve = nf * tt .log (rss_solve_safe / nf ) + k * tt .log (nf )
752
+ bic_pinv = nf * tt .log (rss_pinv_safe / nf ) + k * tt .log (nf )
753
+
754
+ bic_solve_fn = pytensor .function (
755
+ [X , n ], [bic_solve , rss_solve ], on_unused_input = "ignore" , mode = "FAST_RUN"
756
+ )
757
+ bic_pinv_fn = pytensor .function (
758
+ [X , n ], bic_pinv , on_unused_input = "ignore" , mode = "FAST_RUN"
759
+ )
760
+
761
+ def bic_fn (X_val : np .ndarray , n_val : int ) -> float :
762
+ try :
763
+ bic_value , rss_value = bic_solve_fn (X_val , n_val )
764
+ if np .isfinite (rss_value ) and rss_value > np .finfo ("float64" ).tiny :
765
+ return float (bic_value )
766
+ except (np .linalg .LinAlgError , RuntimeError , ValueError ):
767
+ pass
768
+ return float (bic_pinv_fn (X_val , n_val ))
769
+
770
+ return bic_fn
722
771
723
772
def _ci_independent (
724
773
self , df : pd .DataFrame , x : str , y : str , cond : Sequence [str ]
@@ -923,30 +972,50 @@ class TBF_FCI:
923
972
Whether to allow contemporaneous edges at time t.
924
973
"""
925
974
975
+ @validate_call (config = dict (arbitrary_types_allowed = True ))
926
976
def __init__ (
927
977
self ,
928
- target : str ,
978
+ target : Annotated [
979
+ str ,
980
+ Field (
981
+ min_length = 1 ,
982
+ description = "Name of the outcome variable at time t." ,
983
+ ),
984
+ ],
929
985
* ,
930
- target_edge_rule : str = "any" ,
931
- bf_thresh : float = 1.0 ,
986
+ target_edge_rule : Literal [ "any" , "conservative" , "fullS" ] = "any" ,
987
+ bf_thresh : Annotated [ float , Field ( gt = 0.0 )] = 1.0 ,
932
988
forbidden_edges : Sequence [tuple [str , str ]] | None = None ,
933
- max_lag : int = 2 ,
989
+ max_lag : Annotated [ int , Field ( ge = 0 )] = 2 ,
934
990
allow_contemporaneous : bool = True ,
935
991
):
992
+ """Create a new temporal TBF-PC causal discovery model.
993
+
994
+ Parameters
995
+ ----------
996
+ target
997
+ Target variable name at time ``t`` that the algorithm orients
998
+ toward.
999
+ target_edge_rule
1000
+ Rule used to retain lagged → target edges. Choose from
1001
+ ``"any"``, ``"conservative"``, or ``"fullS"``.
1002
+ bf_thresh
1003
+ Positive Bayes factor threshold applied during conditional
1004
+ independence testing.
1005
+ forbidden_edges
1006
+ Optional sequence of node pairs that must be excluded from the
1007
+ final graph.
1008
+ max_lag
1009
+ Maximum lag (inclusive) to consider when constructing temporal
1010
+ drivers.
1011
+ allow_contemporaneous
1012
+ Whether contemporaneous edges at time ``t`` are permitted.
1013
+ """
936
1014
warnings .warn (
937
1015
"TBF_FCI is experimental and its API may change; use with caution." ,
938
1016
UserWarning ,
939
1017
stacklevel = 2 ,
940
1018
)
941
- if not isinstance (target , str ) or not target :
942
- raise ValueError ("target must be a non-empty string" )
943
- allowed_rules = {"any" , "conservative" , "fullS" }
944
- if target_edge_rule not in allowed_rules :
945
- raise ValueError (f"target_edge_rule must be one of { allowed_rules } " )
946
- if not isinstance (bf_thresh , (int , float )) or bf_thresh <= 0 :
947
- raise ValueError ("bf_thresh must be a positive float" )
948
- if not isinstance (max_lag , int ) or max_lag < 0 :
949
- raise ValueError ("max_lag must be a non-negative integer" )
950
1019
951
1020
self .target = target
952
1021
self .target_edge_rule = target_edge_rule
@@ -961,6 +1030,8 @@ def __init__(
961
1030
self .nodes_ : list [str ] = []
962
1031
self .test_results : dict [tuple [str , str , frozenset ], dict [str , float ]] = {}
963
1032
1033
+ # Shared response vector for symbolic BIC computation
1034
+ # Initialized with placeholder; will be updated with actual data during fitting
964
1035
self .y_sh = pytensor .shared (np .zeros (1 , dtype = "float64" ), name = "y_sh" )
965
1036
self ._bic_fn = self ._build_symbolic_bic_fn ()
966
1037
@@ -1048,14 +1119,43 @@ def _remove_all(self, u: str, v: str) -> None:
1048
1119
self ._adj_directed .discard ((v , u ))
1049
1120
1050
1121
def _build_symbolic_bic_fn (self ):
1122
+ """Build a BIC callable using a fast solver with a pseudoinverse fallback."""
1051
1123
X = tt .matrix ("X" )
1052
1124
n = tt .iscalar ("n" )
1053
- beta = tt .nlinalg .pinv (X ) @ self .y_sh
1054
- resid = self .y_sh - X @ beta
1055
- rss = tt .sum (resid ** 2 )
1125
+
1126
+ xtx = tt .dot (X .T , X )
1127
+ xty = tt .dot (X .T , self .y_sh )
1128
+
1129
+ beta_solve = tt .linalg .solve (xtx , xty )
1130
+ resid_solve = self .y_sh - tt .dot (X , beta_solve )
1131
+ rss_solve = tt .sum (resid_solve ** 2 )
1132
+
1133
+ beta_pinv = tt .nlinalg .pinv (X ) @ self .y_sh
1134
+ resid_pinv = self .y_sh - tt .dot (X , beta_pinv )
1135
+ rss_pinv = tt .sum (resid_pinv ** 2 )
1136
+
1056
1137
k = X .shape [1 ]
1057
- bic = n * tt .log (rss / n ) + k * tt .log (n )
1058
- return pytensor .function ([X , n ], bic )
1138
+
1139
+ bic_solve = n * tt .log (rss_solve / n ) + k * tt .log (n )
1140
+ bic_pinv = n * tt .log (rss_pinv / n ) + k * tt .log (n )
1141
+
1142
+ bic_solve_fn = pytensor .function (
1143
+ [X , n ], bic_solve , on_unused_input = "ignore" , mode = "FAST_RUN"
1144
+ )
1145
+ bic_pinv_fn = pytensor .function (
1146
+ [X , n ], bic_pinv , on_unused_input = "ignore" , mode = "FAST_RUN"
1147
+ )
1148
+
1149
+ def bic_fn (X_val : np .ndarray , n_val : int ) -> float :
1150
+ try :
1151
+ value = float (bic_solve_fn (X_val , n_val ))
1152
+ if np .isfinite (value ):
1153
+ return value
1154
+ except (np .linalg .LinAlgError , RuntimeError , ValueError ):
1155
+ pass
1156
+ return float (bic_pinv_fn (X_val , n_val ))
1157
+
1158
+ return bic_fn
1059
1159
1060
1160
def _ci_independent (
1061
1161
self , df : pd .DataFrame , x : str , y : str , cond : Sequence [str ]
@@ -1194,15 +1294,17 @@ def fit(self, df: pd.DataFrame, drivers: Sequence[str]):
1194
1294
self ._stageB_contemporaneous (L , drivers )
1195
1295
return self
1196
1296
1197
- def collapsed_summary (self ):
1198
- collapsed_directed = []
1297
+ def collapsed_summary (self ) -> tuple [list [tuple [str , str , int ]], list [tuple [str , str ]]]:
1298
+ """Return collapsed summary of lagged directed and undirected edges."""
1299
+
1300
+ collapsed_directed : list [tuple [str , str , int ]] = []
1199
1301
for u , v in self ._adj_directed :
1200
1302
base_u , lag_u = self ._parse_lag (u )
1201
1303
base_v , lag_v = self ._parse_lag (v )
1202
1304
if lag_v == 0 :
1203
1305
collapsed_directed .append ((base_u , base_v , lag_u ))
1204
1306
1205
- collapsed_undirected = []
1307
+ collapsed_undirected : list [ tuple [ str , str ]] = []
1206
1308
for u , v in self ._adj_undirected :
1207
1309
base_u , lag_u = self ._parse_lag (u )
1208
1310
base_v , lag_v = self ._parse_lag (v )
0 commit comments