Skip to content

Commit eb44ff0

Browse files
committed
[SYSTEMDS-3910] OOC matrix-matrix multiplication
This patch introduces the MatrixMatrix multiplication logic. It performs a shuffle-based matrix multiplication on two large matrix streams. Implementation Detail: Asynchronous Producer: The processInstruction method launches a background thread to perform the entire two-phase multiplication, but returns control to the main thread immediately. This non-blocking setup allows the compiler to build the downstream executionplan while the OOC operation prepares to run upon data request. Two-Phase Streaming Logic: The background thread implements a shuffle-based algorithm to handle two large inputs: * Phase 1 (Grouping/Shuffle): It first consumes both input streams entirely. Blocks from each stream (A_ik and B_kj) are partitioned into groups based on the output block index (C_ij) they contribute to. A HashMap stores these groups, effectively "shuffling" the data for parallel processing. * Phase 2 (Aggregation/Reduce): After grouping, it processes each group independently. Within a group, it pairs the corresponding blocks using their common index k, performs the block-level multiplication, and aggregates the results to produce a single, final output block which is then enqueued to theoutput stream. Robust Block Identification: A TaggedMatrixValue wrapper is used during the grouping phase to explicitly tag each block with its source matrix (A or B). This ensures correct and unambiguous identification during the aggregation phase, a critical requirement that cannot be met by relying on block dimensions alone. Integration: The new instruction is fully integrated into the OOC framework: * The OOCInstructionParser is updated to recognize the aggregate binary in OOC context.
1 parent 8bed176 commit eb44ff0

File tree

5 files changed

+471
-157
lines changed

5 files changed

+471
-157
lines changed

src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,10 @@
2525
import org.apache.sysds.runtime.DMLRuntimeException;
2626
import org.apache.sysds.runtime.instructions.ooc.AggregateUnaryOOCInstruction;
2727
import org.apache.sysds.runtime.instructions.ooc.BinaryOOCInstruction;
28+
import org.apache.sysds.runtime.instructions.ooc.MatrixMultiplyOOCInstruction;
2829
import org.apache.sysds.runtime.instructions.ooc.OOCInstruction;
2930
import org.apache.sysds.runtime.instructions.ooc.ReblockOOCInstruction;
3031
import org.apache.sysds.runtime.instructions.ooc.UnaryOOCInstruction;
31-
import org.apache.sysds.runtime.instructions.ooc.MatrixVectorBinaryOOCInstruction;
3232
import org.apache.sysds.runtime.instructions.ooc.TransposeOOCInstruction;
3333

3434
public class OOCInstructionParser extends InstructionParser {
@@ -60,7 +60,7 @@ public static OOCInstruction parseSingleInstruction(InstructionType ooctype, Str
6060
return BinaryOOCInstruction.parseInstruction(str);
6161
case AggregateBinary:
6262
case MAPMM:
63-
return MatrixVectorBinaryOOCInstruction.parseInstruction(str);
63+
return MatrixMultiplyOOCInstruction.parseInstruction(str);
6464
case Reorg:
6565
return TransposeOOCInstruction.parseInstruction(str);
6666

Lines changed: 319 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,319 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.runtime.instructions.ooc;
21+
22+
import java.util.ArrayList;
23+
import java.util.HashMap;
24+
import java.util.HashSet;
25+
import java.util.List;
26+
import java.util.Map;
27+
import java.util.Set;
28+
import java.util.concurrent.ExecutorService;
29+
30+
import org.apache.sysds.common.Opcodes;
31+
import org.apache.sysds.conf.ConfigurationManager;
32+
import org.apache.sysds.runtime.DMLRuntimeException;
33+
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
34+
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
35+
import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue;
36+
import org.apache.sysds.runtime.functionobjects.Multiply;
37+
import org.apache.sysds.runtime.functionobjects.Plus;
38+
import org.apache.sysds.runtime.instructions.InstructionUtils;
39+
import org.apache.sysds.runtime.instructions.cp.CPOperand;
40+
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
41+
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
42+
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
43+
import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator;
44+
import org.apache.sysds.runtime.matrix.operators.AggregateOperator;
45+
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
46+
import org.apache.sysds.runtime.matrix.operators.Operator;
47+
import org.apache.sysds.runtime.util.CommonThreadPool;
48+
49+
public class MatrixMultiplyOOCInstruction extends ComputationOOCInstruction {
50+
51+
52+
protected MatrixMultiplyOOCInstruction(OOCType type, Operator op, CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr) {
53+
super(type, op, in1, in2, out, opcode, istr);
54+
}
55+
56+
public static MatrixMultiplyOOCInstruction parseInstruction(String str) {
57+
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
58+
InstructionUtils.checkNumFields(parts, 4);
59+
String opcode = parts[0];
60+
CPOperand in1 = new CPOperand(parts[1]); // the larget matrix (streamed)
61+
CPOperand in2 = new CPOperand(parts[2]); // the small vector (in-memory)
62+
CPOperand out = new CPOperand(parts[3]);
63+
64+
AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject());
65+
AggregateBinaryOperator ba = new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), agg);
66+
67+
return new MatrixMultiplyOOCInstruction(OOCType.MAPMM, ba, in1, in2, out, opcode, str);
68+
}
69+
70+
@Override
71+
public void processInstruction( ExecutionContext ec ) {
72+
73+
if (ec.getMatrixObject(input2).getDataCharacteristics().getCols() == 1) {
74+
_processMatrixVector(ec);
75+
} else {
76+
_processMatrixMatrix(ec);
77+
}
78+
}
79+
80+
private void _processMatrixVector( ExecutionContext ec ) {
81+
// 1. Identify the inputs
82+
MatrixObject min = ec.getMatrixObject(input1); // big matrix
83+
MatrixBlock vin = ec.getMatrixObject(input2)
84+
.acquireReadAndRelease(); // in-memory vector
85+
86+
// 2. Pre-partition the in-memory vector into a hashmap
87+
HashMap<Long, MatrixBlock> partitionedVector = new HashMap<>();
88+
int blksize = vin.getDataCharacteristics().getBlocksize();
89+
if (blksize < 0)
90+
blksize = ConfigurationManager.getBlocksize();
91+
for (int i = 0; i < vin.getNumRows(); i += blksize) {
92+
long key = (long) (i / blksize) + 1; // the key starts at 1
93+
int end_row = Math.min(i + blksize, vin.getNumRows());
94+
MatrixBlock vectorSlice = vin.slice(i, end_row - 1);
95+
partitionedVector.put(key, vectorSlice);
96+
}
97+
98+
LocalTaskQueue<IndexedMatrixValue> qIn = min.getStreamHandle();
99+
LocalTaskQueue<IndexedMatrixValue> qOut = new LocalTaskQueue<>();
100+
BinaryOperator plus = InstructionUtils.parseBinaryOperator(Opcodes.PLUS.toString());
101+
ec.getMatrixObject(output).setStreamHandle(qOut);
102+
103+
ExecutorService pool = CommonThreadPool.get();
104+
try {
105+
// Core logic: background thread
106+
pool.submit(() -> {
107+
IndexedMatrixValue tmp = null;
108+
try {
109+
HashMap<Long, MatrixBlock> partialResults = new HashMap<>();
110+
while ((tmp = qIn.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) {
111+
MatrixBlock matrixBlock = (MatrixBlock) tmp.getValue();
112+
long rowIndex = tmp.getIndexes().getRowIndex();
113+
long colIndex = tmp.getIndexes().getColumnIndex();
114+
MatrixBlock vectorSlice = partitionedVector.get(colIndex);
115+
116+
// Now, call the operation with the correct, specific operator.
117+
MatrixBlock partialResult = matrixBlock.aggregateBinaryOperations(
118+
matrixBlock, vectorSlice, new MatrixBlock(), (AggregateBinaryOperator) _optr);
119+
120+
// for single column block, no aggregation neeeded
121+
if (min.getNumColumns() <= min.getBlocksize()) {
122+
qOut.enqueueTask(new IndexedMatrixValue(tmp.getIndexes(), partialResult));
123+
} else {
124+
MatrixBlock currAgg = partialResults.get(rowIndex);
125+
if (currAgg == null)
126+
partialResults.put(rowIndex, partialResult);
127+
else
128+
currAgg.binaryOperationsInPlace(plus, partialResult);
129+
}
130+
}
131+
132+
// emit aggregated blocks
133+
if (min.getNumColumns() > min.getBlocksize()) {
134+
for (Map.Entry<Long, MatrixBlock> entry : partialResults.entrySet()) {
135+
MatrixIndexes outIndexes = new MatrixIndexes(entry.getKey(), 1L);
136+
qOut.enqueueTask(new IndexedMatrixValue(outIndexes, entry.getValue()));
137+
}
138+
}
139+
} catch (Exception ex) {
140+
throw new DMLRuntimeException(ex);
141+
} finally {
142+
qOut.closeInput();
143+
}
144+
});
145+
} catch (Exception e) {
146+
throw new DMLRuntimeException(e);
147+
} finally {
148+
pool.shutdown();
149+
}
150+
}
151+
152+
private void _processMatrixMatrix( ExecutionContext ec ) {
153+
// 1. Identify the inputs
154+
MatrixObject min = ec.getMatrixObject(input1); // big matrix
155+
MatrixObject min2 = ec.getMatrixObject(input2);
156+
157+
LocalTaskQueue<IndexedMatrixValue> qIn1 = min.getStreamHandle();
158+
LocalTaskQueue<IndexedMatrixValue> qIn2 = min2.getStreamHandle();
159+
LocalTaskQueue<IndexedMatrixValue> qOut = new LocalTaskQueue<>();
160+
BinaryOperator plus = InstructionUtils.parseBinaryOperator(Opcodes.PLUS.toString());
161+
ec.getMatrixObject(output).setStreamHandle(qOut);
162+
163+
// Result matrix rows, cols = rows of A, cols of B
164+
long resultRowBlocks = min.getDataCharacteristics().getNumRowBlocks();
165+
long resultColBlocks = min2.getDataCharacteristics().getNumColBlocks();
166+
167+
ExecutorService pool = CommonThreadPool.get();
168+
try {
169+
// Core logic: background thread
170+
pool.submit(() -> {
171+
IndexedMatrixValue tmpA = null;
172+
IndexedMatrixValue tmpB = null;
173+
try {
174+
// Phase 1: grouping the output blocks by block Index (The Shuffle)
175+
Map<MatrixIndexes, List<TaggedMatrixValue>> groupedBlocks = new HashMap<>();
176+
HashMap<Long, MatrixBlock> partialResults = new HashMap<>();
177+
178+
// Process matrix A: each block A(i,k) contributes to C(i,j) for all j
179+
while((tmpA = qIn1.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) {
180+
long i = tmpA.getIndexes().getRowIndex() - 1;
181+
long k = tmpA.getIndexes().getColumnIndex() - 1;
182+
183+
for (int j=0; j<resultColBlocks; j++) {
184+
MatrixIndexes index = new MatrixIndexes(i, j); // 1,1= A11,A12,A13,B11,B21,B31
185+
186+
// Create a copy
187+
MatrixBlock sourceBlock = (MatrixBlock) tmpA.getValue();
188+
IndexedMatrixValue valueCopy = new IndexedMatrixValue(new MatrixIndexes(tmpA.getIndexes()), sourceBlock);
189+
190+
TaggedMatrixValue taggedValue = new TaggedMatrixValue(valueCopy, true, k);
191+
groupedBlocks.computeIfAbsent(index, idx -> new ArrayList<>()).add(taggedValue);
192+
}
193+
}
194+
195+
// Process matrix B: each block B(k,j) contributes to C(i,j) for all i
196+
while((tmpB = qIn2.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) {
197+
long k = tmpB.getIndexes().getRowIndex() - 1;
198+
long j = tmpB.getIndexes().getColumnIndex() - 1;
199+
200+
for (int i=0; i<resultRowBlocks; i++) {
201+
MatrixIndexes index = new MatrixIndexes(i, j);
202+
203+
MatrixBlock sourceBlock = (MatrixBlock) tmpB.getValue();
204+
IndexedMatrixValue valueCopy = new IndexedMatrixValue(new MatrixIndexes(tmpB.getIndexes()), sourceBlock);
205+
206+
TaggedMatrixValue taggedValue = new TaggedMatrixValue(valueCopy, false, k);
207+
groupedBlocks.computeIfAbsent(index,idx -> new ArrayList<>()).add(taggedValue);
208+
}
209+
}
210+
211+
212+
// Phase 2: Multiplication and Aggregation
213+
Map<MatrixIndexes, MatrixBlock> resultBlocks = new HashMap<>();
214+
215+
// Process each output block separately
216+
for (Map.Entry<MatrixIndexes, List<TaggedMatrixValue>> entry : groupedBlocks.entrySet()) {
217+
MatrixIndexes outIndex = entry.getKey();
218+
List<TaggedMatrixValue> outValues = entry.getValue();
219+
220+
// For this output block, collect left and right input blocks
221+
Map<Long, MatrixBlock> leftBlocks = new HashMap<>();
222+
Map<Long, MatrixBlock> rightBlocks = new HashMap<>();
223+
224+
// Organize blocks by k-index
225+
for (TaggedMatrixValue taggedValue : outValues) {
226+
IndexedMatrixValue value = taggedValue.getValue();
227+
long kIndex = taggedValue.getkIndex();
228+
229+
if (taggedValue.isFirstInput()) {
230+
leftBlocks.put(kIndex, (MatrixBlock)value.getValue());
231+
} else {
232+
rightBlocks.put(kIndex, (MatrixBlock)value.getValue());
233+
}
234+
}
235+
236+
// Create result block for this (i,j) position
237+
MatrixBlock resultBlock = null;
238+
239+
// Find k-indices that exist in both left and right
240+
Set<Long> commonKIndices = new HashSet<>(leftBlocks.keySet());
241+
commonKIndices.retainAll(rightBlocks.keySet());
242+
243+
// Multiply and aggregate matching blocks
244+
for (Long k : commonKIndices) {
245+
MatrixBlock leftBlock = leftBlocks.get(k);
246+
MatrixBlock rightBlock = rightBlocks.get(k);
247+
248+
// Multiply matching blocks
249+
MatrixBlock partialResult = leftBlock.aggregateBinaryOperations(leftBlock,
250+
rightBlock,
251+
new MatrixBlock(),
252+
InstructionUtils.getMatMultOperator(1));
253+
254+
if (resultBlock == null) {
255+
resultBlock = partialResult;
256+
} else {
257+
resultBlock = resultBlock.binaryOperationsInPlace(plus, partialResult);
258+
}
259+
}
260+
261+
// Store the final result for this output block
262+
if (resultBlock != null) {
263+
resultBlocks.put(outIndex, resultBlock);
264+
}
265+
}
266+
267+
// Enqueue all results after all multiplications are complete
268+
for (Map.Entry<MatrixIndexes, MatrixBlock> entry : resultBlocks.entrySet()) {
269+
MatrixIndexes outIdx0 = entry.getKey();
270+
MatrixBlock outBlock = entry.getValue();
271+
MatrixIndexes outIdx = new MatrixIndexes(outIdx0.getRowIndex() + 1,
272+
outIdx0.getColumnIndex() + 1);
273+
outBlock.checkSparseRows();
274+
qOut.enqueueTask(new IndexedMatrixValue(outIdx, outBlock));
275+
}
276+
277+
}
278+
catch(Exception ex) {
279+
throw new DMLRuntimeException(ex);
280+
}
281+
finally {
282+
qOut.closeInput();
283+
}
284+
});
285+
} catch (Exception e) {
286+
throw new DMLRuntimeException(e);
287+
}
288+
finally {
289+
pool.shutdown();
290+
}
291+
}
292+
293+
/**
294+
* Helper class to tag matrix block with their source and k-index
295+
*/
296+
private static class TaggedMatrixValue {
297+
IndexedMatrixValue _value;
298+
private long _kIndex;
299+
private boolean _isFirstInput;
300+
301+
public TaggedMatrixValue(IndexedMatrixValue value, boolean isFirstInput, long kIndex) {
302+
this._value = value;
303+
this._isFirstInput = isFirstInput;
304+
this._kIndex = kIndex;
305+
}
306+
307+
public IndexedMatrixValue getValue() {
308+
return _value;
309+
}
310+
311+
public boolean isFirstInput() {
312+
return _isFirstInput;
313+
}
314+
315+
public long getkIndex() {
316+
return _kIndex;
317+
}
318+
}
319+
}

0 commit comments

Comments
 (0)