99NEG_EXAMPLE_WEIGHT = 1
1010
1111class Combiner :
12- def __init__ (self , settings , tester ):
12+ def __init__ (self , settings , tester , coverage_pos , coverage_neg , prog_lookup ):
1313 self .settings = settings
1414 self .tester = tester
1515 self .best_cost = None
1616 self .saved_progs = set ()
1717 self .inconsistent = set ()
18+ self .coverage_pos = coverage_pos
19+ self .coverage_neg = coverage_neg
20+ self .prog_lookup = prog_lookup
1821
19- def add_inconsistent (self , prog ):
20- self .inconsistent .add (prog )
22+ def add_inconsistent (self , prog_hash ):
23+ self .inconsistent .add (prog_hash )
2124
2225 def find_combination (self , timeout ):
2326 encoding = []
@@ -30,29 +33,30 @@ def find_combination(self, timeout):
3033 base_rules = []
3134 recursive_rules = []
3235
33- programs_covering_example = defaultdict (list )
36+ programs_covering_pos_example = {}
37+ programs_covering_neg_example = {}
3438 program_var = {}
3539 program_clauses = {}
3640 vpool = IDPool ()
37- example_covered_var = {}
3841
39- for i in self . settings . pos_index :
40- example_covered_var [ i ] = vpool . id ( "example_covered({0})" . format ( i ))
42+ pos_example_covered_var = {}
43+ neg_example_covered_var = {}
4144
42- if not self .settings .nonoise :
43- for i in self .settings .neg_index :
44- example_covered_var [i ] = vpool .id ("example_covered({0})" .format (i ))
45+ pos_index = list (range (self .tester .num_pos ))
46+ neg_index = list (range (self .tester .num_neg ))
4547
46- # print('moo', t2-t1)
47- # print('starting to build' )
48- # t1 = time.time()
48+ for i in pos_index :
49+ pos_example_covered_var [ i ] = vpool . id ( "pos_example_covered({0})" . format ( i ) )
50+ programs_covering_pos_example [ i ] = []
4951
50- rule_var = {}
51-
52- for program_count , prog in enumerate (self .saved_progs ):
52+ if self .settings .noisy :
53+ for i in neg_index :
54+ neg_example_covered_var [i ] = vpool .id ("neg_example_covered({0})" .format (i ))
55+ programs_covering_neg_example [i ] = []
5356
54- # print('B', format_prog(prog), hash(prog))
57+ rule_var = {}
5558
59+ for program_count , prog_hash in enumerate (self .saved_progs ):
5660 # UNCOMMENT TO SHOW PROGRAMS ADDED TO THE SOLVER
5761 # tp = len(pos_covered)
5862 # fp = len(neg_covered)
@@ -61,21 +65,21 @@ def find_combination(self, timeout):
6165 # print(f'size: {size} fp:{fp} tp:{tp} mdl:{size + fp + fn} {format_prog(prog)}')
6266 # print(sorted(pos_covered))
6367
64-
65- prog_hash = hash (prog )
66-
67- # if prog_hash in self.to_delete:
68- # continue
68+ prog = self .prog_lookup [prog_hash ]
6969
7070 pos_covered = self .coverage_pos [prog_hash ]
7171 neg_covered = self .coverage_neg [prog_hash ]
7272
73- for ex in pos_covered :
74- programs_covering_example [ex ].append (program_count )
73+ for ex , x in enumerate (pos_covered ):
74+ if x == 1 :
75+ # AC: REALLY REALLY HACKY
76+ # WE NEED TO +1 BECAUSE OF POS_INDEX BELOW
77+ programs_covering_pos_example [ex ].append (program_count )
7578
7679 if self .settings .noisy :
77- for ex in neg_covered :
78- programs_covering_example [ex ].append (program_count )
80+ for ex , x in enumerate (neg_covered ):
81+ if x == 1 :
82+ programs_covering_neg_example [ex ].append (program_count )
7983
8084 rule_vars = []
8185 ids = []
@@ -89,6 +93,7 @@ def find_combination(self, timeout):
8993 ids .append (k )
9094 else :
9195 ids .append (rulehash_to_id [rule_hash ])
96+
9297 for rule in prog :
9398 rule_hash = hash (rule )
9499 rule_id = rulehash_to_id [rule_hash ]
@@ -122,20 +127,20 @@ def find_combination(self, timeout):
122127 if self .settings .lex and self .settings .recursion_enabled :
123128 encoding .append ([rule_var [rule_id ] for rule_id in base_rules ])
124129
125- for ex in self .settings .pos_index :
126- encoding .append ([- example_covered_var [ex ]] + [program_var [p ] for p in programs_covering_example [ex ]])
127- if not self .settings .nonoise :
128- for ex in self .settings .neg_index :
129- for p in programs_covering_example [ex ]:
130- encoding .append ([example_covered_var [ex ], - program_var [p ]])
130+ for ex in pos_index :
131+ encoding .append ([- pos_example_covered_var [ex ]] + [program_var [p ] for p in programs_covering_pos_example [ex ]])
132+
133+ if self .settings .noisy :
134+ for ex in neg_index :
135+ for p in programs_covering_neg_example [ex ]:
136+ encoding .append ([neg_example_covered_var [ex ], - program_var [p ]])
131137
132138 soft_clauses = []
133139 weights = []
134140
135141 if self .settings .best_prog_score :
136142 tp_ , fn_ , tn_ , fp_ , size_ = self .settings .best_prog_score
137143
138- # with self.settings.stats.duration('combine.add'):
139144 if self .settings .lex :
140145 soft_lit_groups = []
141146 rule_soft_lits = []
@@ -145,37 +150,37 @@ def find_combination(self, timeout):
145150 weights .append (ruleid_to_size [rule_id ])
146151 if self .settings .best_prog_score :
147152 if fn_ == 0 :
148- for i in self . settings . pos_index :
149- encoding .append ([example_covered_var [i ]])
153+ for i in pos_index :
154+ encoding .append ([pos_example_covered_var [i ]])
150155 if fp_ == 0 :
151156 if not self .settings .nonoise :
152- for i in self . settings . neg_index :
153- encoding .append ([- example_covered_var [i ]])
157+ for i in neg_index :
158+ encoding .append ([- neg_example_covered_var [i ]])
154159 soft_lit_groups = [[lit for lit in rule_soft_lits ]]
155160 else :
156- soft_lit_groups = [[- example_covered_var [i ] for i in self . settings . neg_index ]]
161+ soft_lit_groups = [[- neg_example_covered_var [i ] for i in neg_index ]]
157162 soft_lit_groups .append ([lit for lit in rule_soft_lits ])
158163 else :
159- soft_lit_groups = [[example_covered_var [i ] for i in self . settings . pos_index ]]
164+ soft_lit_groups = [[pos_example_covered_var [i ] for i in pos_index ]]
160165 if not self .settings .nonoise :
161- soft_lit_groups .append ([- example_covered_var [i ] for i in self . settings . neg_index ])
166+ soft_lit_groups .append ([- neg_example_covered_var [i ] for i in neg_index ])
162167 soft_lit_groups .append ([lit for lit in rule_soft_lits ])
163168 else :
164- soft_lit_groups = [[example_covered_var [i ] for i in self . settings . pos_index ]]
169+ soft_lit_groups = [[pos_example_covered_var [i ] for i in pos_index ]]
165170 if not self .settings .nonoise :
166- soft_lit_groups .append ([- example_covered_var [i ] for i in self . settings . neg_index ])
171+ soft_lit_groups .append ([- neg_example_covered_var [i ] for i in neg_index ])
167172 soft_lit_groups .append ([lit for lit in rule_soft_lits ])
168173 else :
169174 for rule_id in rule_var :
170175 if rule_var [rule_id ] is not None :
171176 soft_clauses .append ([- rule_var [rule_id ]])
172177 weights .append (ruleid_to_size [rule_id ])
173- for i in self . settings . pos_index :
174- soft_clauses .append ([example_covered_var [i ]])
178+ for i in pos_index :
179+ soft_clauses .append ([pos_example_covered_var [i ]])
175180 weights .append (POS_EXAMPLE_WEIGHT )
176181 if not self .settings .nonoise :
177- for i in self . settings . neg_index :
178- soft_clauses .append ([- example_covered_var [i ]])
182+ for i in neg_index :
183+ soft_clauses .append ([- neg_example_covered_var [i ]])
179184 weights .append (NEG_EXAMPLE_WEIGHT )
180185
181186 # PRUNE INCONSISTENT
@@ -190,14 +195,10 @@ def find_combination(self, timeout):
190195 ids .append (k )
191196 if not should_add :
192197 continue
193- # print('MOO')
194198 ids = [rulehash_to_id [k ] for k in ids ]
195199 clause = [- rule_var [k ] for k in ids ]
196200 encoding .append (clause )
197201
198- # t2 = time.time()
199- # print('building time', t2-t1)
200-
201202 best_prog = []
202203 best_fp = False
203204 best_fn = False
@@ -210,12 +211,6 @@ def find_combination(self, timeout):
210211 model_found = False
211212 model_inconsistent = False
212213
213-
214-
215- # print('solving')
216- # t1 = time.time()
217-
218-
219214 if not self .settings .lex :
220215 if timeout is None or self .settings .last_combine_stage :
221216 cost , model = maxsat .exact_maxsat_solve (encoding , soft_clauses , weights , self .settings )
@@ -227,17 +222,14 @@ def find_combination(self, timeout):
227222 else :
228223 cost , model = maxsat .anytime_lex_solve (encoding , soft_lit_groups , weights , self .settings , timeout )
229224
230- # t2 = time.time()
231- # print('solving time', t2-t1)
232-
233225 if model is None :
234226 print ("WARNING: No solution found, exit combiner." )
235227 break
236228
237- fn = sum (1 for i in self . settings . pos_index if model [example_covered_var [i ]- 1 ] < 0 )
229+ fn = sum (1 for i in pos_index if model [pos_example_covered_var [i ]- 1 ] < 0 )
238230 fp = 0
239231 if not self .settings .nonoise :
240- fp = sum (1 for i in self . settings . neg_index if model [example_covered_var [i ]- 1 ] > 0 )
232+ fp = sum (1 for i in neg_index if model [neg_example_covered_var [i ]- 1 ] > 0 )
241233 size = sum ([ruleid_to_size [rule_id ] for rule_id in ruleid_to_size if model [rule_var [rule_id ]- 1 ] > 0 ])
242234
243235 if self .settings .lex :
@@ -305,8 +297,8 @@ def update_best_prog(self, new_progs, timeout=None):
305297
306298 new_solution = reduce_prog (new_solution )
307299 pos_covered , neg_covered = self .tester .test_prog_all (new_solution )
308- tp = len ( pos_covered )
309- fp = len ( neg_covered )
300+ tp = pos_covered . count ( 1 )
301+ fp = neg_covered . count ( 1 )
310302 tn = self .tester .num_neg - fp
311303 fn = self .tester .num_pos - tp
312304 size = calc_prog_size (new_solution )
0 commit comments