Skip to content

Commit f47b08f

Browse files
perf: optimize memory usage of precompute (#1885)
Also settle the node network in Bavet when the working solution is set to a new instance so time taken by precomputes are not considered by terminations.
1 parent 9eb30e8 commit f47b08f

14 files changed

+170
-84
lines changed

core/src/main/java/ai/timefold/solver/core/impl/bavet/bi/PrecomputeBiNode.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,23 @@
11
package ai.timefold.solver.core.impl.bavet.bi;
22

3-
import ai.timefold.solver.core.impl.bavet.NodeNetwork;
3+
import java.util.function.Supplier;
4+
45
import ai.timefold.solver.core.impl.bavet.common.AbstractPrecomputeNode;
56
import ai.timefold.solver.core.impl.bavet.common.tuple.BiTuple;
6-
import ai.timefold.solver.core.impl.bavet.common.tuple.RecordingTupleLifecycle;
77
import ai.timefold.solver.core.impl.bavet.common.tuple.TupleLifecycle;
8+
import ai.timefold.solver.core.impl.score.stream.bavet.common.BavetPrecomputeBuildHelper;
89

910
import org.jspecify.annotations.NullMarked;
1011

1112
@NullMarked
1213
public final class PrecomputeBiNode<A, B> extends AbstractPrecomputeNode<BiTuple<A, B>> {
1314
private final int outputStoreSize;
1415

15-
public PrecomputeBiNode(NodeNetwork nodeNetwork,
16-
RecordingTupleLifecycle<BiTuple<A, B>> recordingTupleNode,
16+
public PrecomputeBiNode(Supplier<BavetPrecomputeBuildHelper<BiTuple<A, B>>> precomputeBuildHelperSupplier,
1717
int outputStoreSize,
1818
TupleLifecycle<BiTuple<A, B>> nextNodesTupleLifecycle,
1919
Class<?>[] sourceClasses) {
20-
super(nodeNetwork, recordingTupleNode, nextNodesTupleLifecycle, sourceClasses);
20+
super(precomputeBuildHelperSupplier, nextNodesTupleLifecycle, sourceClasses);
2121
this.outputStoreSize = outputStoreSize;
2222
}
2323

core/src/main/java/ai/timefold/solver/core/impl/bavet/common/AbstractPrecomputeNode.java

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
package ai.timefold.solver.core.impl.bavet.common;
22

3-
import ai.timefold.solver.core.impl.bavet.NodeNetwork;
3+
import java.util.function.Supplier;
4+
45
import ai.timefold.solver.core.impl.bavet.common.tuple.AbstractTuple;
5-
import ai.timefold.solver.core.impl.bavet.common.tuple.RecordingTupleLifecycle;
66
import ai.timefold.solver.core.impl.bavet.common.tuple.TupleLifecycle;
7+
import ai.timefold.solver.core.impl.score.stream.bavet.common.BavetPrecomputeBuildHelper;
78

89
import org.jspecify.annotations.NullMarked;
910
import org.jspecify.annotations.Nullable;
@@ -14,12 +15,10 @@ public abstract class AbstractPrecomputeNode<Tuple_ extends AbstractTuple> exten
1415
private final RecordAndReplayPropagator<Tuple_> recordAndReplayPropagator;
1516
private final Class<?>[] sourceClasses;
1617

17-
protected AbstractPrecomputeNode(NodeNetwork innerNodeNetwork,
18-
RecordingTupleLifecycle<Tuple_> recordingTupleLifecycle,
18+
protected AbstractPrecomputeNode(Supplier<BavetPrecomputeBuildHelper<Tuple_>> precomputeBuildHelperSupplier,
1919
TupleLifecycle<Tuple_> nextNodesTupleLifecycle,
2020
Class<?>[] sourceClasses) {
21-
this.recordAndReplayPropagator = new RecordAndReplayPropagator<>(innerNodeNetwork,
22-
recordingTupleLifecycle,
21+
this.recordAndReplayPropagator = new RecordAndReplayPropagator<>(precomputeBuildHelperSupplier,
2322
this::remapTuple,
2423
nextNodesTupleLifecycle);
2524
this.sourceClasses = sourceClasses;

core/src/main/java/ai/timefold/solver/core/impl/bavet/common/RecordAndReplayPropagator.java

Lines changed: 80 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,21 @@
11
package ai.timefold.solver.core.impl.bavet.common;
22

33
import java.util.ArrayList;
4+
import java.util.Collections;
5+
import java.util.HashMap;
46
import java.util.IdentityHashMap;
57
import java.util.List;
68
import java.util.Map;
79
import java.util.Set;
10+
import java.util.function.Supplier;
811
import java.util.function.UnaryOperator;
912

1013
import ai.timefold.solver.core.impl.bavet.NodeNetwork;
1114
import ai.timefold.solver.core.impl.bavet.common.tuple.AbstractTuple;
1215
import ai.timefold.solver.core.impl.bavet.common.tuple.RecordingTupleLifecycle;
1316
import ai.timefold.solver.core.impl.bavet.common.tuple.TupleLifecycle;
1417
import ai.timefold.solver.core.impl.bavet.common.tuple.TupleState;
18+
import ai.timefold.solver.core.impl.score.stream.bavet.common.BavetPrecomputeBuildHelper;
1519
import ai.timefold.solver.core.impl.util.CollectionUtils;
1620

1721
import org.jspecify.annotations.NullMarked;
@@ -32,39 +36,42 @@ public final class RecordAndReplayPropagator<Tuple_ extends AbstractTuple>
3236
private final Set<Object> updateQueue;
3337
private final Set<Object> insertQueue;
3438

35-
private final NodeNetwork internalNodeNetwork;
36-
private final RecordingTupleLifecycle<Tuple_> recordingTupleLifecycle;
39+
// Store entities and facts separately; we don't need to precompute
40+
// the tuples for facts, since facts never update
41+
private final Set<Object> seenEntitySet;
42+
private final Set<Object> seenFactSet;
43+
44+
private final Supplier<BavetPrecomputeBuildHelper<Tuple_>> precomputeBuildHelperSupplier;
3745
private final UnaryOperator<Tuple_> internalTupleToOutputTupleMapper;
38-
private final Map<Tuple_, Tuple_> internalTupleToOutputTupleMap;
3946
private final Map<Object, List<Tuple_>> objectToOutputTuplesMap;
47+
private final Map<Class<?>, Boolean> objectClassToIsEntitySourceClass;
4048

4149
private final StaticPropagationQueue<Tuple_> propagationQueue;
4250

4351
public RecordAndReplayPropagator(
44-
NodeNetwork internalNodeNetwork,
45-
RecordingTupleLifecycle<Tuple_> recordingTupleLifecycle,
52+
Supplier<BavetPrecomputeBuildHelper<Tuple_>> precomputeBuildHelperSupplier,
4653
UnaryOperator<Tuple_> internalTupleToOutputTupleMapper,
4754
TupleLifecycle<Tuple_> nextNodesTupleLifecycle, int size) {
48-
this.internalNodeNetwork = internalNodeNetwork;
49-
this.recordingTupleLifecycle = recordingTupleLifecycle;
55+
this.precomputeBuildHelperSupplier = precomputeBuildHelperSupplier;
5056
this.internalTupleToOutputTupleMapper = internalTupleToOutputTupleMapper;
51-
this.internalTupleToOutputTupleMap = CollectionUtils.newIdentityHashMap(size);
5257
this.objectToOutputTuplesMap = CollectionUtils.newIdentityHashMap(size);
5358

5459
// Guesstimate that updates are dominant.
5560
this.retractQueue = CollectionUtils.newIdentityHashSet(size / 20);
5661
this.updateQueue = CollectionUtils.newIdentityHashSet((size / 20) * 18);
5762
this.insertQueue = CollectionUtils.newIdentityHashSet(size / 20);
63+
this.objectClassToIsEntitySourceClass = new HashMap<>();
64+
this.seenEntitySet = CollectionUtils.newIdentityHashSet(size);
65+
this.seenFactSet = CollectionUtils.newIdentityHashSet(size);
5866

5967
this.propagationQueue = new StaticPropagationQueue<>(nextNodesTupleLifecycle);
6068
}
6169

6270
public RecordAndReplayPropagator(
63-
NodeNetwork internalNodeNetwork,
64-
RecordingTupleLifecycle<Tuple_> recordingTupleLifecycle,
71+
Supplier<BavetPrecomputeBuildHelper<Tuple_>> precomputeBuildHelperSupplier,
6572
UnaryOperator<Tuple_> internalTupleToOutputTupleMapper,
6673
TupleLifecycle<Tuple_> nextNodesTupleLifecycle) {
67-
this(internalNodeNetwork, recordingTupleLifecycle, internalTupleToOutputTupleMapper, nextNodesTupleLifecycle, 1000);
74+
this(precomputeBuildHelperSupplier, internalTupleToOutputTupleMapper, nextNodesTupleLifecycle, 1000);
6875
}
6976

7077
public void insert(Object object) {
@@ -87,33 +94,72 @@ public void retract(Object object) {
8794
@Override
8895
public void propagateRetracts() {
8996
if (!retractQueue.isEmpty() || !insertQueue.isEmpty()) {
90-
updateQueue.removeAll(retractQueue);
91-
updateQueue.removeAll(insertQueue);
97+
var precomputeBuildHelper = precomputeBuildHelperSupplier.get();
98+
var internalNodeNetwork = precomputeBuildHelper.getNodeNetwork();
99+
var objectClassToRootNodes = new HashMap<Class<?>, List<BavetRootNode<?>>>();
100+
var recordingTupleLifecycle = precomputeBuildHelper.getRecordingTupleLifecycle();
101+
102+
invalidateCache();
103+
seenEntitySet.removeAll(retractQueue);
104+
seenFactSet.removeAll(retractQueue);
105+
106+
for (var entity : seenEntitySet) {
107+
for (var rootNode : getRootNodes(entity, internalNodeNetwork, objectClassToRootNodes)) {
108+
rootNode.insert(entity);
109+
}
110+
}
111+
112+
for (var fact : seenFactSet) {
113+
for (var rootNode : getRootNodes(fact, internalNodeNetwork, objectClassToRootNodes)) {
114+
rootNode.insert(fact);
115+
}
116+
}
117+
92118
// Do not remove queued retracts from inserts; if a fact property
93119
// change, there will be both a retract and insert for that fact
94-
invalidateCache();
120+
for (var object : insertQueue) {
121+
if (objectClassToIsEntitySourceClass.computeIfAbsent(object.getClass(),
122+
precomputeBuildHelper::isSourceEntityClass)) {
123+
seenEntitySet.add(object);
124+
} else {
125+
seenFactSet.add(object);
126+
}
127+
for (var rootNode : getRootNodes(object, internalNodeNetwork, objectClassToRootNodes)) {
128+
rootNode.insert(object);
129+
}
130+
}
95131

96-
retractQueue.forEach(this::retractFromInternalNodeNetwork);
97-
insertQueue.forEach(this::insertIntoInternalNodeNetwork);
132+
updateQueue.clear();
98133
retractQueue.clear();
99134
insertQueue.clear();
100135

101136
// settle the inner node network, so the inserts/retracts do not interfere
102137
// with the recording of the first object's tuples
103138
internalNodeNetwork.settle();
104-
recalculateTuples();
139+
recalculateTuples(internalNodeNetwork, objectClassToRootNodes, recordingTupleLifecycle);
140+
105141
propagationQueue.propagateRetracts();
106142
}
107143
}
108144

145+
@SuppressWarnings({ "unchecked", "rawtypes" })
146+
private static <A> List<BavetRootNode<A>> getRootNodes(Object object, NodeNetwork internalNodeNetwork,
147+
Map<Class<?>, List<BavetRootNode<?>>> objectClassToRootNodes) {
148+
return (List) objectClassToRootNodes.computeIfAbsent(object.getClass(), clazz -> {
149+
var out = new ArrayList<BavetRootNode<?>>();
150+
internalNodeNetwork.getRootNodesAcceptingType(object.getClass()).forEach(out::add);
151+
return out;
152+
});
153+
}
154+
109155
@Override
110156
public void propagateUpdates() {
157+
Set<Tuple_> updatedTuples = CollectionUtils.newIdentityHashSet(2 * updateQueue.size());
111158
for (var update : updateQueue) {
112-
for (var updatedTuple : objectToOutputTuplesMap.get(update)) {
113-
propagationQueue.update(updatedTuple);
114-
}
159+
updatedTuples.addAll(objectToOutputTuplesMap.get(update));
115160
}
116161
updateQueue.clear();
162+
updatedTuples.forEach(propagationQueue::update);
117163
propagationQueue.propagateUpdates();
118164
}
119165

@@ -144,36 +190,30 @@ private void retractIfPresent(Tuple_ tuple) {
144190
}
145191
}
146192

147-
private void insertIntoInternalNodeNetwork(Object toInsert) {
148-
objectToOutputTuplesMap.put(toInsert, new ArrayList<>());
149-
internalNodeNetwork.getRootNodesAcceptingType(toInsert.getClass())
150-
.forEach(node -> ((BavetRootNode<Object>) node).insert(toInsert));
151-
}
152-
153-
private void retractFromInternalNodeNetwork(Object toRetract) {
154-
objectToOutputTuplesMap.remove(toRetract);
155-
internalNodeNetwork.getRootNodesAcceptingType(toRetract.getClass())
156-
.forEach(node -> ((BavetRootNode<Object>) node).retract(toRetract));
157-
}
158-
159193
private void invalidateCache() {
160194
objectToOutputTuplesMap.values().stream().flatMap(List::stream).forEach(this::retractIfPresent);
161-
internalTupleToOutputTupleMap.clear();
195+
objectToOutputTuplesMap.clear();
162196
}
163197

164-
private void recalculateTuples() {
165-
for (var mappedTupleEntry : objectToOutputTuplesMap.entrySet()) {
166-
mappedTupleEntry.getValue().clear();
167-
var invalidated = mappedTupleEntry.getKey();
198+
private void recalculateTuples(NodeNetwork internalNodeNetwork,
199+
Map<Class<?>, List<BavetRootNode<?>>> classToRootNodeList,
200+
RecordingTupleLifecycle<Tuple_> recordingTupleLifecycle) {
201+
var internalTupleToOutputTupleMap = new IdentityHashMap<Tuple_, Tuple_>(seenEntitySet.size());
202+
for (var invalidated : seenEntitySet) {
203+
var mappedTuples = new ArrayList<Tuple_>();
168204
try (var unusedActiveRecordingLifecycle = recordingTupleLifecycle.recordInto(
169-
new TupleRecorder<>(mappedTupleEntry.getValue(), internalTupleToOutputTupleMapper,
170-
(IdentityHashMap<Tuple_, Tuple_>) internalTupleToOutputTupleMap))) {
205+
new TupleRecorder<>(mappedTuples, internalTupleToOutputTupleMapper, internalTupleToOutputTupleMap))) {
171206
// Do a fake update on the object and settle the network; this will update precisely the
172207
// tuples mapped to this node, which will then be recorded
173-
internalNodeNetwork.getRootNodesAcceptingType(invalidated.getClass())
208+
classToRootNodeList.get(invalidated.getClass())
174209
.forEach(node -> ((BavetRootNode<Object>) node).update(invalidated));
175210
internalNodeNetwork.settle();
176211
}
212+
if (mappedTuples.isEmpty()) {
213+
objectToOutputTuplesMap.put(invalidated, Collections.emptyList());
214+
} else {
215+
objectToOutputTuplesMap.put(invalidated, mappedTuples);
216+
}
177217
}
178218
objectToOutputTuplesMap.values().stream().flatMap(List::stream).forEach(this::insertIfAbsent);
179219
}

core/src/main/java/ai/timefold/solver/core/impl/bavet/common/tuple/RecordingTupleLifecycle.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,13 @@ public void close() {
2323
@Override
2424
public void insert(Tuple_ tuple) {
2525
if (tupleRecorder != null) {
26-
throw new IllegalStateException("Impossible state: tuple %s was inserted during recording".formatted(tuple));
26+
throw new IllegalStateException("""
27+
Illegal state: tuple %s was inserted during recording.
28+
Certain operations like flattenLast will create new tuples
29+
on update if its mapping function returns a new instance.
30+
Maybe refactor the code to avoid creating new instances,
31+
or avoid using precompute if that is not possible.
32+
""".formatted(tuple));
2733
}
2834
}
2935

core/src/main/java/ai/timefold/solver/core/impl/bavet/quad/PrecomputeQuadNode.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,23 @@
11
package ai.timefold.solver.core.impl.bavet.quad;
22

3-
import ai.timefold.solver.core.impl.bavet.NodeNetwork;
3+
import java.util.function.Supplier;
4+
45
import ai.timefold.solver.core.impl.bavet.common.AbstractPrecomputeNode;
56
import ai.timefold.solver.core.impl.bavet.common.tuple.QuadTuple;
6-
import ai.timefold.solver.core.impl.bavet.common.tuple.RecordingTupleLifecycle;
77
import ai.timefold.solver.core.impl.bavet.common.tuple.TupleLifecycle;
8+
import ai.timefold.solver.core.impl.score.stream.bavet.common.BavetPrecomputeBuildHelper;
89

910
import org.jspecify.annotations.NullMarked;
1011

1112
@NullMarked
1213
public final class PrecomputeQuadNode<A, B, C, D> extends AbstractPrecomputeNode<QuadTuple<A, B, C, D>> {
1314
private final int outputStoreSize;
1415

15-
public PrecomputeQuadNode(NodeNetwork nodeNetwork,
16-
RecordingTupleLifecycle<QuadTuple<A, B, C, D>> recordingTupleNode,
16+
public PrecomputeQuadNode(Supplier<BavetPrecomputeBuildHelper<QuadTuple<A, B, C, D>>> precomputeBuildHelperSupplier,
1717
int outputStoreSize,
1818
TupleLifecycle<QuadTuple<A, B, C, D>> nextNodesTupleLifecycle,
1919
Class<?>[] sourceClasses) {
20-
super(nodeNetwork, recordingTupleNode, nextNodesTupleLifecycle, sourceClasses);
20+
super(precomputeBuildHelperSupplier, nextNodesTupleLifecycle, sourceClasses);
2121
this.outputStoreSize = outputStoreSize;
2222
}
2323

core/src/main/java/ai/timefold/solver/core/impl/bavet/tri/PrecomputeTriNode.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,23 @@
11
package ai.timefold.solver.core.impl.bavet.tri;
22

3-
import ai.timefold.solver.core.impl.bavet.NodeNetwork;
3+
import java.util.function.Supplier;
4+
45
import ai.timefold.solver.core.impl.bavet.common.AbstractPrecomputeNode;
5-
import ai.timefold.solver.core.impl.bavet.common.tuple.RecordingTupleLifecycle;
66
import ai.timefold.solver.core.impl.bavet.common.tuple.TriTuple;
77
import ai.timefold.solver.core.impl.bavet.common.tuple.TupleLifecycle;
8+
import ai.timefold.solver.core.impl.score.stream.bavet.common.BavetPrecomputeBuildHelper;
89

910
import org.jspecify.annotations.NullMarked;
1011

1112
@NullMarked
1213
public final class PrecomputeTriNode<A, B, C> extends AbstractPrecomputeNode<TriTuple<A, B, C>> {
1314
private final int outputStoreSize;
1415

15-
public PrecomputeTriNode(NodeNetwork nodeNetwork,
16-
RecordingTupleLifecycle<TriTuple<A, B, C>> recordingTupleNode,
16+
public PrecomputeTriNode(Supplier<BavetPrecomputeBuildHelper<TriTuple<A, B, C>>> precomputeBuildHelperSupplier,
1717
int outputStoreSize,
1818
TupleLifecycle<TriTuple<A, B, C>> nextNodesTupleLifecycle,
1919
Class<?>[] sourceClasses) {
20-
super(nodeNetwork, recordingTupleNode, nextNodesTupleLifecycle, sourceClasses);
20+
super(precomputeBuildHelperSupplier, nextNodesTupleLifecycle, sourceClasses);
2121
this.outputStoreSize = outputStoreSize;
2222
}
2323

core/src/main/java/ai/timefold/solver/core/impl/bavet/uni/PrecomputeUniNode.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,23 @@
11
package ai.timefold.solver.core.impl.bavet.uni;
22

3-
import ai.timefold.solver.core.impl.bavet.NodeNetwork;
3+
import java.util.function.Supplier;
4+
45
import ai.timefold.solver.core.impl.bavet.common.AbstractPrecomputeNode;
5-
import ai.timefold.solver.core.impl.bavet.common.tuple.RecordingTupleLifecycle;
66
import ai.timefold.solver.core.impl.bavet.common.tuple.TupleLifecycle;
77
import ai.timefold.solver.core.impl.bavet.common.tuple.UniTuple;
8+
import ai.timefold.solver.core.impl.score.stream.bavet.common.BavetPrecomputeBuildHelper;
89

910
import org.jspecify.annotations.NullMarked;
1011

1112
@NullMarked
1213
public final class PrecomputeUniNode<A> extends AbstractPrecomputeNode<UniTuple<A>> {
1314
private final int outputStoreSize;
1415

15-
public PrecomputeUniNode(NodeNetwork nodeNetwork,
16-
RecordingTupleLifecycle<UniTuple<A>> recordingTupleNode,
16+
public PrecomputeUniNode(Supplier<BavetPrecomputeBuildHelper<UniTuple<A>>> precomputeBuildHelperSupplier,
1717
int outputStoreSize,
1818
TupleLifecycle<UniTuple<A>> nextNodesTupleLifecycle,
1919
Class<?>[] sourceClasses) {
20-
super(nodeNetwork, recordingTupleNode, nextNodesTupleLifecycle, sourceClasses);
20+
super(precomputeBuildHelperSupplier, nextNodesTupleLifecycle, sourceClasses);
2121
this.outputStoreSize = outputStoreSize;
2222
}
2323

core/src/main/java/ai/timefold/solver/core/impl/score/director/AbstractScoreDirector.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,15 @@ public final void setWorkingSolution(Solution_ workingSolution) {
282282
setWorkingSolutionWithoutUpdatingShadows(workingSolution);
283283
forceTriggerVariableListeners();
284284
expectShadowVariablesInCorrectState = originalShouldAssert;
285+
afterSetWorkingSolution();
286+
}
287+
288+
/**
289+
* Note: by default does nothing. Subclasses should override this if they
290+
* need to compute something after shadow variables are set.
291+
*/
292+
protected void afterSetWorkingSolution() {
293+
// Do nothing
285294
}
286295

287296
@Override

core/src/main/java/ai/timefold/solver/core/impl/score/director/stream/BavetConstraintStreamScoreDirector.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,13 @@ public void setWorkingSolutionWithoutUpdatingShadows(Solution_ workingSolution)
7878
super.setWorkingSolutionWithoutUpdatingShadows(workingSolution, session::insert);
7979
}
8080

81+
@Override
82+
protected void afterSetWorkingSolution() {
83+
// Settle the node network to calculate precomputes
84+
// This is required so precomputes are not considered by terminations
85+
session.settle();
86+
}
87+
8188
@Override
8289
public InnerScore<Score_> calculateScore() {
8390
variableListenerSupport.assertNotificationQueuesAreEmpty();

0 commit comments

Comments
 (0)