Select Top n Records For Each Group In Python (Pandas)

Say that you have a dataframe in Pandas and you are interested in finding the top n records for each group. Depending on your need, top n can be defined based on a numeric column in your dataframe or it can simply be based on the count of occurrences for the rows in that group.

For example, suppose (god forbid) that I have a retailer store with four branches (A, B, C, D). Also, suppose that for each day, I would like to get the three branches with most number of items sold. So, I would like to take the day as a group, count the number of sold items in each branch for that day and then pick the three branches with highest number of sales. This is our problem, now, without further to say, let’s see the code for our example.

import pandas as pd
sales = {
'branch_name': [ 'A', 'A',  'B',   'B',   'C', 'C',  'C', 'C',  'D',
                        'A', 'A',  'B',   'B',   'C', 'D',  'D', 'D',  'D'],
    
'date': ['01/01/2020','01/01/2020','01/01/2020','01/01/2020','01/01/2020','01/01/2020',
         '01/01/2020','01/01/2020','01/01/2020', '02/01/2020','02/01/2020','02/01/2020',
         '02/01/2020','02/01/2020','02/01/2020','02/01/2020','02/01/2020','02/01/2020'],
    
'item_no': ['I1','I2','I3','I4','I5','I6','I7','I8','I9',
           'I10','I11','I12','I13','I14','I15','I16','I17','I18']
        
         }
sales = pd.DataFrame(sales)
sales['date'] = pd.to_datetime(sales['date'],format = '%d/%m/%Y')

Sweet, let’s look how our dataframe looks like

        branch_name 	date 	item_no
0 	   A       	2020-01-01 	I1
1 	   A       	2020-01-01 	I2
2 	   B       	2020-01-01 	I3
3 	   B       	2020-01-01 	I4
4 	   C       	2020-01-01 	I5
5 	   C       	2020-01-01 	I6
6 	   C       	2020-01-01 	I7
7 	   C       	2020-01-01 	I8
8 	   D       	2020-01-01 	I9
9 	   A       	2020-01-02 	I10
10 	   A       	2020-01-02 	I11
11 	   B       	2020-01-02 	I12
12 	   B       	2020-01-02 	I13
13 	   C       	2020-01-02 	I14
14 	   D       	2020-01-02 	I15
15 	   D       	2020-01-02 	I16
16 	   D       	2020-01-02 	I17
17 	   D       	2020-01-02 	I18

Great, from the dataframe above, you can see that for 2020-01-01, we have most sales coming from branches A, B and C whereas for 2020-01-02, we have most sales coming from branches A, B and D. Let’s write the expression to return what we just concluded with our eyes.

Solution

The solution here should be as the following:
1- We need to count the number of items sold for each day and each branch.

sales_agg = sales.groupby(['date', 'branch_name']).agg({'item_no':'nunique'}).reset_index()
print(sales_agg)
      date 	   branch_name 	item_no
0 	2020-01-01 	A 	         2
1 	2020-01-01 	B 	         2
2 	2020-01-01 	C 	         4
3 	2020-01-01 	D 	         1
4 	2020-01-02 	A 	         2
5 	2020-01-02 	B 	         2
6 	2020-01-02 	C 	         1
7 	2020-01-02 	D 	         4

2- For each day, we need to sort the branches by the number of items sold in descending order

sales_sorted = sales_agg.groupby(['date']).apply(lambda x: x.sort_values(['item_no'],ascending = False)).reset_index(drop = True)
print(sales_sorted)
 	date 	    branch_name 	item_no
0 	2020-01-01   	C 	         4
1 	2020-01-01   	A 	         2
2 	2020-01-01   	B 	         2
3 	2020-01-01   	D 	         1
4 	2020-01-02   	D 	         4
5 	2020-01-02   	A 	         2
6 	2020-01-02   	B 	         2
7 	2020-01-02   	C 	         1

3- Now for each date, we need to pick the top n records (3 in our case)

sales_sorted.groupby(['date']).head(3)
      date 	   branch_name 	item_no
0 	2020-01-01 	   C 	      4
1 	2020-01-01 	   A 	      2
2 	2020-01-01 	   B 	      2
4 	2020-01-02 	   D 	      4
5 	2020-01-02 	   A 	      2
6 	2020-01-02 	   B 	      2

Great, we are done. This has returned the top 3 branches based on total number of items sold for each day. Hope this helps you solve your own problem.