Skip to content
This repository was archived by the owner on Jul 1, 2025. It is now read-only.
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,15 @@ public static String generateTaskId(final String stageId, final int index, final
return stageId + SPLITTER + index + SPLITTER + attempt;
}

/**
* Generates the ID of a task created by Work Stealing.
* @param taskId the ID of original task.
* @return the generated ID.
*/
public static String generateWorkStealingTaskId(final String taskId) {
return getStageIdFromTaskId(taskId) + SPLITTER + getIndexFromTaskId(taskId) + SPLITTER + "*";
}

/**
* Generates the ID for executor.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ public class TaskMetric implements StateMetric<TaskState.State> {
private long shuffleReadTime = -1;
private long shuffleWriteBytes = -1;
private long shuffleWriteTime = -1;
private int currentIteratorIndex = -1;
private int totalIteratorNumber = -1;
private long taskPreparationTime = -1;

private static final Logger LOG = LoggerFactory.getLogger(TaskMetric.class.getName());

Expand Down Expand Up @@ -252,6 +255,30 @@ private void setShuffleWriteTime(final long shuffleWriteTime) {
this.shuffleWriteTime = shuffleWriteTime;
}

public final int getCurrentIteratorIndex() {
return this.currentIteratorIndex;
}

private void setCurrentIteratorIndex(final int currentIteratorIndex) {
this.currentIteratorIndex = currentIteratorIndex;
}

public final int getTotalIteratorNumber() {
return this.totalIteratorNumber;
}

private void setTotalIteratorNumber(final int totalIteratorNumber) {
this.totalIteratorNumber = totalIteratorNumber;
}

public final long getTaskPreparationTime() {
return this.taskPreparationTime;
}

private void setTaskPreparationTime(final long taskPreparationTime) {
this.taskPreparationTime = taskPreparationTime;
}

@Override
public final String getId() {
return id;
Expand Down Expand Up @@ -317,6 +344,14 @@ public final boolean processMetricMessage(final String metricField, final byte[]
case "shuffleWriteTime":
setShuffleWriteTime(SerializationUtils.deserialize(metricValue));
break;
case "currentIteratorIndex":
setCurrentIteratorIndex(SerializationUtils.deserialize(metricValue));
break;
case "totalIteratorNumber":
setTotalIteratorNumber(SerializationUtils.deserialize(metricValue));
break;
case "taskPreparationTime":
setTaskPreparationTime(SerializationUtils.deserialize(metricValue));
default:
LOG.warn("metricField {} is not supported.", metricField);
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,14 @@
import org.apache.nemo.common.ir.vertex.executionproperty.ScheduleGroupProperty;

import java.io.Serializable;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.*;

/**
* Stage.
*/
public final class Stage extends Vertex {
private final List<Integer> taskIndices;
private final Set<String> workStealingTaskIds = new HashSet<>();
private final DAG<IRVertex, RuntimeEdge<IRVertex>> irDag;
private final byte[] serializedIRDag;
private final List<Map<String, Readable>> vertexIdToReadables;
Expand Down Expand Up @@ -93,6 +92,18 @@ public List<Integer> getTaskIndices() {
return taskIndices;
}

/**
* Set IDs for work stealing.
* @param workStealingTaskIds IDs of work stealer tasks.
*/
public void setWorkStealingTaskIds(final Set<String> workStealingTaskIds) {
this.workStealingTaskIds.addAll(workStealingTaskIds);
}

public Set<String> getWorkStealingTaskIds() {
return this.workStealingTaskIds;
}

/**
* @return the parallelism.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicInteger;

/**
* A Task (attempt) is a self-contained executable that can be executed on a machine.
Expand All @@ -40,8 +41,13 @@ public final class Task implements Serializable {
private final byte[] serializedIRDag;
private final Map<String, Readable> irVertexIdToReadable;

/* For work stealing */
private final AtomicInteger iteratorStartIndex;
private final AtomicInteger iteratorEndIndex;

/**
* Constructor.
* Default Constructor.
* It initializes iteratorStartIndex as 0 and iteratorEndIndex as Integer.MAX_VALUE.
*
* @param planId the id of the physical plan.
* @param taskId the ID of this task attempt.
Expand All @@ -58,13 +64,42 @@ public Task(final String planId,
final List<StageEdge> taskIncomingEdges,
final List<StageEdge> taskOutgoingEdges,
final Map<String, Readable> irVertexIdToReadable) {
this(planId, taskId, executionProperties, serializedIRDag, taskIncomingEdges, taskOutgoingEdges,
irVertexIdToReadable, new AtomicInteger(0), new AtomicInteger(Integer.MAX_VALUE));
}

/**
* Constructor with iterator information.
* This constructor is used when creating work stealer tasks.
*
* @param planId the id of the physical plan.
* @param taskId the ID of this task attempt.
* @param executionProperties {@link VertexExecutionProperty} map for the corresponding stage
* @param serializedIRDag the serialized DAG of the task.
* @param taskIncomingEdges the incoming edges of the task.
* @param taskOutgoingEdges the outgoing edges of the task.
* @param irVertexIdToReadable the map between IRVertex id to readable.
* @param iteratorStartIndex starting index of iterator.
* @param iteratorEndIndex ending index of iterator.
*/
public Task(final String planId,
final String taskId,
final ExecutionPropertyMap<VertexExecutionProperty> executionProperties,
final byte[] serializedIRDag,
final List<StageEdge> taskIncomingEdges,
final List<StageEdge> taskOutgoingEdges,
final Map<String, Readable> irVertexIdToReadable,
final AtomicInteger iteratorStartIndex,
final AtomicInteger iteratorEndIndex) {
this.planId = planId;
this.taskId = taskId;
this.executionProperties = executionProperties;
this.serializedIRDag = serializedIRDag;
this.taskIncomingEdges = taskIncomingEdges;
this.taskOutgoingEdges = taskOutgoingEdges;
this.irVertexIdToReadable = irVertexIdToReadable;
this.iteratorStartIndex = iteratorStartIndex;
this.iteratorEndIndex = iteratorEndIndex;
}

/**
Expand Down
20 changes: 20 additions & 0 deletions runtime/common/src/main/proto/ControlMessage.proto
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ enum MessageType {
PipeInit = 13;
RequestPipeLoc = 14;
PipeLocInfo = 15;
ParentTaskDataCollected = 16;
CurrentlyProcessedBytesCollected = 17;
SendWorkStealingResult = 18;
}

message Message {
Expand All @@ -107,6 +110,9 @@ message Message {
optional PipeInitMessage pipeInitMsg = 16;
optional RequestPipeLocationMessage requestPipeLocMsg = 17;
optional PipeLocationInfoMessage pipeLocInfoMsg = 18;
optional ParentTaskDataCollectMsg ParentTaskDataCollected = 19;
optional CurrentlyProcessedBytesCollectMsg currentlyProcessedBytesCollected = 20;
optional WorkStealingResultMsg sendWorkStealingResult = 22;
}

// Messages from Master to Executors
Expand Down Expand Up @@ -256,3 +262,17 @@ message PipeLocationInfoMessage {
required int64 requestId = 1; // To find the matching request msg
required string executorId = 2;
}

message ParentTaskDataCollectMsg {
required string taskId = 1;
required bytes partitionSizeMap = 2;
}

message CurrentlyProcessedBytesCollectMsg {
required string taskId = 1;
required int64 processedDataBytes = 2;
}

message WorkStealingResultMsg {
required bytes workStealingResult = 1;
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ public final class BlockOutputWriter implements OutputWriter {

private long writtenBytes;

private Optional<Map<Integer, Long>> partitionSizeMap;

/**
* Constructor.
*
Expand Down Expand Up @@ -109,7 +111,7 @@ public void close() {
final DataPersistenceProperty.Value persistence = (DataPersistenceProperty.Value) runtimeEdge
.getPropertyValue(DataPersistenceProperty.class).orElseThrow(IllegalStateException::new);

final Optional<Map<Integer, Long>> partitionSizeMap = blockToWrite.commit();
partitionSizeMap = blockToWrite.commit();
// Return the total size of the committed block.
if (partitionSizeMap.isPresent()) {
long blockSizeTotal = 0;
Expand All @@ -123,6 +125,16 @@ public void close() {
blockManagerWorker.writeBlock(blockToWrite, blockStoreValue, getExpectedRead(), persistence);
}

@Override
public Optional<Map<Integer, Long>> getPartitionSizeMap() {
if (partitionSizeMap.isPresent()) {
return partitionSizeMap;
} else {
return Optional.empty();
}
}

@Override
public Optional<Long> getWrittenBytes() {
if (writtenBytes == -1) {
return Optional.empty();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import org.apache.nemo.common.punctuation.Watermark;

import java.util.Map;
import java.util.Optional;

/**
Expand All @@ -45,5 +46,10 @@ public interface OutputWriter {
*/
Optional<Long> getWrittenBytes();

/**
* @return the map of hashed key to partition size.
*/
Optional<Map<Integer, Long>> getPartitionSizeMap();

void close();
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;

/**
Expand Down Expand Up @@ -113,6 +114,11 @@ public Optional<Long> getWrittenBytes() {
return Optional.empty();
}

@Override
public Optional<Map<Integer, Long>> getPartitionSizeMap() {
return Optional.empty();
}

@Override
public void close() {
if (!initialized) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import org.apache.nemo.common.ir.OutputCollector;
import org.apache.nemo.common.ir.vertex.IRVertex;
import org.apache.nemo.runtime.executor.MetricMessageSender;

import java.io.IOException;

Expand Down Expand Up @@ -49,6 +50,21 @@ abstract class DataFetcher implements AutoCloseable {
*/
abstract Object fetchDataElement() throws IOException;

/**
* Identical with fetchDataElement(), except it sends intermediate serializedReadBytes to MetricStore
* on every iterator advance.
* This method is for WorkStealing implementation in Nemo.
*
* @param taskId task id
* @param metricMessageSender metricMessageSender
*
* @return data element
* @throws IOException upon I/O error
* @throws java.util.NoSuchElementException if no more element is available
*/
abstract Object fetchDataElementWithTrace(String taskId,
MetricMessageSender metricMessageSender) throws IOException;

OutputCollector getOutputCollector() {
return outputCollector;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.nemo.common.ir.vertex.IRVertex;
import org.apache.nemo.common.punctuation.Finishmark;
import org.apache.nemo.common.punctuation.Watermark;
import org.apache.nemo.runtime.executor.MetricMessageSender;
import org.apache.nemo.runtime.executor.data.DataUtil;
import org.apache.nemo.runtime.executor.datatransfer.*;
import org.slf4j.Logger;
Expand Down Expand Up @@ -100,6 +101,12 @@ Object fetchDataElement() throws IOException {
}
}

@Override
Object fetchDataElementWithTrace(final String taskId,
final MetricMessageSender metricMessageSender) throws IOException {
return fetchDataElement();
}

private void fetchDataLazily() {
final List<CompletableFuture<DataUtil.IteratorWithNumBytes>> futures = readersForParentTask.read();
numOfIterators = futures.size();
Expand Down
Loading