1 Standard practices for efficient coding 🧠
1.1 Import PySpark modules with default aliases
Always start with the following imports, F
prefix for all PySpark functions, T
prefix for all PySpark types, and W
for Windows. Adding the #noqa
comment at the end of the line will disable linters (like flake8) from flagging the line as an error. Importing these will enable you to use the types and functions without importing each item separately, resulting in a speedy workflow. This will also eliminate the oopsie when you are applying the sum from the Python library instead of the PySpark library☹.
from pyspark.sql import DataFrame as df, functions as F, types as T, Window as W #noqa
sum()
F.
T.IntegerType() W.partitionBy()
1.2 Type annotaion and docstrings in functions
type hint and docstrings are a pretty good idea, it has numerous benefits such as - improve the readability -> makes it easier to understand the code - enhance the IDE support -> autocompletion, type checking, etc. - better documentation -> document while you are coding store the solution with the problem - static type checking -> tools like mypy can be used to help you catch errors early.
from pyspark.sql import DataFrame
# bad
def calculate_daily_scoops(df, flavor_col, date_col):
return df.groupBy(date_col, flavor_col).agg(F.sum("scoops_sold").alias("total_scoops"))
# good
def calculate_daily_scoops(df: DataFrame, flavor_col: str, date_col: str) -> DataFrame:
"""Calculate total scoops sold per day and flavor.
Groups the input DataFrame by date and flavor, aggregating the total number
of scoops sold for each unique combination.
Parameters
----------
df : DataFrame
Input DataFrame containing sales data.
flavor_col : str
Name of the column containing ice cream flavor information.
date_col : str
Name of the column containing date information.
Returns
-------
DataFrame
DataFrame with total scoops sold, grouped by date and flavor.
Contains columns for date, flavor, and total scoops.
Examples
--------
>>> import pyspark.sql.functions as F
>>> sales_df = spark.createDataFrame(...)
>>> daily_scoops = calculate_daily_scoops(sales_df, 'flavor', 'sale_date')
>>> daily_scoops.show()
"""
return df.groupBy(date_col, flavor_col).agg(
sum("scoops_sold").alias("total_scoops")
F. )
1.3 Formatting
black, just use black.
2 Implicit column selection
# bad
= df.select(F.lower(F.col('colA')), F.upper(F.col('colB')))
df
# good
= df.select(F.lower(df1.colA), F.upper(df2.colB))
df
# better - since Spark 3.0
= df.select(F.lower('colA'), F.upper('colB')) df
In most cases, the thrid option is best. Spark 3.0 expanded the scenarios where this approach works. However, when using strings as column selection isn’t feasible, fall back
2.1 When to deviate
In some contexts, you might need to access columns from multiple dataframes with overlapping names. For example, in matching expressions like df.join(df2, on=(df.key == df2.key), how='left')
. In such cases, it’s acceptable to reference columns directly by their dataframe. Please see Joins for more details.
3 Logical operations
Logical operations, which often reside inside .filter()
or F.when()
, need to be readable. Keep logic expressions inside the same code block to three (3) expressions at most. If they grow longer, it is often a sign that the code can and should be extracted and/or simplified. Extracting complex logical operations into variables makes the code easier to read and reason about, which in turn also reduces bugs.
# bad
'prod_status') == 'Delivered') | (((F.datediff('deliveryDate_actual', 'current_date') < 0) & ((F.col('currentRegistration') != '') | ((F.datediff('deliveryDate_actual', 'current_date') < 0) & ((F.col('originalOperator') != '') | (F.col('currentOperator') != '')))))), 'In Service') F.when( (F.col(
The code above can be simplified in different ways. To start, focus on grouping the logic steps in named variables. PySpark requires that expressions are wrapped with parentheses -> (). This, mixed with actual parenthesis to group logical operations, can hurt readability. For example the code above has a redundant condition -> (F.datediff(df.deliveryDate_actual, df.current_date) < 0)
that it very hard to spot.
# better
= ((F.col('originalOperator') != '') | (F.col('currentOperator') != ''))
has_operator = (F.datediff('deliveryDate_actual', 'current_date') < 0)
delivery_date_passed = (F.col('currentRegistration').rlike('.+'))
has_registration = (F.col('prod_status') == 'Delivered')
is_delivered
| (delivery_date_passed & (has_registration | has_operator)), 'In Service') F.when(is_delivered
The above example drops the redundant expression and is easier to read. We can improve it further by reducing the number of operations and group the ones that have a business meaning together.
# good
= F.col("currentRegistration").rlike(".+")
has_registration = (F.col("originalOperator") != "") | (F.col("currentOperator") != "")
has_operator = has_registration | has_operator
is_active = F.col("prod_status") == "Delivered"
is_delivered = F.datediff("deliveryDate_actual", "current_date") < 0
delivery_date_passed
| (delivery_date_passed & is_active), "In Service") F.when(is_delivered
The F.when
expression is now readable and the desired behavior is clear to anyone reviewing this code. Suprise that person is going to be you in 6 months. The reader only needs to visit the individual expressions if they suspect there is an error. It also makes each chunk of logic easy to reason about and test. You do have unit tests right?. if you want these expersions can be extracted into their own functions.
4 Use select
to specify a schema contract
Doing a select at the beginning or a end of a PySpark transform, specifies the contract with both the reader and the code about the expected dataframe schema for inputs and outputs. Keep select statements as simple as possible. Apply only one function from spark.sql.function
per column, plus an optional .alias()
to give it a meaningful name. Expressions involving more than one dataframe, or conditional operations like .when()
are discouraged to be used in a select.
# bad
= aircraft.select(
aircraft 'aircraft_id',
'aircraft_msn',
'aircraft_registration').alias('registration'),
F.col('aircraft_type',
'staleness').alias('avg_staleness'),
F.avg('number_of_economy_seats').cast('long'),
F.col('flight_hours').alias('avg_flight_hours'),
F.avg('operator_code',
'number_of_business_seats').cast('long'),
F.col( )
Unless order matters to you, try to cluster together operations of the same type, reducing the cognitve load on the reader of the code.
# good
= aircraft.select(
aircraft "aircraft_id",
"aircraft_msn",
"aircraft_type",
"operator_code",
"aircraft_registration").alias("registration"),
F.col("number_of_economy_seats").cast("long"),
F.col("number_of_business_seats").cast("long"),
F.col("staleness").alias("avg_staleness"),
F.avg("flight_hours").alias("avg_flight_hours"),
F.avg( )
4.1 Modifying columns in select
The select()
statement redefines the schema of a dataframe, so it naturally supports the inclusion or exclusion of columns, old and new, as well as the redefinition of pre-existing ones. By centralising all such operations in a single statement, it becomes much easier to identify the final schema, which makes code more readable. It also makes code more concise.
Instead of calling withColumnRenamed()
-> use alias()
:
#bad
'key', 'comments').withColumnRenamed('comments', 'num_comments')
df.select(
# good
"key", F.col("comments").alias("num_comments")) df.select(
Instead of using withColumn()
to redefine type -> cast()
in the select:
# bad
'comments').withColumn('comments', F.col('comments').cast('double'))
df.select(
# good
"comments").cast("double")) df.select(F.col(
4.2 Inclusive selection above exclusive
Only select columns that are needed, avoid using drop()
as over time it might will return a different dataframe than expected, and it will be your job of fixing it.
Finally, instead of adding new columns via the select statement, using .withColumn()
is recommended when adding a single columns. When adding or manipulating tens or hundreds of columns, use a single .select()
for performance reasons.
5 Empty columns
If you need to add an empty column to satisfy a schema, always use F.lit(None)
for populating that column. Never use an empty string or some other string signalling an empty value (such as -
,, or
NA
this can be interpeted as ‘North America’ for example).
Beyond being technically correct(the best kind of correct), one practical reason for using F.lit(None)
is preserving the ability to use utilities like isNull
, instead of having to verify empty strings, nulls, and 'NA'
, etc. Your future self will thank you.
# bad
= df.withColumn('foo', F.lit(''))
df
# also bad
= df.withColumn('foo', F.lit('NA'))
df
# good
= df.withColumn("foo", F.lit(None)) df
6 Using comments
While comments can provide useful insight into code, it is often more valuable to refactor the code to improve its readability and only use comments to explain the why or provide more context. Most if not all people lookinf ath the code will be able to understand what the code is doing, but not why it is doing it.
# bad
# Cast the timestamp columns
= ['start_date', 'delivery_date']
cols for c in cols:
= df.withColumn(c, F.from_unixtime(F.col(c) / 1000).cast(T.TimestampType())) df
In the example above, we can see that those columns are getting cast to Timestamp. The comment doesn’t add value. Moreover, a more verbose comment might still be unhelpful if it only provides information that already exists in the code. For example:
# still bad
# Go through each column, divide by 1000 because millis and cast to timestamp
= ['start_date', 'delivery_date']
cols for c in cols:
= df.withColumn(c, F.from_unixtime(F.col(c) / 1000).cast(T.TimestampType())) df
Instead of leaving comments that only describe the logic you wrote, aim to leave comments that give context, that explain the why of decisions you made when writing the code. This is particularly important for PySpark, since the reader can understand your code, but often doesn’t have context on the data that feeds into your PySpark transform. Small pieces of logic might have involved hours of digging👷 through data to understand the correct behavior, in which case comments explaining the rationale are especially valuable.
# good
# The consumer of this dataset expects a timestamp instead of a date, and we need
# to adjust the time by 1000 because the original datasource is storing these as millis
# even though the documentation says it's actually a date.
= ["start_date", "delivery_date"]
cols for c in cols:
= df.withColumn(c, F.from_unixtime(F.col(c) / 1000).cast(T.TimestampType())) df
7 UDFs (user defined functions)
It is highly recommended to avoid UDFs in all situations, as they are dramatically less performant than native PySpark. In most situations, logic that seems to necessitate a UDF can be refactored to use only native PySpark functions. However, there are some situations where a UDF is unavoidable.
8 Joins
Be careful with joins! If you perform a left join, and the right side has multiple matches for a key, that row will be duplicated as many times as there are matches. This is called a “join explosion” and can dramatically bloat the output of your transforms job. Always double check your assumptions to see that the key you are joining on is unique, unless you are expecting the multiplication.
Bad joins are the source of many tricky-to-debug issues. Always pass the join type by name, even if you are using the default values such as (inner)
. You did know that inner
is the default, right?
# bad
= flights.join(aircraft, 'aircraft_id')
flights
# also bad
= flights.join(aircraft, 'aircraft_id', 'inner')
flights
# good
= flights.join(other=aircraft, on="aircraft_id", how="inner") flights
Avoid right
joins. If you are about to use a right
join, switch the order of your dataframes and use a left
join instead. It is more intuitive since the dataframe you are doing the operation on is the one that you are centering your join around. This one doesnt make sense, but it does.
# bad
= aircraft.join(flights, 'aircraft_id', how='right')
flights
# good
= flights.join(other=aircraft, on="aircraft_id", how="left") flights
8.1 Column collisions when joining
Avoid renaming all columns to avoid collisions. Instead, give an alias to the whole dataframe, and use that alias to select which columns you want in the end.
# bad
= ['start_time', 'end_time', 'idle_time', 'total_time']
columns for col in columns:
= flights.withColumnRenamed(col, 'flights_' + col)
flights = parking.withColumnRenamed(col, 'parking_' + col)
parking
= flights.join(parking, on='flight_code', how='left')
flights
= flights.select(
flights 'flights_start_time').alias('flight_start_time'),
F.col('flights_end_time').alias('flight_end_time'),
F.col('parking_total_time').alias('client_parking_total_time')
F.col(
)
# good
= flights.alias("flights")
flights = parking.alias("parking")
parking
= flights.join(other=parking, on="flight_code", how="left")
flights
= flights.select(
flights "flights.start_time").alias("flight_start_time"),
F.col("flights.end_time").alias("flight_end_time"),
F.col("parking.total_time").alias("client_parking_total_time"),
F.col( )
In such cases, keep in mind:
- It is a better idea to only select the cols that are needed before joining
- In case you do need both, it might be best to rename one of them prior to joining, signaling the difference between the two cols, as most likely the underlying data generating process is different.
- You should always resolve ambiguous columns before outputting a dataset. After the transform is finished you can no longer distinguish between cols.
8.2 .dropDuplicates()
and .distinct()
to “clean” joins
Don’t think about using .dropDuplicates()
or .distinct()
as a quick fix for data duplication after a join. If unexpected duplicate rows are in your dataframe, there is always an underlying reason for why those duplicate rows appear. Adding .dropDuplicates()
only masks this problem and adds unneccesary cpu cycles.
9 Window Functions
Window functions are incredibly useful when performing data transformations, especially when you need to perform calculations that require a specific order or grouping of data. They allow you to perform operations across a set of table rows that are somehow related to the current row. This is particularly powerful for operations like running totals, moving averages, and rank calculations.
Always explicitly define three things when working with window functions:
partitionBy
: The partitions or groups over which the window function will be applied.orderBy
: The ordering of rows within the partition.rowsBetween
orrangeBetween
: Defines the scope in rows or ranges considered when applying a function over the ordered partition.
By specifying these three components, you ensure that your windows are properly defined.
from pyspark.sql import functions as F, Window as W
= spark.createDataFrame(
df "a", 1), ("a", 2), ("a", 3), ("a", 4)],
[("key", "num"],
[
)
# bad
= W.partitionBy('key')
w1 = W.partitionBy('key').orderBy('num')
w2
'key', F.sum('num').over(w1).alias('sum')).show()
df.select(# +---+---+
# |key|sum|
# +---+---+
# | a| 10|
# | a| 10|
# | a| 10|
# | a| 10|
# +---+---+
'key', F.sum('num').over(w2).alias('sum')).show()
df.select(# +---+---+
# |key|sum|
# +---+---+
# | a| 1|
# | a| 3|
# | a| 6|
# | a| 10|
# +---+---+
'key', F.first('num').over(w2).alias('first')).show()
df.select(# +---+-----+
# |key|first|
# +---+-----+
# | a| 1|
# | a| 1|
# | a| 1|
# | a| 1|
# +---+-----+
'key', F.last('num').over(w2).alias('last')).show()
df.select(# +---+----+
# |key|last|
# +---+----+
# | a| 1|
# | a| 2|
# | a| 3|
# | a| 4|
# +---+----+
It is much safer to always specify an explicit window with the three components:
# good
= (
w3 "key")
W.partitionBy("num")
.orderBy(
.rowsBetween(=W.unboundedPreceding,
start=0, #<- zero means the current row, -1 is the row before the current
end
)
)= (
w4 "key")
W.partitionBy("num")
.orderBy(
.rowsBetween(=W.unboundedPreceding,
start=W.unboundedFollowing,
end
)
)
"key", F.sum("num").over(w3).alias("sum")).collect()
df.select(# [Row(key="a", sum=1),
# Row(key="a", sum=3),
# Row(key="a", sum=6),
# Row(key="a", sum=10),]
"key", F.sum("num").over(w4).alias("sum")).collect()
df.select(# [Row(key="a", sum=10),
# Row(key="a", sum=10),
# Row(key="a", sum=10),
# Row(key="a", sum=10),]
"key", F.first("num").over(w4).alias("first")).collect()
df.select(# [Row(key="a", first=1),
# Row(key="a", first=1),
# Row(key="a", first=1),
# Row(key="a", first=1),]
"key", F.last("num").over(w4).alias("last")).collect()
df.select(# [Row(key="a", last=4),
# Row(key="a", last=4),
# Row(key="a", last=4),
# Row(key="a", last=4),]
9.1 Dealing with nulls
While nulls are ignored for aggregate functions (like F.sum()
and F.max()
), they will impact the result of analytic functions (like F.first()/F.last()
and F.lead()/F.lag()
)
= spark.createDataFrame(
df_nulls "a", None), ("a", 1), ("a", 2), ("a", None)],
[("key", "num"],
[
)
"key", F.first("num").over(w4).alias("first")).collect()
df_nulls.select(# [Row(key="a", first=None),
# Row(key="a", first=None),
# Row(key="a", first=None),
# Row(key="a", first=None),]
"key", F.last("num").over(w4).alias("last")).collect()
df_nulls.select(# [Row(key="a", last=None),
# Row(key="a", last=None),
# Row(key="a", last=None),
# Row(key="a", last=None),]
Best to avoid this problem by enabling the ignorenulls
flag:
df_nulls.select("key", F.first("num", ignorenulls=True).over(w4).alias("first")
).collect()# [Row(key="a", first=1),
# Row(key="a", first=1),
# Row(key="a", first=1),
# Row(key="a", first=1),]
"key", F.last("num", ignorenulls=True).over(w4).alias("last")).collect()
df_nulls.select(# [Row(key="a", last=2),
# Row(key="a", last=2),
# Row(key="a", last=2),
# Row(key="a", last=2),]
Also be mindful of explicit ordering of nulls to make sure the expected results are obtained:
= (
w5 "key")
W.partitionBy("num"))
.orderBy(F.asc_nulls_first(
.rowsBetween(W.currentRow, W.unboundedFollowing)
)= (
w6 "key")
W.partitionBy("num"))
.orderBy(F.asc_nulls_last(
.rowsBetween(W.currentRow, W.unboundedFollowing)
)
"key", F.lead("num").over(w5).alias("lead")).collect()
df_nulls.select(# [ Row(key="a", lead=None),
# Row(key="a", lead=None),
# Row(key="a", lead=1),
# Row(key="a", lead=2),]
"key", F.lead("num").over(w6).alias("lead")).collect()
df_nulls.select(# [ Row(key="a", lead=1),
# Row(key="a", lead=2),
# Row(key="a", lead=None),
# Row(key="a", lead=None),]
9.2 Empty partitionBy()
Spark window functions can be applied over all rows, using a global frame. This is accomplished by specifying zero columns in the partition by expression (i.e. W.partitionBy()
).
Code like this should be avoided, however, as it forces Spark to combine all data into a single partition, which can be extremely harmful for performance.
Prefer to use aggregations whenever possible:
# bad
= W.partitionBy()
w = df.select(F.sum('num').over(w).alias('sum'))
df
# good
= df.agg(F.sum("num").alias("sum")) df
10 Micro structure
This is a style guide with an opinion, and here are some recommendation on how to strucutre your scripts and what is the best way of applying transformations to your data.
- start with loading the specific data that you want, by combining
select
andwhere
-> repeat for all data sources - select a base table(this should be the largest table) and left join all other tables
- perform aggregations or window transformations
- post-processing transformations
If you utilize a functional approach, then you can use transform()
method to chain multiple operations together, keep in mind that the function supplied to transform()
should accept and return a dataframe. this structure will force you to write function that are ready for unit testing.
10.1 making transformation functions
this can be achieved in three ways:
10.1.1 wrapper function
def transform_wrapper(df):
return f(df, specific_arg1, specific_arg2)
= df.transform(transform_wrapper) df
10.1.2 lambda function
= df.transform(lambda df: f(df, specific_arg1, specific_arg2)) df
10.1.3 partial function
from functools import partial
= partial(f, arg1=value1, arg2=value2)
f_par = df.transform(f_par) df
Chaining expressions is a contentious topic, however, since this is an opinionated guide, we are opting to recommend some limits on the usage of chaining. See the conclusion of this section for a discussion of the rationale behind this recommendation.
Avoid chaining of expressions into multi-line expressions with different types, particularly if they have different behaviours or contexts. For example- mixing column creation or joining with selecting and filtering.
# bad
= (
df
df'a', 'b', 'c', 'key')
.select(filter(F.col('a') == 'truthiness')
.'boverc', F.col('b') / F.col('c'))
.withColumn('key', how='inner')
.join(df2, 'key', how='left')
.join(df3, 'c')
.drop(
)
# better (seperating into steps)
# first: we select and trim down the data that we need
# second: we create the columns that we need to have
# third: joining with other dataframes
= df.select("a", "b", "c", "key").filter(F.col("a") == "truthiness")
df = df.withColumn("boverc", F.col("b") / F.col("c"))
df = df.join(df2, "key", how="inner").join(df3, "key", how="left").drop("c") df
Having each group of expressions isolated into its own logical code block improves legibility and makes it easier to find relevant logic. For example, a reader of the code below will probably jump to where they see dataframes being assigned df = df...
.
# bad
= (
df
df'foo', 'bar', 'foobar', 'abc')
.select(filter(F.col('abc') == 123)
.'some_field')
.join(another_table,
)
# better
= df.select("foo", "bar", "foobar", "abc").filter(F.col("abc") == 123)
df = df.join(other=another_table, on="some_field", how="inner") df
There are legitimate reasons to chain expressions together. These commonly represent atomic logic steps, and are acceptable. Apply a rule with a maximum of number chained expressions in the same block to keep the code readable. We recommend chains of no longer than 3 statements.
If you find you are making longer chains, or having trouble because of the size of your variables, consider extracting the logic into a separate function:
# bad
= (
customers_with_shipping_address
customers_with_shipping_address'a', 'b', 'c', 'key')
.select(filter(F.col('a') == 'truthiness')
.'boverc', F.col('b') / F.col('c'))
.withColumn('key', how='inner')
.join(df2,
)
# also bad
= customers_with_shipping_address.select('a', 'b', 'c', 'key')
customers_with_shipping_address = customers_with_shipping_address.filter(F.col('a') == 'truthiness')
customers_with_shipping_address = customers_with_shipping_address.withColumn('boverc', F.col('b') / F.col('c'))
customers_with_shipping_address = customers_with_shipping_address.join(df2, 'key', how='inner')
customers_with_shipping_address
# better
def join_customers_with_shipping_address(customers, df_to_join):
= customers.select("a", "b", "c", "key").filter(
customers "a") == "truthiness"
F.col(
)= customers.withColumn("boverc", F.col("b") / F.col("c"))
customers return customers.join(other=df_to_join, on="key", how="inner")
Chains of more than 3 statement are prime candidates to factor into separate, well-named functions since they are already encapsulated, isolated blocks of logic.
The rationale for why we’ve set these limits on chaining:
- Differentiation between PySpark code and SQL code. Chaining is something that goes against most, if not all, other Python styling. You don’t chain in Python, you assign.
- Discourage the creation of large single code blocks. These would often make more sense extracted as a named function.
- It doesn’t need to be all or nothing, but a maximum of five lines of chaining balances practicality with legibility.
- If you are using an IDE, it makes it easier to use automatic extractions or do code movements
- Large chains are hard to read and maintain, particularly if chains are nested.
##TODO # Good flavor_analysis = ( ice_cream_sales .transform(clean_sales_data) .transform(add_price_category) .transform(calculate_flavor_metrics) )
11 Avoid: Procedural approach with intermediate DataFrames
clean_sales = clean_sales_data(ice_cream_sales) categorized_sales = add_price_category(clean_sales) flavor_analysis = calculate_flavor_metrics(categorized_sales)
12 Other Considerations and Recommendations
- Be wary of functions that grow too large. As a general rule, a file should not be over 250 lines, and a function should not be over 70 lines.
- Try to keep your code in logical blocks. For example, if you have multiple lines referencing the same things, try to keep them together. Separating them reduces context and readability.
- Test your code! If you can run the local tests, do so and make sure that your new code is covered by the tests. If you can’t run the local tests, build the datasets on your branch and manually verify that the data looks as expected.
- Avoid
.otherwise(value)
as a general fallback. If you are mapping a list of keys to a list of values and a number of unknown keys appear, usingotherwise
will mask all of these into one value. - Do not keep commented out code checked in the repository. This applies to single line of codes, functions, classes or modules. Rely on git and its capabilities of branching or looking at history instead.
- When encountering a large single transformation composed of integrating multiple different source tables, split it into the natural sub-steps and extract the logic to functions. This allows for easier higher level readability and allows for code re-usability and consistency between transforms.
- Try to be as explicit and descriptive as possible when naming functions or variables. Strive to capture what the function is actually doing as opposed to naming it based the objects used inside of it.
- Think twice about introducing new import aliases, unless there is a good reason to do so. Some of the established ones are
types
andfunctions
from PySparkfrom pyspark.sql import types as T, functions as F
. - Avoid using literal strings or integers in filtering conditions, new values of columns etc. Instead, to capture their meaning, extract them into variables, constants, dicts or classes as suitable. This makes the code more readable and enforces consistency across the repository.