diff --git a/runtime/common/src/main/java/org/apache/nemo/runtime/common/RuntimeIdManager.java b/runtime/common/src/main/java/org/apache/nemo/runtime/common/RuntimeIdManager.java index 255cca70f5..0c4f74231c 100644 --- a/runtime/common/src/main/java/org/apache/nemo/runtime/common/RuntimeIdManager.java +++ b/runtime/common/src/main/java/org/apache/nemo/runtime/common/RuntimeIdManager.java @@ -74,6 +74,15 @@ public static String generateTaskId(final String stageId, final int index, final return stageId + SPLITTER + index + SPLITTER + attempt; } + /** + * Generates the ID of a task created by Work Stealing. + * @param taskId the ID of original task. + * @return the generated ID. + */ + public static String generateWorkStealingTaskId(final String taskId) { + return getStageIdFromTaskId(taskId) + SPLITTER + getIndexFromTaskId(taskId) + SPLITTER + "*"; + } + /** * Generates the ID for executor. * diff --git a/runtime/common/src/main/java/org/apache/nemo/runtime/common/metric/TaskMetric.java b/runtime/common/src/main/java/org/apache/nemo/runtime/common/metric/TaskMetric.java index 531e715a7b..7d98140cc3 100644 --- a/runtime/common/src/main/java/org/apache/nemo/runtime/common/metric/TaskMetric.java +++ b/runtime/common/src/main/java/org/apache/nemo/runtime/common/metric/TaskMetric.java @@ -49,6 +49,9 @@ public class TaskMetric implements StateMetric { private long shuffleReadTime = -1; private long shuffleWriteBytes = -1; private long shuffleWriteTime = -1; + private int currentIteratorIndex = -1; + private int totalIteratorNumber = -1; + private long taskPreparationTime = -1; private static final Logger LOG = LoggerFactory.getLogger(TaskMetric.class.getName()); @@ -252,6 +255,30 @@ private void setShuffleWriteTime(final long shuffleWriteTime) { this.shuffleWriteTime = shuffleWriteTime; } + public final int getCurrentIteratorIndex() { + return this.currentIteratorIndex; + } + + private void setCurrentIteratorIndex(final int currentIteratorIndex) { + this.currentIteratorIndex = currentIteratorIndex; + } + + public final int getTotalIteratorNumber() { + return this.totalIteratorNumber; + } + + private void setTotalIteratorNumber(final int totalIteratorNumber) { + this.totalIteratorNumber = totalIteratorNumber; + } + + public final long getTaskPreparationTime() { + return this.taskPreparationTime; + } + + private void setTaskPreparationTime(final long taskPreparationTime) { + this.taskPreparationTime = taskPreparationTime; + } + @Override public final String getId() { return id; @@ -317,6 +344,14 @@ public final boolean processMetricMessage(final String metricField, final byte[] case "shuffleWriteTime": setShuffleWriteTime(SerializationUtils.deserialize(metricValue)); break; + case "currentIteratorIndex": + setCurrentIteratorIndex(SerializationUtils.deserialize(metricValue)); + break; + case "totalIteratorNumber": + setTotalIteratorNumber(SerializationUtils.deserialize(metricValue)); + break; + case "taskPreparationTime": + setTaskPreparationTime(SerializationUtils.deserialize(metricValue)); default: LOG.warn("metricField {} is not supported.", metricField); return false; diff --git a/runtime/common/src/main/java/org/apache/nemo/runtime/common/plan/Stage.java b/runtime/common/src/main/java/org/apache/nemo/runtime/common/plan/Stage.java index a7f472c0da..e60cf1648f 100644 --- a/runtime/common/src/main/java/org/apache/nemo/runtime/common/plan/Stage.java +++ b/runtime/common/src/main/java/org/apache/nemo/runtime/common/plan/Stage.java @@ -33,15 +33,14 @@ import org.apache.nemo.common.ir.vertex.executionproperty.ScheduleGroupProperty; import java.io.Serializable; -import java.util.List; -import java.util.Map; -import java.util.Optional; +import java.util.*; /** * Stage. */ public final class Stage extends Vertex { private final List taskIndices; + private final Set workStealingTaskIds = new HashSet<>(); private final DAG> irDag; private final byte[] serializedIRDag; private final List> vertexIdToReadables; @@ -93,6 +92,18 @@ public List getTaskIndices() { return taskIndices; } + /** + * Set IDs for work stealing. + * @param workStealingTaskIds IDs of work stealer tasks. + */ + public void setWorkStealingTaskIds(final Set workStealingTaskIds) { + this.workStealingTaskIds.addAll(workStealingTaskIds); + } + + public Set getWorkStealingTaskIds() { + return this.workStealingTaskIds; + } + /** * @return the parallelism. */ diff --git a/runtime/common/src/main/java/org/apache/nemo/runtime/common/plan/Task.java b/runtime/common/src/main/java/org/apache/nemo/runtime/common/plan/Task.java index 719075b456..5d827e4d07 100644 --- a/runtime/common/src/main/java/org/apache/nemo/runtime/common/plan/Task.java +++ b/runtime/common/src/main/java/org/apache/nemo/runtime/common/plan/Task.java @@ -27,6 +27,7 @@ import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.concurrent.atomic.AtomicInteger; /** * A Task (attempt) is a self-contained executable that can be executed on a machine. @@ -40,8 +41,13 @@ public final class Task implements Serializable { private final byte[] serializedIRDag; private final Map irVertexIdToReadable; + /* For work stealing */ + private final AtomicInteger iteratorStartIndex; + private final AtomicInteger iteratorEndIndex; + /** - * Constructor. + * Default Constructor. + * It initializes iteratorStartIndex as 0 and iteratorEndIndex as Integer.MAX_VALUE. * * @param planId the id of the physical plan. * @param taskId the ID of this task attempt. @@ -58,6 +64,33 @@ public Task(final String planId, final List taskIncomingEdges, final List taskOutgoingEdges, final Map irVertexIdToReadable) { + this(planId, taskId, executionProperties, serializedIRDag, taskIncomingEdges, taskOutgoingEdges, + irVertexIdToReadable, new AtomicInteger(0), new AtomicInteger(Integer.MAX_VALUE)); + } + + /** + * Constructor with iterator information. + * This constructor is used when creating work stealer tasks. + * + * @param planId the id of the physical plan. + * @param taskId the ID of this task attempt. + * @param executionProperties {@link VertexExecutionProperty} map for the corresponding stage + * @param serializedIRDag the serialized DAG of the task. + * @param taskIncomingEdges the incoming edges of the task. + * @param taskOutgoingEdges the outgoing edges of the task. + * @param irVertexIdToReadable the map between IRVertex id to readable. + * @param iteratorStartIndex starting index of iterator. + * @param iteratorEndIndex ending index of iterator. + */ + public Task(final String planId, + final String taskId, + final ExecutionPropertyMap executionProperties, + final byte[] serializedIRDag, + final List taskIncomingEdges, + final List taskOutgoingEdges, + final Map irVertexIdToReadable, + final AtomicInteger iteratorStartIndex, + final AtomicInteger iteratorEndIndex) { this.planId = planId; this.taskId = taskId; this.executionProperties = executionProperties; @@ -65,6 +98,8 @@ public Task(final String planId, this.taskIncomingEdges = taskIncomingEdges; this.taskOutgoingEdges = taskOutgoingEdges; this.irVertexIdToReadable = irVertexIdToReadable; + this.iteratorStartIndex = iteratorStartIndex; + this.iteratorEndIndex = iteratorEndIndex; } /** diff --git a/runtime/common/src/main/proto/ControlMessage.proto b/runtime/common/src/main/proto/ControlMessage.proto index 97e30fb4e7..bd5ed2d30f 100644 --- a/runtime/common/src/main/proto/ControlMessage.proto +++ b/runtime/common/src/main/proto/ControlMessage.proto @@ -86,6 +86,9 @@ enum MessageType { PipeInit = 13; RequestPipeLoc = 14; PipeLocInfo = 15; + ParentTaskDataCollected = 16; + CurrentlyProcessedBytesCollected = 17; + SendWorkStealingResult = 18; } message Message { @@ -107,6 +110,9 @@ message Message { optional PipeInitMessage pipeInitMsg = 16; optional RequestPipeLocationMessage requestPipeLocMsg = 17; optional PipeLocationInfoMessage pipeLocInfoMsg = 18; + optional ParentTaskDataCollectMsg ParentTaskDataCollected = 19; + optional CurrentlyProcessedBytesCollectMsg currentlyProcessedBytesCollected = 20; + optional WorkStealingResultMsg sendWorkStealingResult = 22; } // Messages from Master to Executors @@ -256,3 +262,17 @@ message PipeLocationInfoMessage { required int64 requestId = 1; // To find the matching request msg required string executorId = 2; } + +message ParentTaskDataCollectMsg { + required string taskId = 1; + required bytes partitionSizeMap = 2; +} + +message CurrentlyProcessedBytesCollectMsg { + required string taskId = 1; + required int64 processedDataBytes = 2; +} + +message WorkStealingResultMsg { + required bytes workStealingResult = 1; +} diff --git a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/BlockOutputWriter.java b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/BlockOutputWriter.java index 97cde037c6..b9cf9ef129 100644 --- a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/BlockOutputWriter.java +++ b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/BlockOutputWriter.java @@ -52,6 +52,8 @@ public final class BlockOutputWriter implements OutputWriter { private long writtenBytes; + private Optional> partitionSizeMap; + /** * Constructor. * @@ -109,7 +111,7 @@ public void close() { final DataPersistenceProperty.Value persistence = (DataPersistenceProperty.Value) runtimeEdge .getPropertyValue(DataPersistenceProperty.class).orElseThrow(IllegalStateException::new); - final Optional> partitionSizeMap = blockToWrite.commit(); + partitionSizeMap = blockToWrite.commit(); // Return the total size of the committed block. if (partitionSizeMap.isPresent()) { long blockSizeTotal = 0; @@ -123,6 +125,16 @@ public void close() { blockManagerWorker.writeBlock(blockToWrite, blockStoreValue, getExpectedRead(), persistence); } + @Override + public Optional> getPartitionSizeMap() { + if (partitionSizeMap.isPresent()) { + return partitionSizeMap; + } else { + return Optional.empty(); + } + } + + @Override public Optional getWrittenBytes() { if (writtenBytes == -1) { return Optional.empty(); diff --git a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/OutputWriter.java b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/OutputWriter.java index bf6ff84e69..a1862f5f2d 100644 --- a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/OutputWriter.java +++ b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/OutputWriter.java @@ -20,6 +20,7 @@ import org.apache.nemo.common.punctuation.Watermark; +import java.util.Map; import java.util.Optional; /** @@ -45,5 +46,10 @@ public interface OutputWriter { */ Optional getWrittenBytes(); + /** + * @return the map of hashed key to partition size. + */ + Optional> getPartitionSizeMap(); + void close(); } diff --git a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/PipeOutputWriter.java b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/PipeOutputWriter.java index 544d64d921..d0025428aa 100644 --- a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/PipeOutputWriter.java +++ b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/PipeOutputWriter.java @@ -34,6 +34,7 @@ import java.io.IOException; import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.Optional; /** @@ -113,6 +114,11 @@ public Optional getWrittenBytes() { return Optional.empty(); } + @Override + public Optional> getPartitionSizeMap() { + return Optional.empty(); + } + @Override public void close() { if (!initialized) { diff --git a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/DataFetcher.java b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/DataFetcher.java index 7af08852eb..b1a828c13c 100644 --- a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/DataFetcher.java +++ b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/DataFetcher.java @@ -20,6 +20,7 @@ import org.apache.nemo.common.ir.OutputCollector; import org.apache.nemo.common.ir.vertex.IRVertex; +import org.apache.nemo.runtime.executor.MetricMessageSender; import java.io.IOException; @@ -49,6 +50,21 @@ abstract class DataFetcher implements AutoCloseable { */ abstract Object fetchDataElement() throws IOException; + /** + * Identical with fetchDataElement(), except it sends intermediate serializedReadBytes to MetricStore + * on every iterator advance. + * This method is for WorkStealing implementation in Nemo. + * + * @param taskId task id + * @param metricMessageSender metricMessageSender + * + * @return data element + * @throws IOException upon I/O error + * @throws java.util.NoSuchElementException if no more element is available + */ + abstract Object fetchDataElementWithTrace(String taskId, + MetricMessageSender metricMessageSender) throws IOException; + OutputCollector getOutputCollector() { return outputCollector; } diff --git a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/MultiThreadParentTaskDataFetcher.java b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/MultiThreadParentTaskDataFetcher.java index 797818ce44..d7947e8c78 100644 --- a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/MultiThreadParentTaskDataFetcher.java +++ b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/MultiThreadParentTaskDataFetcher.java @@ -22,6 +22,7 @@ import org.apache.nemo.common.ir.vertex.IRVertex; import org.apache.nemo.common.punctuation.Finishmark; import org.apache.nemo.common.punctuation.Watermark; +import org.apache.nemo.runtime.executor.MetricMessageSender; import org.apache.nemo.runtime.executor.data.DataUtil; import org.apache.nemo.runtime.executor.datatransfer.*; import org.slf4j.Logger; @@ -100,6 +101,12 @@ Object fetchDataElement() throws IOException { } } + @Override + Object fetchDataElementWithTrace(final String taskId, + final MetricMessageSender metricMessageSender) throws IOException { + return fetchDataElement(); + } + private void fetchDataLazily() { final List> futures = readersForParentTask.read(); numOfIterators = futures.size(); diff --git a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/ParentTaskDataFetcher.java b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/ParentTaskDataFetcher.java index a8ae4a9306..3a92cbc8a9 100644 --- a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/ParentTaskDataFetcher.java +++ b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/ParentTaskDataFetcher.java @@ -18,10 +18,12 @@ */ package org.apache.nemo.runtime.executor.task; +import org.apache.commons.lang3.SerializationUtils; import org.apache.nemo.common.ir.OutputCollector; import org.apache.nemo.common.ir.edge.executionproperty.BlockFetchFailureProperty; import org.apache.nemo.common.ir.vertex.IRVertex; import org.apache.nemo.common.punctuation.Finishmark; +import org.apache.nemo.runtime.executor.MetricMessageSender; import org.apache.nemo.runtime.executor.data.DataUtil; import org.apache.nemo.runtime.executor.datatransfer.InputReader; import org.slf4j.Logger; @@ -100,6 +102,49 @@ Object fetchDataElement() throws IOException { return Finishmark.getInstance(); } + @Override + Object fetchDataElementWithTrace(final String taskId, + final MetricMessageSender metricMessageSender) throws IOException { + try { + if (firstFetch) { + fetchDataLazily(); + advanceIterator(); + firstFetch = false; + } + + while (true) { + // This iterator has the element + if (this.currentIterator.hasNext()) { + return this.currentIterator.next(); + } + + // This iterator does not have the element + if (currentIteratorIndex < expectedNumOfIterators) { + // Next iterator has the element + countBytes(currentIterator); + // Send the cumulative serBytes to MetricStore + metricMessageSender.send("TaskMetric", taskId, "serializedReadBytes", + SerializationUtils.serialize(serBytes)); + advanceIterator(); + continue; + } else { + // We've consumed all the iterators + break; + } + + } + } catch (final Throwable e) { + // Any failure is caught and thrown as an IOException, so that the task is retried. + // In particular, we catch unchecked exceptions like RuntimeException thrown by DataUtil.IteratorWithNumBytes + // when remote data fetching fails for whatever reason. + // Note that we rely on unchecked exceptions because the Iterator interface does not provide the standard + // "throw Exception" that the TaskExecutor thread can catch and handle. + throw new IOException(e); + } + + return Finishmark.getInstance(); + } + private void advanceIterator() throws IOException { // Take from iteratorQueue final Object iteratorOrThrowable; diff --git a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/SourceVertexDataFetcher.java b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/SourceVertexDataFetcher.java index 2d82898d7a..68a3362d27 100644 --- a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/SourceVertexDataFetcher.java +++ b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/SourceVertexDataFetcher.java @@ -23,6 +23,7 @@ import org.apache.nemo.common.ir.vertex.SourceVertex; import org.apache.nemo.common.punctuation.Finishmark; import org.apache.nemo.common.punctuation.Watermark; +import org.apache.nemo.runtime.executor.MetricMessageSender; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; @@ -74,6 +75,11 @@ Object fetchDataElement() { } } + @Override + Object fetchDataElementWithTrace(final String taskId, final MetricMessageSender metricMessageSender) { + return fetchDataElement(); + } + final long getBoundedSourceReadTime() { return boundedSourceReadTime; } diff --git a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/TaskExecutor.java b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/TaskExecutor.java index 2bf574d396..91e8212640 100644 --- a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/TaskExecutor.java +++ b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/TaskExecutor.java @@ -19,6 +19,7 @@ package org.apache.nemo.runtime.executor.task; import com.google.common.collect.Lists; +import com.google.protobuf.ByteString; import org.apache.commons.lang3.SerializationUtils; import org.apache.commons.lang3.exception.ExceptionUtils; import org.apache.nemo.common.Pair; @@ -458,7 +459,7 @@ private boolean handleDataFetchers(final List fetchers) { while (availableIterator.hasNext()) { final DataFetcher dataFetcher = availableIterator.next(); try { - final Object element = dataFetcher.fetchDataElement(); + final Object element = dataFetcher.fetchDataElementWithTrace(taskId, metricMessageSender); onEventFromDataFetcher(element, dataFetcher); if (element instanceof Finishmark) { availableIterator.remove(); @@ -688,12 +689,21 @@ public void setIRVertexPutOnHold(final IRVertex irVertex) { */ private void finalizeOutputWriters(final VertexHarness vertexHarness) { final List writtenBytesList = new ArrayList<>(); + final HashMap partitionSizeMap = new HashMap<>(); // finalize OutputWriters for main children vertexHarness.getWritersToMainChildrenTasks().forEach(outputWriter -> { outputWriter.close(); final Optional writtenBytes = outputWriter.getWrittenBytes(); writtenBytes.ifPresent(writtenBytesList::add); + + // Send partitionSizeMap to Scheduler + if (true) { + final Optional> partitionSizes = outputWriter.getPartitionSizeMap(); + if (partitionSizes.isPresent()) { + computePartitionSizeMap(partitionSizeMap, partitionSizes.get()); + } + } }); // finalize OutputWriters for additional tagged children @@ -702,6 +712,14 @@ private void finalizeOutputWriters(final VertexHarness vertexHarness) { outputWriter.close(); final Optional writtenBytes = outputWriter.getWrittenBytes(); writtenBytes.ifPresent(writtenBytesList::add); + + // Send partitionSizeMap to Scheduler + if (true) { + final Optional> partitionSizes = outputWriter.getPartitionSizeMap(); + if (partitionSizes.isPresent()) { + computePartitionSizeMap(partitionSizeMap, partitionSizes.get()); + } + } }) ); @@ -713,5 +731,57 @@ private void finalizeOutputWriters(final VertexHarness vertexHarness) { // TODO #236: Decouple metric collection and sending logic metricMessageSender.send(TASK_METRIC_ID, taskId, "taskOutputBytes", SerializationUtils.serialize(totalWrittenBytes)); + + if (!partitionSizeMap.isEmpty()) { + persistentConnectionToMasterMap.getMessageSender(MessageEnvironment.RUNTIME_MASTER_MESSAGE_LISTENER_ID).send( + ControlMessage.Message.newBuilder() + .setId(RuntimeIdManager.generateMessageId()) + .setListenerId(MessageEnvironment.RUNTIME_MASTER_MESSAGE_LISTENER_ID) + .setType(ControlMessage.MessageType.ParentTaskDataCollected) + .setParentTaskDataCollected(ControlMessage.ParentTaskDataCollectMsg.newBuilder() + .setTaskId(taskId) + .setPartitionSizeMap(ByteString.copyFrom(SerializationUtils.serialize(partitionSizeMap))) + .build()) + .build()); + } + } + + // Methods for work stealing + /** + * Gather the KV statistics of processed data when execution is completed. + * This method is for work stealing implementation: the accumulated statistics will be used to + * detect skewed tasks of the child stage. + * + * @param totalPartitionSizeMap accumulated partitionSizeMap of task. + * @param singlePartitionSizeMap partitionSizeMap gained from single OutputWriter. + */ + private void computePartitionSizeMap(final Map totalPartitionSizeMap, + final Map singlePartitionSizeMap) { + for (Integer hashedKey : singlePartitionSizeMap.keySet()) { + final Long partitionSize = singlePartitionSizeMap.get(hashedKey); + if (totalPartitionSizeMap.containsKey(hashedKey)) { + totalPartitionSizeMap.compute(hashedKey, (existingKey, existingValue) -> existingValue + partitionSize); + } else { + totalPartitionSizeMap.put(hashedKey, partitionSize); + } + } + } + + /** + * Send the temporally processed bytes of the current task on request from the scheduler. + * This method is for work stealing implementation. + */ + public void onRequestForProcessedData() { + LOG.error("{}, bytes {}, replying for the request", taskId, serializedReadBytes); + persistentConnectionToMasterMap.getMessageSender(MessageEnvironment.RUNTIME_MASTER_MESSAGE_LISTENER_ID).send( + ControlMessage.Message.newBuilder() + .setId(RuntimeIdManager.generateMessageId()) + .setListenerId(MessageEnvironment.RUNTIME_MASTER_MESSAGE_LISTENER_ID) + .setType(ControlMessage.MessageType.CurrentlyProcessedBytesCollected) + .setCurrentlyProcessedBytesCollected(ControlMessage.CurrentlyProcessedBytesCollectMsg.newBuilder() + .setTaskId(this.taskId) + .setProcessedDataBytes(serializedReadBytes) + .build()) + .build()); } } diff --git a/runtime/master/src/main/java/org/apache/nemo/runtime/master/PlanStateManager.java b/runtime/master/src/main/java/org/apache/nemo/runtime/master/PlanStateManager.java index 53cab57810..b8f8f7b335 100644 --- a/runtime/master/src/main/java/org/apache/nemo/runtime/master/PlanStateManager.java +++ b/runtime/master/src/main/java/org/apache/nemo/runtime/master/PlanStateManager.java @@ -85,6 +85,10 @@ public final class PlanStateManager { private final Map> stageIdToCompletedTaskTimeMsList = new HashMap<>(); private final Map> stageIdToTaskIndexToNumOfClones = new HashMap<>(); + /** + * Used for work stealing. + */ + private final Map>> stageIdToTaskIdxToWSAttemptStates = new HashMap<>(); /** * Represents the plan to manage. */ @@ -127,7 +131,7 @@ public static PlanStateManager newInstance(final String dagDirectory) { } /** - * @param metricStore set the metric store of the paln state manager. + * @param metricStore set the metric store of the plan state manager. */ public void setMetricStore(final MetricStore metricStore) { this.metricStore = metricStore; @@ -156,6 +160,21 @@ public synchronized void updatePlan(final PhysicalPlan physicalPlanToUpdate, initializeStates(); } + /** + * Add work stealing tasks to the plan. + * @param workStealingTasks work stealing tasks. + */ + public synchronized void addWorkStealingTasks(final Set workStealingTasks) { + for (String taskId : workStealingTasks) { + final String stageId = RuntimeIdManager.getStageIdFromTaskId(taskId); + final int taskIdx = RuntimeIdManager.getIndexFromTaskId(taskId); + stageIdToTaskIdxToWSAttemptStates.putIfAbsent(stageId, new HashMap<>()); + List attemptStatesForThisTask = new ArrayList<>(); + attemptStatesForThisTask.add(new TaskState()); + stageIdToTaskIdxToWSAttemptStates.get(stageId).putIfAbsent(taskIdx, attemptStatesForThisTask); + } + } + /** * Initializes the states for the plan/stages/tasks for this plan. * TODO #182: Consider reshaping in run-time optimization. At now, we only consider plan appending. @@ -326,16 +345,11 @@ public synchronized void onTaskStateChanged(final String taskId, final TaskState // Log not-yet-completed tasks for us humans to track progress final String stageId = RuntimeIdManager.getStageIdFromTaskId(taskId); final Map> taskStatesOfThisStage = stageIdToTaskIdxToAttemptStates.get(stageId); - final long numOfCompletedTaskIndicesInThisStage = taskStatesOfThisStage.values().stream() - .filter(attempts -> { - final List states = attempts - .stream() - .map(state -> (TaskState.State) state.getStateMachine().getCurrentState()) - .collect(Collectors.toList()); - return states.stream().anyMatch(curState -> curState.equals(TaskState.State.ON_HOLD)) // one of them is ON_HOLD - || states.stream().anyMatch(curState -> curState.equals(TaskState.State.COMPLETE)); // one of them is COMPLETE - }) - .count(); + final Map> wsTaskStatesOfThisStage = + stageIdToTaskIdxToWSAttemptStates.getOrDefault(stageId, new HashMap<>()); + final long numOfCompletedTaskIndicesInThisStage = getNumberOfCompletedTasksInStage(taskStatesOfThisStage) + + getNumberOfCompletedTasksInStage(wsTaskStatesOfThisStage); + if (newTaskState.equals(TaskState.State.COMPLETE)) { LOG.info("{} completed: {} Task(s) out of {} are remaining in this stage", taskId, taskStatesOfThisStage.size() - numOfCompletedTaskIndicesInThisStage, taskStatesOfThisStage.size()); @@ -364,9 +378,18 @@ public synchronized void onTaskStateChanged(final String taskId, final TaskState // COMPLETE stage case COMPLETE: case ON_HOLD: - if (numOfCompletedTaskIndicesInThisStage - == physicalPlan.getStageDAG().getVertexById(stageId).getTaskIndices().size()) { - onStageStateChanged(stageId, StageState.State.COMPLETE); + // if work stealing enabled + if (!physicalPlan.getStageDAG().getVertexById(stageId).getWorkStealingTaskIds().isEmpty()) { + if (numOfCompletedTaskIndicesInThisStage + == physicalPlan.getStageDAG().getVertexById(stageId).getTaskIndices().size() + + physicalPlan.getStageDAG().getVertexById(stageId).getWorkStealingTaskIds().size()) { + onStageStateChanged(stageId, StageState.State.COMPLETE); + } + } else { + if (numOfCompletedTaskIndicesInThisStage + == physicalPlan.getStageDAG().getVertexById(stageId).getTaskIndices().size()) { + onStageStateChanged(stageId, StageState.State.COMPLETE); + } } break; @@ -550,10 +573,18 @@ private Map getTaskAttemptIdsToItsState(final String st } private TaskState getTaskStateHelper(final String taskId) { - return stageIdToTaskIdxToAttemptStates - .get(RuntimeIdManager.getStageIdFromTaskId(taskId)) - .get(RuntimeIdManager.getIndexFromTaskId(taskId)) - .get(RuntimeIdManager.getAttemptFromTaskId(taskId)); + final boolean isWorkStealingTask = taskId.split("-")[2].equals("*"); + if (isWorkStealingTask) { + return stageIdToTaskIdxToWSAttemptStates + .get(RuntimeIdManager.getStageIdFromTaskId(taskId)) + .get(RuntimeIdManager.getIndexFromTaskId(taskId)) + .get(RuntimeIdManager.getAttemptFromTaskId(taskId)); + } else { + return stageIdToTaskIdxToAttemptStates + .get(RuntimeIdManager.getStageIdFromTaskId(taskId)) + .get(RuntimeIdManager.getIndexFromTaskId(taskId)) + .get(RuntimeIdManager.getAttemptFromTaskId(taskId)); + } } private boolean isTaskNotDone(final TaskState taskState) { @@ -577,6 +608,59 @@ private List getPeerAttemptsForTheSameTaskIndex(final String ta .collect(Collectors.toList()); } + /** + * Get number of remaining tasks of the stage. + * + * @param stageId stage id. + * @return number of remaining tasks. + */ + public int getNumberOfTasksRemainingInStage(final String stageId) { + final Map> taskStatesOfThisStage = stageIdToTaskIdxToAttemptStates.get(stageId); + final Map> wsTaskStatesOfThisStage = stageIdToTaskIdxToWSAttemptStates + .getOrDefault(stageId, new HashMap<>()); + final long numOfCompletedTaskIndices = getNumberOfCompletedTasksInStage(taskStatesOfThisStage); + if (wsTaskStatesOfThisStage.isEmpty()) { + return (int) (taskStatesOfThisStage.size() - numOfCompletedTaskIndices); + } else { + final long numOfCompletedWorkStealingTaskIndices = getNumberOfCompletedTasksInStage(wsTaskStatesOfThisStage); + return (int) (taskStatesOfThisStage.size() - numOfCompletedTaskIndices + + wsTaskStatesOfThisStage.size() - numOfCompletedWorkStealingTaskIndices); + } + } + + /** + * Get tasks which are currently being executed. + * + * @param stageId stage id. + * @return Set of tasksIds in execution. + */ + public Set getOngoingTaskIdsInStage(final String stageId) { + final Map> taskIdToState = stageIdToTaskIdxToAttemptStates.get(stageId); + final Set onGoingTaskIds = new HashSet<>(); + for (final int taskIndex : taskIdToState.keySet()) { + final List attemptStates = taskIdToState.get(taskIndex); + for (int attempt = 0; attempt < attemptStates.size(); attempt++) { + if (attemptStates.get(attempt).getStateMachine().getCurrentState().equals(TaskState.State.EXECUTING)) { + onGoingTaskIds.add(RuntimeIdManager.generateTaskId(stageId, taskIndex, attempt)); + } + } + } + return onGoingTaskIds; + } + + private long getNumberOfCompletedTasksInStage(final Map> taskIdxToState) { + return taskIdxToState.values().stream() + .filter(attempts -> { + final List states = attempts + .stream() + .map(state -> (TaskState.State) state.getStateMachine().getCurrentState()) + .collect(Collectors.toList()); + return states.stream().anyMatch(curState -> curState.equals(TaskState.State.ON_HOLD)) + || states.stream().anyMatch(curState -> curState.equals(TaskState.State.COMPLETE)); + }) + .count(); + } + /** * @return the physical plan. */ diff --git a/runtime/master/src/main/java/org/apache/nemo/runtime/master/RuntimeMaster.java b/runtime/master/src/main/java/org/apache/nemo/runtime/master/RuntimeMaster.java index d3b48f266a..44eb0f7a78 100644 --- a/runtime/master/src/main/java/org/apache/nemo/runtime/master/RuntimeMaster.java +++ b/runtime/master/src/main/java/org/apache/nemo/runtime/master/RuntimeMaster.java @@ -58,10 +58,7 @@ import javax.inject.Inject; import java.io.Serializable; import java.nio.file.Paths; -import java.util.HashSet; -import java.util.List; -import java.util.Optional; -import java.util.Set; +import java.util.*; import java.util.concurrent.*; import java.util.concurrent.atomic.AtomicInteger; @@ -85,9 +82,11 @@ public final class RuntimeMaster { private static final int METRIC_ARRIVE_TIMEOUT = 10000; private static final int REST_SERVER_PORT = 10101; private static final int SPECULATION_CHECKING_PERIOD_MS = 100; + private static final int WORK_STEALING_CHECKING_PERIOD_MS = 100; private final ExecutorService runtimeMasterThread; private final ScheduledExecutorService speculativeTaskCloningThread; + private final ScheduledExecutorService workStealingThread; private final Scheduler scheduler; private final ContainerManager containerManager; @@ -160,6 +159,16 @@ private RuntimeMaster(final Scheduler scheduler, SPECULATION_CHECKING_PERIOD_MS, TimeUnit.MILLISECONDS); + // Check for work stealing every second + this.workStealingThread = Executors + .newSingleThreadScheduledExecutor(runnable -> new Thread(runnable, "WorkStealing master thread")); + this.workStealingThread.scheduleWithFixedDelay( + () -> this.runtimeMasterThread.submit(scheduler::onWorkStealingCheck), + WORK_STEALING_CHECKING_PERIOD_MS, + WORK_STEALING_CHECKING_PERIOD_MS, + TimeUnit.MILLISECONDS); + + this.scheduler = scheduler; this.containerManager = containerManager; this.executorRegistry = executorRegistry; @@ -481,6 +490,22 @@ private void handleControlMessage(final ControlMessage.Message message) { .setDataCollected(ControlMessage.DataCollectMessage.newBuilder().setData(serializedData).build()) .build()); break; + case ParentTaskDataCollected: + if (scheduler instanceof BatchScheduler) { + final ControlMessage.ParentTaskDataCollectMsg workStealingMsg = message.getParentTaskDataCollected(); + final String taskId = workStealingMsg.getTaskId(); + final Map partitionSizeMap = SerializationUtils + .deserialize(workStealingMsg.getPartitionSizeMap().toByteArray()); + ((BatchScheduler) scheduler).aggregateStageIdToPartitionSizeMap(taskId, partitionSizeMap); + } + break; + case CurrentlyProcessedBytesCollected: + if (scheduler instanceof BatchScheduler) { + ((BatchScheduler) scheduler).aggregateTaskIdToProcessedBytes( + message.getCurrentlyProcessedBytesCollected().getTaskId(), + message.getCurrentlyProcessedBytesCollected().getProcessedDataBytes() + ); + } case MetricFlushed: metricCountDownLatch.countDown(); break; diff --git a/runtime/master/src/main/java/org/apache/nemo/runtime/master/resource/DefaultExecutorRepresenter.java b/runtime/master/src/main/java/org/apache/nemo/runtime/master/resource/DefaultExecutorRepresenter.java index 16c9a70db9..ebec804132 100644 --- a/runtime/master/src/main/java/org/apache/nemo/runtime/master/resource/DefaultExecutorRepresenter.java +++ b/runtime/master/src/main/java/org/apache/nemo/runtime/master/resource/DefaultExecutorRepresenter.java @@ -170,6 +170,11 @@ public void onTaskExecutionFailed(final String taskId) { failedTasks.add(failedTask); } + @Override + public boolean isExecutorSlotAvailable() { + return getExecutorCapacity() - getNumOfRunningTasks() > 0; + } + /** * @return how many Tasks can this executor simultaneously run */ diff --git a/runtime/master/src/main/java/org/apache/nemo/runtime/master/resource/ExecutorRepresenter.java b/runtime/master/src/main/java/org/apache/nemo/runtime/master/resource/ExecutorRepresenter.java index 26649a81db..dcfb53eb1c 100644 --- a/runtime/master/src/main/java/org/apache/nemo/runtime/master/resource/ExecutorRepresenter.java +++ b/runtime/master/src/main/java/org/apache/nemo/runtime/master/resource/ExecutorRepresenter.java @@ -108,4 +108,9 @@ public interface ExecutorRepresenter { * @param taskId id of the Task */ void onTaskExecutionFailed(String taskId); + + /** + * @return true if this executor has an available slot. + */ + boolean isExecutorSlotAvailable(); } diff --git a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java index 086c9d08bd..d30ea42394 100644 --- a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java +++ b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java @@ -18,18 +18,26 @@ */ package org.apache.nemo.runtime.master.scheduler; +import com.google.protobuf.ByteString; import org.apache.commons.lang.mutable.MutableBoolean; +import org.apache.commons.lang3.SerializationUtils; import org.apache.nemo.common.Pair; +import org.apache.nemo.common.dag.Vertex; import org.apache.nemo.common.exception.UnknownExecutionStateException; import org.apache.nemo.common.exception.UnrecoverableFailureException; +import org.apache.nemo.common.ir.Readable; import org.apache.nemo.common.ir.vertex.executionproperty.ClonedSchedulingProperty; import org.apache.nemo.runtime.common.RuntimeIdManager; +import org.apache.nemo.runtime.common.comm.ControlMessage; +import org.apache.nemo.runtime.common.message.MessageEnvironment; +import org.apache.nemo.runtime.common.metric.TaskMetric; import org.apache.nemo.runtime.common.plan.*; import org.apache.nemo.runtime.common.state.StageState; import org.apache.nemo.runtime.common.state.TaskState; import org.apache.nemo.runtime.master.BlockManagerMaster; import org.apache.nemo.runtime.master.PlanAppender; import org.apache.nemo.runtime.master.PlanStateManager; +import org.apache.nemo.runtime.master.metric.MetricStore; import org.apache.nemo.runtime.master.resource.ExecutorRepresenter; import org.apache.reef.annotations.audience.DriverSide; import org.slf4j.Logger; @@ -38,7 +46,9 @@ import javax.annotation.Nullable; import javax.annotation.concurrent.NotThreadSafe; import javax.inject.Inject; +import java.io.Serializable; import java.util.*; +import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; /** @@ -66,6 +76,7 @@ public final class BatchScheduler implements Scheduler { private final PendingTaskCollectionPointer pendingTaskCollectionPointer; // A 'pointer' to the list of pending tasks. private final ExecutorRegistry executorRegistry; // A registry for executors available for the job. private final PlanStateManager planStateManager; // A component that manages the state of the plan. + private final MetricStore metricStore = MetricStore.getStore(); /** * Other necessary components of this {@link org.apache.nemo.runtime.master.RuntimeMaster}. @@ -77,6 +88,14 @@ public final class BatchScheduler implements Scheduler { */ private List> sortedScheduleGroups; // Stages, sorted in the order to be scheduled. + /** + * Data Structures for work stealing. + */ + private final Set workStealingCandidates = new HashSet<>(); + private final Map> stageIdToOutputPartitionSizeMap = new HashMap<>(); + private final Map taskIdToProcessedBytes = new HashMap<>(); + private final Map stageIdToWorkStealingExecuted = new HashMap<>(); + @Inject private BatchScheduler(final PlanRewriter planRewriter, final TaskDispatcher taskDispatcher, @@ -111,6 +130,11 @@ public void updatePlan(final PhysicalPlan newPhysicalPlan) { private void updatePlan(final PhysicalPlan newPhysicalPlan, final int maxScheduleAttempt) { planStateManager.updatePlan(newPhysicalPlan, maxScheduleAttempt); + + for (Stage stage : planStateManager.getPhysicalPlan().getStageDAG().getVertices()) { + stageIdToWorkStealingExecuted.putIfAbsent(stage.getId(), false); + } + this.sortedScheduleGroups = newPhysicalPlan.getStageDAG().getVertices().stream() .collect(Collectors.groupingBy(Stage::getScheduleGroup)) .entrySet().stream() @@ -258,6 +282,51 @@ public void onSpeculativeExecutionCheck() { } } + @Override + public void onWorkStealingCheck() { + final List scheduleGroup = BatchSchedulerUtils + .selectEarliestSchedulableGroup(sortedScheduleGroups, planStateManager).orElse(new ArrayList<>()); + final List scheduleGroupInId = scheduleGroup.stream().map(Stage::getId).collect(Collectors.toList()); + final MutableBoolean isWorkStealingConditionSatisfied = new MutableBoolean(false); + final Map> wsResult = new HashMap<>(); + + /* check if work stealing is possible. If not, return */ + isWorkStealingConditionSatisfied.setValue(checkForWorkStealingBaseConditions(scheduleGroupInId)); + if (!isWorkStealingConditionSatisfied.booleanValue()) { + return; + } + + /* detect skewed tasks */ + taskIdToProcessedBytes.clear(); + final List skewedTasks = detectSkew(scheduleGroupInId); + + /* if there are no skewed tasks, return */ + if (skewedTasks.isEmpty()) { + return; + } + + /* generate work stealing tasks */ + final Map> taskToSplitIteratorInfo = splitIterator(skewedTasks); + final List wsTasks = generateWorkStealingTasks(scheduleGroup, skewedTasks, taskToSplitIteratorInfo); + + /* accumulate result */ + for (String taskId : workStealingCandidates) { + if (skewedTasks.contains(taskId)) { // this is for skewed task + Pair iteratorInfo = taskToSplitIteratorInfo.get(taskId); + wsResult.put(taskId, Pair.of(0, iteratorInfo.left())); + } else { // this is for non skewed tasks + wsResult.put(taskId, Pair.of(0, Integer.MAX_VALUE)); + } + } + + /* notify the updated information to executors */ + sendWorkStealingResultToExecutor(wsResult); + + /* schedule new tasks */ + pendingTaskCollectionPointer.setToOverwrite(wsTasks); + taskDispatcher.onNewPendingTaskCollectionAvailable(); + } + @Override public void onExecutorAdded(final ExecutorRepresenter executorRepresenter) { LOG.info("{} added (node: {})", executorRepresenter.getExecutorId(), executorRepresenter.getNodeName()); @@ -304,6 +373,9 @@ public void terminate() { * - We make {@link TaskDispatcher} dispatch only the tasks that are READY. */ private void doSchedule() { + taskIdToProcessedBytes.clear(); + workStealingCandidates.clear(); + final Optional> earliest = BatchSchedulerUtils.selectEarliestSchedulableGroup(sortedScheduleGroups, planStateManager); @@ -383,4 +455,349 @@ private boolean modifyStageNumCloneUsingMedianTime(final String stageId, return false; } + + ///////////////////////////////////////////////////////////////// Methods for work stealing + + /** + * Accumulate the execution result of each stage in Map[STAGE ID, Map[KEY, SIZE]] format. + * KEY is assumed to be Integer because of the HashPartition. + * + * @param taskId id of task to accumulate. + * @param partitionSizeMap map of (K) - (partition size) of the task. + */ + public void aggregateStageIdToPartitionSizeMap(final String taskId, + final Map partitionSizeMap) { + final Map partitionSizeMapForThisStage = stageIdToOutputPartitionSizeMap + .getOrDefault(RuntimeIdManager.getStageIdFromTaskId(taskId), new HashMap<>()); + for (Integer hashedKey : partitionSizeMap.keySet()) { + final Long partitionSize = partitionSizeMap.get(hashedKey); + if (partitionSizeMapForThisStage.containsKey(hashedKey)) { + partitionSizeMapForThisStage.put(hashedKey, partitionSize + partitionSizeMapForThisStage.get(hashedKey)); + } else { + partitionSizeMapForThisStage.put(hashedKey, partitionSize); + } + } + stageIdToOutputPartitionSizeMap.put(RuntimeIdManager.getStageIdFromTaskId(taskId), partitionSizeMapForThisStage); + } + + /** + * Store the tracked processed bytes per task by the current time. + * + * @param taskId id of task to track. + * @param processedBytes size of the processed bytes till now. + */ + public void aggregateTaskIdToProcessedBytes(final String taskId, + final long processedBytes) { + taskIdToProcessedBytes.put(taskId, processedBytes); + } + + /** + * Check if work stealing can be conducted. + * + * @param scheduleGroup schedule group. + * @return true if work stealing is possible. + */ + private boolean checkForWorkStealingBaseConditions(final List scheduleGroup) { + if (scheduleGroup.isEmpty()) { + return false; + } + + /* If the stage of the given schedule group contains sharded tasks, return false */ + if (scheduleGroup.stream().anyMatch(stageId -> stageIdToWorkStealingExecuted.get(stageId).equals(true))) { + return false; + } + + /* If there are idle executors and the number of remaining tasks are smaller than number of executors, + * return true. */ + final boolean executorStatus = executorRegistry.isExecutorSlotAvailable(); + final int totalNumberOfSlots = executorRegistry.getTotalNumberOfExecutorSlots(); + int remainingTasks = 0; + for (String stage : scheduleGroup) { + remainingTasks += planStateManager.getNumberOfTasksRemainingInStage(stage); // ready + executing? + } + return executorStatus && (totalNumberOfSlots > remainingTasks); + } + + /** + * Get the ids of tasks in execution. + * + * @param scheduleGroup schedule group. + * @return ids of running tasks. + */ + private Set getRunningTaskId(final List scheduleGroup) { + final Set onGoingTasksOfSchedulingGroup = new HashSet<>(); + for (String stageId : scheduleGroup) { + onGoingTasksOfSchedulingGroup.addAll(planStateManager.getOngoingTaskIdsInStage(stageId)); + } + return onGoingTasksOfSchedulingGroup; + } + + /** + * Get parent stages of given schedule group. + * + * @param scheduleGroup schedule group. + * @return Map of stage and set of its parent. + */ + private Map> getParentStages(final List scheduleGroup) { + Map> parentStages = new HashMap<>(); + for (String stageId : scheduleGroup) { + parentStages.put(stageId, planStateManager.getPhysicalPlan().getStageDAG().getParents(stageId) + .stream() + .map(Vertex::getId) + .collect(Collectors.toSet())); + } + return parentStages; + } + + /** + * Get the input size of running tasks. + * + * @param parentStageIds id of parent stages. + * @param runningTaskIds id of running tasks. + * @return Map of task id to its input size. + */ + private Map getInputSizeOfRunningTasks(final Set parentStageIds, + final Set runningTaskIds) { + Map currentlyRunningTaskIdsToTotalSize = new HashMap<>(); + for (String parent : parentStageIds) { + Map taskIdxToSize = stageIdToOutputPartitionSizeMap.get(parent); + for (String taskId : runningTaskIds) { + if (currentlyRunningTaskIdsToTotalSize.containsKey(taskId)) { + final long existingValue = currentlyRunningTaskIdsToTotalSize.get(taskId); + currentlyRunningTaskIdsToTotalSize.put(taskId, + existingValue + taskIdxToSize.get(RuntimeIdManager.getIndexFromTaskId(taskId))); + } else { + currentlyRunningTaskIdsToTotalSize + .put(taskId, taskIdxToSize.get(RuntimeIdManager.getIndexFromTaskId(taskId))); + } + } + } + return currentlyRunningTaskIdsToTotalSize; + } + + /** + * get current execution time of running tasks in millisecond. + * Note that this is the execution time of incomplete tasks. + * + * @param scheduleGroup schedule group. + * @return Map of task id to its execution time. + */ + private Map getCurrentExecutionTimeMsOfRunningTasks(final List scheduleGroup) { + final Map taskToExecutionTime = new HashMap<>(); + for (String stageId : scheduleGroup) { + taskToExecutionTime.putAll(planStateManager.getExecutingTaskToRunningTimeMs(stageId)); + } + return taskToExecutionTime; + } + + private List getScheduleGroupByStage(final String stageId) { + return sortedScheduleGroups.get( + planStateManager.getPhysicalPlan().getStageDAG().getVertexById(stageId).getScheduleGroup()) + .stream() + .map(Vertex::getId) + .collect(Collectors.toList()); + } + + /** + * Detect skewed tasks. + * + * @param scheduleGroup current schedule group. + * @return List of skewed tasks. + */ + private List detectSkew(final List scheduleGroup) { + final Map> taskIdToIteratorInformation = new HashMap<>(); + final Map taskIdToInitializationOverhead = new HashMap<>(); + final Map inputSizeOfCandidateTasks = new HashMap<>(); + final Map> parentStageId = getParentStages(scheduleGroup); + + + /* if this schedule group contains a source stage, return empty list */ + if (scheduleGroup.stream().anyMatch(stage -> + planStateManager.getPhysicalPlan().getStageDAG().getParents(stage).isEmpty())) { + return new ArrayList<>(); + } + + workStealingCandidates.addAll(getRunningTaskId(scheduleGroup)); + + /* Gather statistics of work stealing candidates */ + /* get size of running tasks */ + for (String stage : scheduleGroup) { + inputSizeOfCandidateTasks.putAll( + getInputSizeOfRunningTasks(parentStageId.get(stage), workStealingCandidates)); + } + + /* get elapsed time */ + Map taskIdToElapsedTime = getCurrentExecutionTimeMsOfRunningTasks(scheduleGroup); + + /* gather task metric */ + for (String taskId : workStealingCandidates) { + TaskMetric taskMetric = metricStore.getMetricWithId(TaskMetric.class, taskId); + + taskIdToProcessedBytes.put(taskId, taskMetric.getSerializedReadBytes()); + taskIdToIteratorInformation.put(taskId, Pair.of( + taskMetric.getCurrentIteratorIndex(), taskMetric.getTotalIteratorNumber())); + taskIdToInitializationOverhead.put(taskId, taskMetric.getTaskPreparationTime()); + } + + /* If gathered statistic is not sufficient for skew detection, return empty list. */ + if (taskIdToProcessedBytes.size() <= workStealingCandidates.size() / 2) { + return new ArrayList<>(); + } + + /* estimate the remaining time */ + List> estimatedTimeToFinishPerTask = new ArrayList<>(taskIdToElapsedTime.size()); + + for (String taskId : taskIdToProcessedBytes.keySet()) { + // if processed bytes are not available, do not detect skew. + if (taskIdToProcessedBytes.get(taskId) <= 0) { + return new ArrayList<>(); + } + + // if this task is almost finished, ignore it. + Pair iteratorInformation = taskIdToIteratorInformation.get(taskId); + if (iteratorInformation.right() - iteratorInformation.left() <= 2) { + continue; + } + + long timeToFinishExecute = taskIdToElapsedTime.get(taskId) * inputSizeOfCandidateTasks.get(taskId) + / taskIdToProcessedBytes.get(taskId); + + // if the estimated left time is shorter than the initialization overhead, stop! + if (timeToFinishExecute < taskIdToInitializationOverhead.get(taskId) * 2) { + continue; + } + + estimatedTimeToFinishPerTask.add(Pair.of(taskId, timeToFinishExecute)); + } + + /* detect skew */ + Collections.sort(estimatedTimeToFinishPerTask, new Comparator>() { + @Override + public int compare(final Pair o1, final Pair o2) { + return o2.right().compareTo(o1.right()); + } + }); + + /* return only longer half */ + return estimatedTimeToFinishPerTask + .subList(0, estimatedTimeToFinishPerTask.size() / 2) + .stream().map(Pair::left).collect(Collectors.toList()); + } + + /** + * Calculate the iterator range of work stealing tasks. + * Given a skewed task, it calculates the iterator range which work stealing task will take from the task. + * + * @param skewedTasks List of skewed (original) tasks. + * @return Map of skewed task ID to iterator information. + * pair.left() is the starting index (inclusive) and pair.right() ending index (exclusive). + */ + private Map> splitIterator(final List skewedTasks) { + final Map> taskToIteratorInfo = new HashMap<>(); + + for (String taskId : skewedTasks) { + TaskMetric taskMetric = metricStore.getMetricWithId(TaskMetric.class, taskId); + int currIterIdx = taskMetric.getCurrentIteratorIndex(); + int totalIterIndex = taskMetric.getTotalIteratorNumber(); + int changePoint = (int) Math.floor((totalIterIndex + currIterIdx) / 2 + 1); + + taskToIteratorInfo.put(taskId, Pair.of(changePoint, totalIterIndex)); + } + + return taskToIteratorInfo; + } + + /** + * Generate work stealing tasks. + * + * @param scheduleGroup schedule group. + * @param skewedTasks List of skewed (original) tasks. + * @param taskToIteratorInfo Map of work stealing task ID to its iterator range information. + * @return List of work stealer tasks. + */ + private List generateWorkStealingTasks(final List scheduleGroup, + final List skewedTasks, + final Map> taskToIteratorInfo) { + final List tasksToSchedule = new ArrayList<>(skewedTasks.size()); + + /* tasks are generated in stage based: loop by stage, not schedule group */ + for (Stage stageToSchedule : scheduleGroup) { + String stageId = stageToSchedule.getId(); + + /* make new task ids and store that information in stage and plan state manager. + * for now, id logic for work stealing tasks is as follows: + * - same stage id + * - same index number + * - attempt number is replaced with "*", similar with the block wildcard id. + */ + + /* generate work stealing task id */ + final Set newTaskIds = skewedTasks.stream() + .filter(taskId -> taskId.contains(stageId)) + .map(RuntimeIdManager::generateWorkStealingTaskId) + .collect(Collectors.toSet()); + + /* if there are no work stealing tasks in this stage, pass */ + if (newTaskIds.isEmpty()) { + continue; + } + + /* update the work stealing tasks in Stage and PlanStateManager */ + planStateManager.getPhysicalPlan().getStageDAG() + .getVertexById(stageId).setWorkStealingTaskIds(newTaskIds); + planStateManager.addWorkStealingTasks(newTaskIds); + + /* create work stealing task */ + final List stageIncomingEdges = + planStateManager.getPhysicalPlan().getStageDAG().getIncomingEdgesOf(stageToSchedule.getId()); + final List stageOutgoingEdges = + planStateManager.getPhysicalPlan().getStageDAG().getOutgoingEdgesOf(stageToSchedule.getId()); + final List> vertexIdToReadable = stageToSchedule.getVertexIdToReadables(); + + skewedTasks.forEach(taskId -> { + final Set blockIds = BatchSchedulerUtils.getOutputBlockIds(planStateManager, taskId); + blockManagerMaster.onProducerTaskScheduled(taskId, blockIds); + final int taskIdx = RuntimeIdManager.getIndexFromTaskId(taskId); + + int startIterIdx = taskToIteratorInfo.get(taskId).left(); + int endIterIndex = taskToIteratorInfo.get(taskId).right(); + + tasksToSchedule.add(new Task( + planStateManager.getPhysicalPlan().getPlanId(), + RuntimeIdManager.generateWorkStealingTaskId(taskId), + stageToSchedule.getExecutionProperties(), + stageToSchedule.getSerializedIRDAG(), + stageIncomingEdges, + stageOutgoingEdges, + vertexIdToReadable.get(taskIdx), + new AtomicInteger(startIterIdx), + new AtomicInteger(endIterIndex))); + }); + + + // do work stealing for only once : this is because of the index based task state tracking system + // Need to be handled in the near future! + stageIdToWorkStealingExecuted.put(stageId, true); + } + + return tasksToSchedule; + } + + /** + * Send the accumulated iterator information (work stealing result) to executor. + * + * @param result result to send. + */ + private void sendWorkStealingResultToExecutor(final Map> result) { + final byte[] serialized = SerializationUtils.serialize((Serializable) result); + ControlMessage.Message message = ControlMessage.Message.newBuilder() + .setId(RuntimeIdManager.generateMessageId()) + .setListenerId(MessageEnvironment.EXECUTOR_MESSAGE_LISTENER_ID) + .setType(ControlMessage.MessageType.SendWorkStealingResult) + .setSendWorkStealingResult(ControlMessage.WorkStealingResultMsg.newBuilder() + .setWorkStealingResult(ByteString.copyFrom(serialized)) + .build()) + .build(); + executorRegistry.viewExecutors(executors -> executors.forEach(executor -> executor.sendControlMessage(message))); + } } diff --git a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/ExecutorRegistry.java b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/ExecutorRegistry.java index 11d40c73b8..5cead6e290 100644 --- a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/ExecutorRegistry.java +++ b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/ExecutorRegistry.java @@ -126,6 +126,14 @@ private Set getRunningExecutors() { .collect(Collectors.toSet()); } + public int getTotalNumberOfExecutorSlots() { + return getRunningExecutors().stream().mapToInt(ExecutorRepresenter::getExecutorCapacity).sum(); + } + + public boolean isExecutorSlotAvailable() { + return getRunningExecutors().stream().anyMatch(ExecutorRepresenter::isExecutorSlotAvailable); + } + @Override public String toString() { return executors.toString(); diff --git a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/Scheduler.java b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/Scheduler.java index cc4661df64..afe30f6e73 100644 --- a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/Scheduler.java +++ b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/Scheduler.java @@ -86,6 +86,11 @@ void onTaskStateReportFromExecutor(String executorId, */ void onSpeculativeExecutionCheck(); + /** + * Called to check for work stealing condition. + */ + void onWorkStealingCheck(); + /** * To be called when a job should be terminated. * Any clean up code should be implemented in this method. diff --git a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/SimulationScheduler.java b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/SimulationScheduler.java index 42870f609e..5885aa0ada 100644 --- a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/SimulationScheduler.java +++ b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/SimulationScheduler.java @@ -451,6 +451,12 @@ public void onSpeculativeExecutionCheck() { return; } + @Override + public void onWorkStealingCheck() { + // we don't simulate work stealing yet. + return; + } + @Override public void terminate() { this.taskDispatcher.terminate(); diff --git a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/StreamingScheduler.java b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/StreamingScheduler.java index 24e30bec87..ffa2c586da 100644 --- a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/StreamingScheduler.java +++ b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/StreamingScheduler.java @@ -149,6 +149,11 @@ public void onSpeculativeExecutionCheck() { throw new UnsupportedOperationException(); } + @Override + public void onWorkStealingCheck() { + throw new UnsupportedOperationException(); + } + @Override public void onExecutorAdded(final ExecutorRepresenter executorRepresenter) { LOG.info("{} added (node: {})", executorRepresenter.getExecutorId(), executorRepresenter.getNodeName());