@@ -485,30 +485,60 @@ def join(self, other, joinExprs=None, joinType=None):
485
485
return DataFrame (jdf , self .sql_ctx )
486
486
487
487
@ignore_unicode_prefix
488
- def sort (self , * cols ):
488
+ def sort (self , * cols , ** kwargs ):
489
489
"""Returns a new :class:`DataFrame` sorted by the specified column(s).
490
490
491
- :param cols: list of :class:`Column` to sort by.
491
+ :param cols: list of :class:`Column` or column names to sort by.
492
+ :param ascending: sort by ascending order or not, could be bool, int
493
+ or list of bool, int (default: True).
492
494
493
495
>>> df.sort(df.age.desc()).collect()
494
496
[Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
497
+ >>> df.sort("age", ascending=False).collect()
498
+ [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
495
499
>>> df.orderBy(df.age.desc()).collect()
496
500
[Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
497
501
>>> from pyspark.sql.functions import *
498
502
>>> df.sort(asc("age")).collect()
499
503
[Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
500
504
>>> df.orderBy(desc("age"), "name").collect()
501
505
[Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
506
+ >>> df.orderBy(["age", "name"], ascending=[0, 1]).collect()
507
+ [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
502
508
"""
503
509
if not cols :
504
510
raise ValueError ("should sort by at least one column" )
505
- jcols = ListConverter ().convert ([_to_java_column (c ) for c in cols ],
506
- self ._sc ._gateway ._gateway_client )
507
- jdf = self ._jdf .sort (self ._sc ._jvm .PythonUtils .toSeq (jcols ))
511
+ if len (cols ) == 1 and isinstance (cols [0 ], list ):
512
+ cols = cols [0 ]
513
+ jcols = [_to_java_column (c ) for c in cols ]
514
+ ascending = kwargs .get ('ascending' , True )
515
+ if isinstance (ascending , (bool , int )):
516
+ if not ascending :
517
+ jcols = [jc .desc () for jc in jcols ]
518
+ elif isinstance (ascending , list ):
519
+ jcols = [jc if asc else jc .desc ()
520
+ for asc , jc in zip (ascending , jcols )]
521
+ else :
522
+ raise TypeError ("ascending can only be bool or list, but got %s" % type (ascending ))
523
+
524
+ jdf = self ._jdf .sort (self ._jseq (jcols ))
508
525
return DataFrame (jdf , self .sql_ctx )
509
526
510
527
orderBy = sort
511
528
529
+ def _jseq (self , cols , converter = None ):
530
+ """Return a JVM Seq of Columns from a list of Column or names"""
531
+ return _to_seq (self .sql_ctx ._sc , cols , converter )
532
+
533
+ def _jcols (self , * cols ):
534
+ """Return a JVM Seq of Columns from a list of Column or column names
535
+
536
+ If `cols` has only one list in it, cols[0] will be used as the list.
537
+ """
538
+ if len (cols ) == 1 and isinstance (cols [0 ], list ):
539
+ cols = cols [0 ]
540
+ return self ._jseq (cols , _to_java_column )
541
+
512
542
def describe (self , * cols ):
513
543
"""Computes statistics for numeric columns.
514
544
@@ -523,9 +553,7 @@ def describe(self, *cols):
523
553
min 2
524
554
max 5
525
555
"""
526
- cols = ListConverter ().convert (cols ,
527
- self .sql_ctx ._sc ._gateway ._gateway_client )
528
- jdf = self ._jdf .describe (self .sql_ctx ._sc ._jvm .PythonUtils .toSeq (cols ))
556
+ jdf = self ._jdf .describe (self ._jseq (cols ))
529
557
return DataFrame (jdf , self .sql_ctx )
530
558
531
559
@ignore_unicode_prefix
@@ -607,9 +635,7 @@ def select(self, *cols):
607
635
>>> df.select(df.name, (df.age + 10).alias('age')).collect()
608
636
[Row(name=u'Alice', age=12), Row(name=u'Bob', age=15)]
609
637
"""
610
- jcols = ListConverter ().convert ([_to_java_column (c ) for c in cols ],
611
- self ._sc ._gateway ._gateway_client )
612
- jdf = self ._jdf .select (self .sql_ctx ._sc ._jvm .PythonUtils .toSeq (jcols ))
638
+ jdf = self ._jdf .select (self ._jcols (* cols ))
613
639
return DataFrame (jdf , self .sql_ctx )
614
640
615
641
def selectExpr (self , * expr ):
@@ -620,8 +646,9 @@ def selectExpr(self, *expr):
620
646
>>> df.selectExpr("age * 2", "abs(age)").collect()
621
647
[Row((age * 2)=4, Abs(age)=2), Row((age * 2)=10, Abs(age)=5)]
622
648
"""
623
- jexpr = ListConverter ().convert (expr , self ._sc ._gateway ._gateway_client )
624
- jdf = self ._jdf .selectExpr (self ._sc ._jvm .PythonUtils .toSeq (jexpr ))
649
+ if len (expr ) == 1 and isinstance (expr [0 ], list ):
650
+ expr = expr [0 ]
651
+ jdf = self ._jdf .selectExpr (self ._jseq (expr ))
625
652
return DataFrame (jdf , self .sql_ctx )
626
653
627
654
@ignore_unicode_prefix
@@ -659,6 +686,8 @@ def groupBy(self, *cols):
659
686
so we can run aggregation on them. See :class:`GroupedData`
660
687
for all the available aggregate functions.
661
688
689
+ :func:`groupby` is an alias for :func:`groupBy`.
690
+
662
691
:param cols: list of columns to group by.
663
692
Each element should be a column name (string) or an expression (:class:`Column`).
664
693
@@ -668,12 +697,14 @@ def groupBy(self, *cols):
668
697
[Row(name=u'Alice', AVG(age)=2.0), Row(name=u'Bob', AVG(age)=5.0)]
669
698
>>> df.groupBy(df.name).avg().collect()
670
699
[Row(name=u'Alice', AVG(age)=2.0), Row(name=u'Bob', AVG(age)=5.0)]
700
+ >>> df.groupBy(['name', df.age]).count().collect()
701
+ [Row(name=u'Bob', age=5, count=1), Row(name=u'Alice', age=2, count=1)]
671
702
"""
672
- jcols = ListConverter ().convert ([_to_java_column (c ) for c in cols ],
673
- self ._sc ._gateway ._gateway_client )
674
- jdf = self ._jdf .groupBy (self .sql_ctx ._sc ._jvm .PythonUtils .toSeq (jcols ))
703
+ jdf = self ._jdf .groupBy (self ._jcols (* cols ))
675
704
return GroupedData (jdf , self .sql_ctx )
676
705
706
+ groupby = groupBy
707
+
677
708
def agg (self , * exprs ):
678
709
""" Aggregate on the entire :class:`DataFrame` without groups
679
710
(shorthand for ``df.groupBy.agg()``).
@@ -744,9 +775,7 @@ def dropna(self, how='any', thresh=None, subset=None):
744
775
if thresh is None :
745
776
thresh = len (subset ) if how == 'any' else 1
746
777
747
- cols = ListConverter ().convert (subset , self .sql_ctx ._sc ._gateway ._gateway_client )
748
- cols = self .sql_ctx ._sc ._jvm .PythonUtils .toSeq (cols )
749
- return DataFrame (self ._jdf .na ().drop (thresh , cols ), self .sql_ctx )
778
+ return DataFrame (self ._jdf .na ().drop (thresh , self ._jseq (subset )), self .sql_ctx )
750
779
751
780
def fillna (self , value , subset = None ):
752
781
"""Replace null values, alias for ``na.fill()``.
@@ -799,9 +828,7 @@ def fillna(self, value, subset=None):
799
828
elif not isinstance (subset , (list , tuple )):
800
829
raise ValueError ("subset should be a list or tuple of column names" )
801
830
802
- cols = ListConverter ().convert (subset , self .sql_ctx ._sc ._gateway ._gateway_client )
803
- cols = self .sql_ctx ._sc ._jvm .PythonUtils .toSeq (cols )
804
- return DataFrame (self ._jdf .na ().fill (value , cols ), self .sql_ctx )
831
+ return DataFrame (self ._jdf .na ().fill (value , self ._jseq (subset )), self .sql_ctx )
805
832
806
833
@ignore_unicode_prefix
807
834
def withColumn (self , colName , col ):
@@ -862,10 +889,8 @@ def _api(self):
862
889
863
890
def df_varargs_api (f ):
864
891
def _api (self , * args ):
865
- jargs = ListConverter ().convert (args ,
866
- self .sql_ctx ._sc ._gateway ._gateway_client )
867
892
name = f .__name__
868
- jdf = getattr (self ._jdf , name )(self .sql_ctx ._sc . _jvm . PythonUtils . toSeq ( jargs ))
893
+ jdf = getattr (self ._jdf , name )(_to_seq ( self .sql_ctx ._sc , args ))
869
894
return DataFrame (jdf , self .sql_ctx )
870
895
_api .__name__ = f .__name__
871
896
_api .__doc__ = f .__doc__
@@ -912,9 +937,8 @@ def agg(self, *exprs):
912
937
else :
913
938
# Columns
914
939
assert all (isinstance (c , Column ) for c in exprs ), "all exprs should be Column"
915
- jcols = ListConverter ().convert ([c ._jc for c in exprs [1 :]],
916
- self .sql_ctx ._sc ._gateway ._gateway_client )
917
- jdf = self ._jdf .agg (exprs [0 ]._jc , self .sql_ctx ._sc ._jvm .PythonUtils .toSeq (jcols ))
940
+ jdf = self ._jdf .agg (exprs [0 ]._jc ,
941
+ _to_seq (self .sql_ctx ._sc , [c ._jc for c in exprs [1 :]]))
918
942
return DataFrame (jdf , self .sql_ctx )
919
943
920
944
@dfapi
@@ -1006,6 +1030,19 @@ def _to_java_column(col):
1006
1030
return jcol
1007
1031
1008
1032
1033
+ def _to_seq (sc , cols , converter = None ):
1034
+ """
1035
+ Convert a list of Column (or names) into a JVM Seq of Column.
1036
+
1037
+ An optional `converter` could be used to convert items in `cols`
1038
+ into JVM Column objects.
1039
+ """
1040
+ if converter :
1041
+ cols = [converter (c ) for c in cols ]
1042
+ jcols = ListConverter ().convert (cols , sc ._gateway ._gateway_client )
1043
+ return sc ._jvm .PythonUtils .toSeq (jcols )
1044
+
1045
+
1009
1046
def _unary_op (name , doc = "unary operator" ):
1010
1047
""" Create a method for given unary operator """
1011
1048
def _ (self ):
@@ -1177,8 +1214,7 @@ def inSet(self, *cols):
1177
1214
cols = cols [0 ]
1178
1215
cols = [c ._jc if isinstance (c , Column ) else _create_column_from_literal (c ) for c in cols ]
1179
1216
sc = SparkContext ._active_spark_context
1180
- jcols = ListConverter ().convert (cols , sc ._gateway ._gateway_client )
1181
- jc = getattr (self ._jc , "in" )(sc ._jvm .PythonUtils .toSeq (jcols ))
1217
+ jc = getattr (self ._jc , "in" )(_to_seq (sc , cols ))
1182
1218
return Column (jc )
1183
1219
1184
1220
# order
0 commit comments