Skip to content

Commit ba2b12f

Browse files
authored
SparkOperations implementation (#14)
SparkOperations implementation
1 parent ee82c02 commit ba2b12f

File tree

5 files changed

+119
-5
lines changed

5 files changed

+119
-5
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
*.pyc
2+
/.idea

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@ Google Python Style Guide https://google.github.io/styleguide/pyguide.html
1212

1313
### Installing dependencies
1414

15-
This project depends on numpy apache-beam absl-py dataclasses
15+
This project depends on numpy apache-beam pyspark absl-py dataclasses
1616

1717
For installing with pip please run:
1818

19-
1. `pip install numpy apache-beam absl-py`
19+
1. `pip install numpy apache-beam pyspark absl-py`
2020

2121
2. (for Python 3.6) `pip install dataclasses`
2222

pipeline_dp/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@
33
from pipeline_dp.dp_engine import DataExtractors
44
from pipeline_dp.dp_engine import Metrics
55
from pipeline_dp.dp_engine import DPEngine
6-
from pipeline_dp.pipeline_operations import BeamOperations
6+
from pipeline_dp.pipeline_operations import BeamOperations
7+
from pipeline_dp.pipeline_operations import SparkRDDOperations

pipeline_dp/pipeline_operations.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""Adapters for working with pipeline frameworks."""
22

3+
import random
4+
35
import abc
46
import apache_beam as beam
57
import apache_beam.transforms.combiners as combiners
@@ -58,6 +60,16 @@ def map_values(self, col, fn, stage_name: str):
5860
return col | stage_name >> beam.MapTuple(lambda k, v: (k, fn(v)))
5961

6062
def group_by_key(self, col, stage_name: str):
63+
"""Group the values for each key in the PCollection into a single sequence.
64+
65+
Args:
66+
col: input collection
67+
stage_name: name of the stage
68+
69+
Returns:
70+
An PCollection of tuples in which the type of the second item is list.
71+
72+
"""
6173
return col | stage_name >> beam.GroupByKey()
6274

6375
def filter(self, col, fn, stage_name: str):
@@ -76,13 +88,69 @@ def count_per_element(self, col, stage_name: str):
7688
return col | stage_name >> combiners.Count.PerElement()
7789

7890

91+
class SparkRDDOperations(PipelineOperations):
92+
"""Apache Spark RDD adapter."""
93+
94+
def map(self, rdd, fn, stage_name: str = None):
95+
return rdd.map(fn)
96+
97+
def map_tuple(self, rdd, fn, stage_name: str = None):
98+
return rdd.map(fn)
99+
100+
def map_values(self, rdd, fn, stage_name: str = None):
101+
return rdd.mapValues(fn)
102+
103+
def group_by_key(self, rdd, stage_name: str = None):
104+
"""Group the values for each key in the RDD into a single sequence.
105+
106+
Args:
107+
rdd: input RDD
108+
stage_name: not used
109+
110+
Returns:
111+
An RDD of tuples in which the type of the second item
112+
is the pyspark.resultiterable.ResultIterable.
113+
114+
"""
115+
return rdd.groupByKey()
116+
117+
def filter(self, rdd, fn, stage_name: str = None):
118+
return rdd.filter(fn)
119+
120+
def keys(self, rdd, stage_name: str = None):
121+
return rdd.keys()
122+
123+
def values(self, rdd, stage_name: str = None):
124+
return rdd.values()
125+
126+
def sample_fixed_per_key(self, rdd, n: int, stage_name: str = None):
127+
"""Get fixed-size random samples for each unique key in an RDD of key-values.
128+
Sampling is not guaranteed to be uniform across partitions.
129+
130+
Args:
131+
rdd: input RDD
132+
n: number of values to sample for each key
133+
stage_name: not used
134+
135+
Returns:
136+
An RDD of tuples.
137+
138+
"""
139+
return rdd.mapValues(lambda x: [x])\
140+
.reduceByKey(lambda x, y: random.sample(x+y, min(len(x)+len(y), n)))
141+
142+
def count_per_element(self, rdd, stage_name: str = None):
143+
return rdd.map(lambda x: (x, 1))\
144+
.reduceByKey(lambda x, y: (x + y))
145+
146+
79147
class LocalPipelineOperations(PipelineOperations):
80148
"""Local Pipeline adapter."""
81149

82150
def map(self, col, fn, stage_name: str = None):
83151
return map(fn, col)
84152

85-
def map_tuple(self, col, fn, stage_name: str = None):
153+
def map_tuple(self, col, fn, stage_name: str):
86154
return (fn(k, v) for k, v in col)
87155

88156
def map_values(self, col, fn, stage_name: str):

tests/pipeline_operations_test.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,45 @@
11
import unittest
2+
import pyspark
23

4+
from pipeline_dp.pipeline_operations import SparkRDDOperations
35
from pipeline_dp.pipeline_operations import LocalPipelineOperations
46

57

68
class PipelineOperationsTest(unittest.TestCase):
79
pass
810

911

12+
class SparkRDDOperationsTest(unittest.TestCase):
13+
@classmethod
14+
def setUpClass(cls):
15+
conf = pyspark.SparkConf()
16+
cls.sc = pyspark.SparkContext(conf=conf)
17+
18+
def test_sample_fixed_per_key(self):
19+
spark_operations = SparkRDDOperations()
20+
data = [(1, 11), (2, 22), (3, 33), (1, 14), (2, 25), (1, 16)]
21+
dist_data = SparkRDDOperationsTest.sc.parallelize(data)
22+
rdd = spark_operations.sample_fixed_per_key(dist_data, 2)
23+
result = dict(rdd.collect())
24+
self.assertEqual(len(result[1]), 2)
25+
self.assertTrue(set(result[1]).issubset({11, 14, 16}))
26+
self.assertSetEqual(set(result[2]), {22, 25})
27+
self.assertSetEqual(set(result[3]), {33})
28+
29+
def test_count_per_element(self):
30+
spark_operations = SparkRDDOperations()
31+
data = ['a', 'b', 'a']
32+
dist_data = SparkRDDOperationsTest.sc.parallelize(data)
33+
rdd = spark_operations.count_per_element(dist_data)
34+
result = rdd.collect()
35+
result = dict(result)
36+
self.assertDictEqual(result, {'a': 2, 'b': 1})
37+
38+
@classmethod
39+
def tearDownClass(cls):
40+
cls.sc.stop()
41+
42+
1043
class LocalPipelineOperationsTest(unittest.TestCase):
1144
@classmethod
1245
def setUpClass(cls):
@@ -22,10 +55,20 @@ def test_local_map(self):
2255
self.assertEqual(list(self.ops.map(range(5), lambda x: x ** 2)),
2356
[0, 1, 4, 9, 16])
2457

58+
def test_local_map_tuple(self):
59+
some_map = self.ops.map([1, 2, 3], lambda x: x)
60+
# some_map is its own consumable iterator
61+
self.assertIs(some_map, iter(some_map))
62+
63+
self.assertEqual(list(self.ops.map([1, 2, 3], str)),
64+
["1", "2", "3"])
65+
self.assertEqual(list(self.ops.map(range(5), lambda x: x ** 2)),
66+
[0, 1, 4, 9, 16])
67+
2568
def test_local_map_tuple(self):
2669
tuple_list = [(1, 2), (2, 3), (3, 4)]
2770

28-
self.assertEqual(list(self.ops.map_tuple(tuple_list, lambda k, v: k+v)),
71+
self.assertEqual(list(self.ops.map_tuple(tuple_list, lambda k, v: k + v)),
2972
[3, 5, 7])
3073

3174
self.assertEqual(list(self.ops.map_tuple(tuple_list, lambda k, v: (

0 commit comments

Comments
 (0)