PySpark: Top N Records In Each Group

Previously I blogged about extracting top N records from each group using Hive. This post shows how to do the same in PySpark.

As compared to earlier Hive version this is much more efficient as its uses combiners (so that we can do map side computation) and further stores only N records any given time both on the mapper and reducer side.

import heapq

def takeOrderedByKey(self, num, sortValue = None, reverse=False):

        def init(a):
            return [a]

        def combine(agg, a):
            return getTopN(agg)

        def merge(a, b):
            agg = a + b
            return getTopN(agg)

        def getTopN(agg):
            if reverse == True:
                return heapq.nlargest(num, agg, sortValue)
                return heapq.nsmallest(num, agg, sortValue)              

        return self.combineByKey(init, combine, merge)

# Create some fake student dataset. The objective is to use identify top 2 
# students in each class based on GPA scores. 
data = [
        ('ClassA','Student1', 3.89),('ClassA','Student2', 3.13),('ClassA', 'Student3',3.87),
        ('ClassB','Student1', 2.89),('ClassB','Student2', 3.13),('ClassB', 'Student3',3.97)

# Add takeOrderedByKey function to RDD class 
from pyspark.rdd import RDD
RDD.takeOrderedByKey = takeOrderedByKey

# Load dataset
rdd1 = sc.parallelize(data).map(lambda x: (x[0], x))

# extract top 2 records in each class ordered by GPA in descending order
for i in rdd1.takeOrderedByKey(2, sortValue=lambda x: x[2], reverse=True).flatMap(lambda x: x[1]).collect():
    print i

Output of the above program is:

('ClassB', 'Student3', 3.97)
('ClassB', 'Student2', 3.13)
('ClassA', 'Student1', 3.89)
('ClassA', 'Student3', 3.87)

The key line to understand is line number 22. We use combineByKey operator to split the dataset by key and then use the heap data structure to order input records by GPA score. You can find a good explanation of combineByKey operator on Adam Shinn’s blog.

Finally note that in line number 40, x in sortValue = lambda x: x[2] refers to the value of the PairRDD created at line number 37.


About Ritesh Agrawal

I am a applied researcher who enjoys anything related to statistics, large data analysis, data mining, machine learning and data visualization.
This entry was posted in Programming and tagged , , , . Bookmark the permalink.

3 Responses to PySpark: Top N Records In Each Group

  1. Just curious
    def init(a):
    return [a]
    wont this part cause whole operation to work in driver rather than the executor.

  2. kopi says:

    why are you calling getTopN twice, in combine and merge?

  3. kopi says:

    i think this should work:

    def combine(agg, a):
    return agg

    def merge(a, b):
    agg = a + b
    return getTopN(agg)

Leave a Reply

Fill in your details below or click an icon to log in: Logo

You are commenting using your account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s