WORK IN PROGRESS
The scenario
we are trying to join two dataframes.
df_fact_sales
, which is from a fact table.df_dim_country
, which is from a dimension table
We are joining the dataframes on the country_id
column.
But, as you can see below, the sales fact table is highly skewed.
When you join the table, it causes OOM error, because the executor handling partition 1 gets overloaded with too many rows, compares to other executors.
Join:
You can solve this problem by using the salting technique.
First, understand how the partition number is being set
How spark decides partition number in the hash shuffle join
In the hash sort shuffle join, spark decides the partition number of a row by the formula -
partition_number
= hash(join_column
) % spark.sql.shuffle.partitions
In this example, let's assume spark.sql.shuffle.partitions=3
So, the formula is -
partition_number
= hash(join_column
) % 3
So, columns with the same value in the join_column will go into the same partition.
Salting
Now, if we can control which row goes into which partition, we can solve this data skew issue.
We cannot change the country_id column. But, we can add an artificially created column to the dataframe. Let's call this column as salt
. And then instead of joining on just country_id
, we will join on (country_id
, salt
).
In this case, the partition_number is calculated as -
partition_number
= hash(country_id
, salt
) % spark.sql.shuffle.partitions
partition_number
= hash(country_id
, salt
) % 3
Let's decide the salt value to be from 0 to 2. So the salt value can be any one of {0, 1, 2}.
Preparing the fact dataframe
We'll assign the salt randomly to rows in df_fact_sales
dataframe. After salting, the sales (fact) dataframe looks like this -
Notice the new partition numbers. Notice how even the partitions have now become.
Preparing the dimension dataframe
Now, its time to process the country (dimension) dataframe.
We will take the country dataframe, and explode it to create every combination of (country_id, salt). An easy way to do this is by creating a single column dataframe of the salt values, and then performing a cross-join between df_salt
and df_dim_country
.
Again, here also, spark will calculate the partition number with the same formula -
partition_number
= hash(country_id
, salt
) % 3
The country dataframe now looks like this -
Joining the prepared dataframes
Finally, we join the dataframes on the columns (country_id, salt)
. The join operation will look like this -
Notice how the df_dim_country_exploded
always has the rows required to join with df_fact_sales_salted
. This is because of the dataframe explosion we did.
Notice the almost even distribution of rows. The executor handling partition 1 is not overloaded, and the query can actually execute without any OOM error.