Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@

package org.junit.platform.engine.support.hierarchical;

import static java.util.Comparator.naturalOrder;
import static java.util.Comparator.reverseOrder;
import static java.util.Comparator.comparing;
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.CompletableFuture.completedFuture;
import static java.util.concurrent.TimeUnit.SECONDS;
Expand All @@ -23,11 +22,13 @@
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.Deque;
import java.util.EnumMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.CompletableFuture;
Expand Down Expand Up @@ -333,19 +334,18 @@ void invokeAll(List<? extends TestTask> testTasks) {

List<TestTask> isolatedTasks = new ArrayList<>(testTasks.size());
List<TestTask> sameThreadTasks = new ArrayList<>(testTasks.size());
var reverseQueueEntries = forkConcurrentChildren(testTasks, isolatedTasks::add, sameThreadTasks);
var queueEntries = forkConcurrentChildren(testTasks, isolatedTasks::add, sameThreadTasks);
executeAll(sameThreadTasks);
var reverseQueueEntriesByResult = tryToStealWorkWithoutBlocking(reverseQueueEntries);
tryToStealWorkWithBlocking(reverseQueueEntriesByResult);
waitFor(reverseQueueEntriesByResult);
var queueEntriesByResult = tryToStealWorkWithoutBlocking(queueEntries);
tryToStealWorkWithBlocking(queueEntriesByResult);
waitFor(queueEntriesByResult);
executeAll(isolatedTasks);
}

private List<WorkQueue.Entry> forkConcurrentChildren(List<? extends TestTask> children,
Consumer<TestTask> isolatedTaskCollector, List<TestTask> sameThreadTasks) {

List<WorkQueue.Entry> queueEntries = new ArrayList<>(children.size());
int index = 0;
for (TestTask child : children) {
if (requiresGlobalReadWriteLock(child)) {
isolatedTaskCollector.accept(child);
Expand All @@ -354,20 +354,19 @@ else if (child.getExecutionMode() == SAME_THREAD) {
sameThreadTasks.add(child);
}
else {
queueEntries.add(workQueue.createEntry(child, index++));
queueEntries.add(new WorkQueue.Entry(child, nextChildIndex()));
}
}

if (!queueEntries.isEmpty()) {
queueEntries.sort(WorkQueue.Entry.CHILD_COMPARATOR);
if (sameThreadTasks.isEmpty()) {
// hold back one task for this thread
var lastEntry = queueEntries.stream().max(naturalOrder()).orElseThrow();
queueEntries.remove(lastEntry);
sameThreadTasks.add(lastEntry.task);
var firstEntry = queueEntries.remove(0);
sameThreadTasks.add(firstEntry.task);
}
forkAll(queueEntries);
}
queueEntries.sort(reverseOrder());
return queueEntries;
}

Expand Down Expand Up @@ -562,9 +561,10 @@ private void tryToStealWorkFromSubmittedChildren() {
if (currentSubmittedChildren == null || currentSubmittedChildren.isEmpty()) {
return;
}
var iterator = currentSubmittedChildren.listIterator(currentSubmittedChildren.size());
while (iterator.hasPrevious()) {
WorkQueue.Entry entry = iterator.previous();
currentSubmittedChildren.sort(WorkQueue.Entry.CHILD_COMPARATOR);
var iterator = currentSubmittedChildren.iterator();
while (iterator.hasNext()) {
WorkQueue.Entry entry = iterator.next();
var result = tryToStealWork(entry, BlockingMode.NON_BLOCKING);
if (result.isExecuted()) {
iterator.remove();
Expand Down Expand Up @@ -653,19 +653,15 @@ private enum BlockingMode {
}

private static class WorkQueue implements Iterable<WorkQueue.Entry> {
private final Set<Entry> queue = new ConcurrentSkipListSet<>();

private final Set<Entry> queue = new ConcurrentSkipListSet<>(Entry.QUEUE_COMPARATOR);

Entry add(TestTask task, int index) {
Entry entry = createEntry(task, index);
Entry entry = new Entry(task, index);
LOGGER.trace(() -> "forking: " + entry.task);
return doAdd(entry);
}

Entry createEntry(TestTask task, int index) {
var uniqueId = task.getTestDescriptor().getUniqueId();
return new Entry(uniqueId, task, new CompletableFuture<>(), index);
}

void addAll(Collection<Entry> entries) {
entries.forEach(this::doAdd);
}
Expand Down Expand Up @@ -696,68 +692,106 @@ public Iterator<Entry> iterator() {
return queue.iterator();
}

private record Entry(UniqueId id, TestTask task, CompletableFuture<@Nullable Void> future, int index)
implements Comparable<Entry> {
private static final class Entry {

private static final Comparator<Entry> QUEUE_COMPARATOR = comparing(Entry::level).reversed() //
.thenComparing(Entry::isContainer) // tests before containers
.thenComparing(Entry::index) //
.thenComparing(Entry::uniqueId, new SameLengthUniqueIdComparator());

private static final Comparator<Entry> CHILD_COMPARATOR = comparing(Entry::isContainer).reversed() // containers before tests
.thenComparing(Entry::index);

private final TestTask task;
private final CompletableFuture<@Nullable Void> future;
private final int index;

@SuppressWarnings("FutureReturnValueIgnored")
Entry {
future.whenComplete((__, t) -> {
Entry(TestTask task, int index) {
this.future = new CompletableFuture<>();
this.future.whenComplete((__, t) -> {
if (t == null) {
LOGGER.trace(() -> "completed normally: " + this.task());
LOGGER.trace(() -> "completed normally: " + task);
}
else {
LOGGER.trace(t, () -> "completed exceptionally: " + this.task());
LOGGER.trace(t, () -> "completed exceptionally: " + task);
}
});
this.task = task;
this.index = index;
}

@Override
public int compareTo(Entry that) {
var result = Integer.compare(that.getLevel(), getLevel());
if (result != 0) {
return result;
}
result = Boolean.compare(this.isContainer(), that.isContainer());
if (result != 0) {
return result;
}
result = Integer.compare(that.index(), this.index());
if (result != 0) {
return result;
}
return compareBy(that.id(), this.id());
private int index() {
return this.index;
}

private int compareBy(UniqueId a, UniqueId b) {
var aIterator = a.getSegments().iterator();
var bIterator = b.getSegments().iterator();
private int level() {
return uniqueId().getSegments().size();
}

// ids have the same length
while (aIterator.hasNext()) {
var aCurrent = aIterator.next();
var bCurrent = bIterator.next();
int result = compareBy(aCurrent, bCurrent);
if (result != 0) {
return result;
}
}
return 0;
private boolean isContainer() {
return task.getTestDescriptor().isContainer();
}

private int compareBy(UniqueId.Segment a, UniqueId.Segment b) {
int result = a.getType().compareTo(b.getType());
if (result != 0) {
return result;
private UniqueId uniqueId() {
return task.getTestDescriptor().getUniqueId();
}

CompletableFuture<@Nullable Void> future() {
return future;
}

@Override
public boolean equals(Object obj) {
if (obj == this) {
return true;
}
return a.getValue().compareTo(b.getValue());
if (obj == null || obj.getClass() != this.getClass()) {
return false;
}
var that = (Entry) obj;
return Objects.equals(this.uniqueId(), that.uniqueId()) && this.index == that.index;
}

private int getLevel() {
return this.id.getSegments().size();
@Override
public int hashCode() {
return Objects.hash(uniqueId(), index);
}

private boolean isContainer() {
return task.getTestDescriptor().isContainer();
@Override
public String toString() {
return new ToStringBuilder(this) //
.append("task", task) //
.append("index", index) //
.toString();
}

private static class SameLengthUniqueIdComparator implements Comparator<UniqueId> {

@Override
public int compare(UniqueId a, UniqueId b) {
var aIterator = a.getSegments().iterator();
var bIterator = b.getSegments().iterator();

// ids have the same length
while (aIterator.hasNext()) {
var aCurrent = aIterator.next();
var bCurrent = bIterator.next();
int result = compareBy(aCurrent, bCurrent);
if (result != 0) {
return result;
}
}
return 0;
}

private static int compareBy(UniqueId.Segment a, UniqueId.Segment b) {
int result = a.getType().compareTo(b.getType());
if (result != 0) {
return result;
}
return a.getValue().compareTo(b.getValue());
}
}

}
Expand Down
Loading