Skip to content

Commit c84d916

Browse files
Davies Liurxin
authored andcommitted
[SPARK-6957] [SPARK-6958] [SQL] improve API compatibility to pandas
``` select(['cola', 'colb']) groupby(['colA', 'colB']) groupby([df.colA, df.colB]) df.sort('A', ascending=True) df.sort(['A', 'B'], ascending=True) df.sort(['A', 'B'], ascending=[1, 0]) ``` cc rxin Author: Davies Liu <[email protected]> Closes #5544 from davies/compatibility and squashes the following commits: 4944058 [Davies Liu] add docstrings adb2816 [Davies Liu] Merge branch 'master' of github.com:apache/spark into compatibility bcbbcab [Davies Liu] support ascending as list 8dabdf0 [Davies Liu] improve API compatibility to pandas
1 parent dc48ba9 commit c84d916

File tree

3 files changed

+70
-39
lines changed

3 files changed

+70
-39
lines changed

python/pyspark/sql/dataframe.py

Lines changed: 66 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -485,30 +485,60 @@ def join(self, other, joinExprs=None, joinType=None):
485485
return DataFrame(jdf, self.sql_ctx)
486486

487487
@ignore_unicode_prefix
488-
def sort(self, *cols):
488+
def sort(self, *cols, **kwargs):
489489
"""Returns a new :class:`DataFrame` sorted by the specified column(s).
490490
491-
:param cols: list of :class:`Column` to sort by.
491+
:param cols: list of :class:`Column` or column names to sort by.
492+
:param ascending: sort by ascending order or not, could be bool, int
493+
or list of bool, int (default: True).
492494
493495
>>> df.sort(df.age.desc()).collect()
494496
[Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
497+
>>> df.sort("age", ascending=False).collect()
498+
[Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
495499
>>> df.orderBy(df.age.desc()).collect()
496500
[Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
497501
>>> from pyspark.sql.functions import *
498502
>>> df.sort(asc("age")).collect()
499503
[Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
500504
>>> df.orderBy(desc("age"), "name").collect()
501505
[Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
506+
>>> df.orderBy(["age", "name"], ascending=[0, 1]).collect()
507+
[Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
502508
"""
503509
if not cols:
504510
raise ValueError("should sort by at least one column")
505-
jcols = ListConverter().convert([_to_java_column(c) for c in cols],
506-
self._sc._gateway._gateway_client)
507-
jdf = self._jdf.sort(self._sc._jvm.PythonUtils.toSeq(jcols))
511+
if len(cols) == 1 and isinstance(cols[0], list):
512+
cols = cols[0]
513+
jcols = [_to_java_column(c) for c in cols]
514+
ascending = kwargs.get('ascending', True)
515+
if isinstance(ascending, (bool, int)):
516+
if not ascending:
517+
jcols = [jc.desc() for jc in jcols]
518+
elif isinstance(ascending, list):
519+
jcols = [jc if asc else jc.desc()
520+
for asc, jc in zip(ascending, jcols)]
521+
else:
522+
raise TypeError("ascending can only be bool or list, but got %s" % type(ascending))
523+
524+
jdf = self._jdf.sort(self._jseq(jcols))
508525
return DataFrame(jdf, self.sql_ctx)
509526

510527
orderBy = sort
511528

529+
def _jseq(self, cols, converter=None):
530+
"""Return a JVM Seq of Columns from a list of Column or names"""
531+
return _to_seq(self.sql_ctx._sc, cols, converter)
532+
533+
def _jcols(self, *cols):
534+
"""Return a JVM Seq of Columns from a list of Column or column names
535+
536+
If `cols` has only one list in it, cols[0] will be used as the list.
537+
"""
538+
if len(cols) == 1 and isinstance(cols[0], list):
539+
cols = cols[0]
540+
return self._jseq(cols, _to_java_column)
541+
512542
def describe(self, *cols):
513543
"""Computes statistics for numeric columns.
514544
@@ -523,9 +553,7 @@ def describe(self, *cols):
523553
min 2
524554
max 5
525555
"""
526-
cols = ListConverter().convert(cols,
527-
self.sql_ctx._sc._gateway._gateway_client)
528-
jdf = self._jdf.describe(self.sql_ctx._sc._jvm.PythonUtils.toSeq(cols))
556+
jdf = self._jdf.describe(self._jseq(cols))
529557
return DataFrame(jdf, self.sql_ctx)
530558

531559
@ignore_unicode_prefix
@@ -607,9 +635,7 @@ def select(self, *cols):
607635
>>> df.select(df.name, (df.age + 10).alias('age')).collect()
608636
[Row(name=u'Alice', age=12), Row(name=u'Bob', age=15)]
609637
"""
610-
jcols = ListConverter().convert([_to_java_column(c) for c in cols],
611-
self._sc._gateway._gateway_client)
612-
jdf = self._jdf.select(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols))
638+
jdf = self._jdf.select(self._jcols(*cols))
613639
return DataFrame(jdf, self.sql_ctx)
614640

615641
def selectExpr(self, *expr):
@@ -620,8 +646,9 @@ def selectExpr(self, *expr):
620646
>>> df.selectExpr("age * 2", "abs(age)").collect()
621647
[Row((age * 2)=4, Abs(age)=2), Row((age * 2)=10, Abs(age)=5)]
622648
"""
623-
jexpr = ListConverter().convert(expr, self._sc._gateway._gateway_client)
624-
jdf = self._jdf.selectExpr(self._sc._jvm.PythonUtils.toSeq(jexpr))
649+
if len(expr) == 1 and isinstance(expr[0], list):
650+
expr = expr[0]
651+
jdf = self._jdf.selectExpr(self._jseq(expr))
625652
return DataFrame(jdf, self.sql_ctx)
626653

627654
@ignore_unicode_prefix
@@ -659,6 +686,8 @@ def groupBy(self, *cols):
659686
so we can run aggregation on them. See :class:`GroupedData`
660687
for all the available aggregate functions.
661688
689+
:func:`groupby` is an alias for :func:`groupBy`.
690+
662691
:param cols: list of columns to group by.
663692
Each element should be a column name (string) or an expression (:class:`Column`).
664693
@@ -668,12 +697,14 @@ def groupBy(self, *cols):
668697
[Row(name=u'Alice', AVG(age)=2.0), Row(name=u'Bob', AVG(age)=5.0)]
669698
>>> df.groupBy(df.name).avg().collect()
670699
[Row(name=u'Alice', AVG(age)=2.0), Row(name=u'Bob', AVG(age)=5.0)]
700+
>>> df.groupBy(['name', df.age]).count().collect()
701+
[Row(name=u'Bob', age=5, count=1), Row(name=u'Alice', age=2, count=1)]
671702
"""
672-
jcols = ListConverter().convert([_to_java_column(c) for c in cols],
673-
self._sc._gateway._gateway_client)
674-
jdf = self._jdf.groupBy(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols))
703+
jdf = self._jdf.groupBy(self._jcols(*cols))
675704
return GroupedData(jdf, self.sql_ctx)
676705

706+
groupby = groupBy
707+
677708
def agg(self, *exprs):
678709
""" Aggregate on the entire :class:`DataFrame` without groups
679710
(shorthand for ``df.groupBy.agg()``).
@@ -744,9 +775,7 @@ def dropna(self, how='any', thresh=None, subset=None):
744775
if thresh is None:
745776
thresh = len(subset) if how == 'any' else 1
746777

747-
cols = ListConverter().convert(subset, self.sql_ctx._sc._gateway._gateway_client)
748-
cols = self.sql_ctx._sc._jvm.PythonUtils.toSeq(cols)
749-
return DataFrame(self._jdf.na().drop(thresh, cols), self.sql_ctx)
778+
return DataFrame(self._jdf.na().drop(thresh, self._jseq(subset)), self.sql_ctx)
750779

751780
def fillna(self, value, subset=None):
752781
"""Replace null values, alias for ``na.fill()``.
@@ -799,9 +828,7 @@ def fillna(self, value, subset=None):
799828
elif not isinstance(subset, (list, tuple)):
800829
raise ValueError("subset should be a list or tuple of column names")
801830

802-
cols = ListConverter().convert(subset, self.sql_ctx._sc._gateway._gateway_client)
803-
cols = self.sql_ctx._sc._jvm.PythonUtils.toSeq(cols)
804-
return DataFrame(self._jdf.na().fill(value, cols), self.sql_ctx)
831+
return DataFrame(self._jdf.na().fill(value, self._jseq(subset)), self.sql_ctx)
805832

806833
@ignore_unicode_prefix
807834
def withColumn(self, colName, col):
@@ -862,10 +889,8 @@ def _api(self):
862889

863890
def df_varargs_api(f):
864891
def _api(self, *args):
865-
jargs = ListConverter().convert(args,
866-
self.sql_ctx._sc._gateway._gateway_client)
867892
name = f.__name__
868-
jdf = getattr(self._jdf, name)(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jargs))
893+
jdf = getattr(self._jdf, name)(_to_seq(self.sql_ctx._sc, args))
869894
return DataFrame(jdf, self.sql_ctx)
870895
_api.__name__ = f.__name__
871896
_api.__doc__ = f.__doc__
@@ -912,9 +937,8 @@ def agg(self, *exprs):
912937
else:
913938
# Columns
914939
assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column"
915-
jcols = ListConverter().convert([c._jc for c in exprs[1:]],
916-
self.sql_ctx._sc._gateway._gateway_client)
917-
jdf = self._jdf.agg(exprs[0]._jc, self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols))
940+
jdf = self._jdf.agg(exprs[0]._jc,
941+
_to_seq(self.sql_ctx._sc, [c._jc for c in exprs[1:]]))
918942
return DataFrame(jdf, self.sql_ctx)
919943

920944
@dfapi
@@ -1006,6 +1030,19 @@ def _to_java_column(col):
10061030
return jcol
10071031

10081032

1033+
def _to_seq(sc, cols, converter=None):
1034+
"""
1035+
Convert a list of Column (or names) into a JVM Seq of Column.
1036+
1037+
An optional `converter` could be used to convert items in `cols`
1038+
into JVM Column objects.
1039+
"""
1040+
if converter:
1041+
cols = [converter(c) for c in cols]
1042+
jcols = ListConverter().convert(cols, sc._gateway._gateway_client)
1043+
return sc._jvm.PythonUtils.toSeq(jcols)
1044+
1045+
10091046
def _unary_op(name, doc="unary operator"):
10101047
""" Create a method for given unary operator """
10111048
def _(self):
@@ -1177,8 +1214,7 @@ def inSet(self, *cols):
11771214
cols = cols[0]
11781215
cols = [c._jc if isinstance(c, Column) else _create_column_from_literal(c) for c in cols]
11791216
sc = SparkContext._active_spark_context
1180-
jcols = ListConverter().convert(cols, sc._gateway._gateway_client)
1181-
jc = getattr(self._jc, "in")(sc._jvm.PythonUtils.toSeq(jcols))
1217+
jc = getattr(self._jc, "in")(_to_seq(sc, cols))
11821218
return Column(jc)
11831219

11841220
# order

python/pyspark/sql/functions.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,11 @@
2323
if sys.version < "3":
2424
from itertools import imap as map
2525

26-
from py4j.java_collections import ListConverter
27-
2826
from pyspark import SparkContext
2927
from pyspark.rdd import _prepare_for_python_RDD
3028
from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
3129
from pyspark.sql.types import StringType
32-
from pyspark.sql.dataframe import Column, _to_java_column
30+
from pyspark.sql.dataframe import Column, _to_java_column, _to_seq
3331

3432

3533
__all__ = ['countDistinct', 'approxCountDistinct', 'udf']
@@ -87,8 +85,7 @@ def countDistinct(col, *cols):
8785
[Row(c=2)]
8886
"""
8987
sc = SparkContext._active_spark_context
90-
jcols = ListConverter().convert([_to_java_column(c) for c in cols], sc._gateway._gateway_client)
91-
jc = sc._jvm.functions.countDistinct(_to_java_column(col), sc._jvm.PythonUtils.toSeq(jcols))
88+
jc = sc._jvm.functions.countDistinct(_to_java_column(col), _to_seq(sc, cols, _to_java_column))
9289
return Column(jc)
9390

9491

@@ -138,9 +135,7 @@ def __del__(self):
138135

139136
def __call__(self, *cols):
140137
sc = SparkContext._active_spark_context
141-
jcols = ListConverter().convert([_to_java_column(c) for c in cols],
142-
sc._gateway._gateway_client)
143-
jc = self._judf.apply(sc._jvm.PythonUtils.toSeq(jcols))
138+
jc = self._judf.apply(_to_seq(sc, cols, _to_java_column))
144139
return Column(jc)
145140

146141

python/pyspark/sql/tests.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ def test_apply_schema(self):
282282
StructField("struct1", StructType([StructField("b", ShortType(), False)]), False),
283283
StructField("list1", ArrayType(ByteType(), False), False),
284284
StructField("null1", DoubleType(), True)])
285-
df = self.sqlCtx.applySchema(rdd, schema)
285+
df = self.sqlCtx.createDataFrame(rdd, schema)
286286
results = df.map(lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int1, x.float1, x.date1,
287287
x.time1, x.map1["a"], x.struct1.b, x.list1, x.null1))
288288
r = (127, -128, -32768, 32767, 2147483647, 1.0, date(2010, 1, 1),

0 commit comments

Comments
 (0)