Dealing with data skew using salting - work in progress




The scenario

we are trying to join two dataframes.

  1. df_fact_sales, which is from a fact table.
  2. df_dim_country, which is from a dimension table

We are joining the dataframes on the country_id column.

But, as you can see below, the sales fact table is highly skewed.

When you join the table, it causes OOM error, because the executor handling partition 1 gets overloaded with too many rows, compares to other executors.


You can solve this problem by using the salting technique.

First, understand how the partition number is being set

How spark decides partition number in the hash shuffle join

In the hash sort shuffle join, spark decides the partition number of a row by the formula -

partition_number = hash(join_column) % spark.sql.shuffle.partitions

In this example, let's assume spark.sql.shuffle.partitions=3

So, the formula is -

partition_number = hash(join_column) % 3

So, columns with the same value in the join_column will go into the same partition.


Now, if we can control which row goes into which partition, we can solve this data skew issue.

We cannot change the country_id column. But, we can add an artificially created column to the dataframe. Let's call this column as salt. And then instead of joining on just country_id, we will join on (country_id, salt).

In this case, the partition_number is calculated as -

partition_number = hash(country_id, salt) % spark.sql.shuffle.partitions

partition_number = hash(country_id, salt) % 3

Let's decide the salt value to be from 0 to 2. So the salt value can be any one of {0, 1, 2}.

Preparing the fact dataframe

We'll assign the salt randomly to rows in df_fact_sales dataframe. After salting, the sales (fact) dataframe looks like this -

Notice the new partition numbers. Notice how even the partitions have now become.

Preparing the dimension dataframe

Now, its time to process the country (dimension) dataframe.

We will take the country dataframe, and explode it to create every combination of (country_id, salt). An easy way to do this is by creating a single column dataframe of the salt values, and then performing a cross-join between df_salt and df_dim_country.

Again, here also, spark will calculate the partition number with the same formula -

partition_number = hash(country_id, salt) % 3

The country dataframe now looks like this -

Joining the prepared dataframes

Finally, we join the dataframes on the columns (country_id, salt). The join operation will look like this -

Notice how the df_dim_country_exploded always has the rows required to join with df_fact_sales_salted. This is because of the dataframe explosion we did.

Notice the almost even distribution of rows. The executor handling partition 1 is not overloaded, and the query can actually execute without any OOM error.


Other articles with similar tags: