Skip to content

Rule U005

Avoid loops inside a UDF body — use pyspark.sql.functions.transform instead

Severity

🔴 HIGH — Major performance impact.

PySpark version

Compatible with PySpark 3.1 and later.

Information

UDFs already pay the highest possible per-row cost: every row is serialised from the JVM to the Python interpreter and back. Adding a loop (for-statement or comprehension) inside the UDF body compounds that cost:

  • The Python interpreter iterates over array elements one by one, with no vectorisation and no Catalyst optimisation
  • The JVM↔Python serialisation overhead is paid once per outer row, but the loop runs entirely in slow Python — compare this to a native array function that runs on the JVM or in native code
  • List/set/dict comprehensions and generator expressions are all loops in disguise and carry the same penalty

pyspark.sql.functions.transform applies a lambda to each array element using Spark's native execution engine — the array never needs to leave the JVM, no UDF boundary is crossed, and Catalyst can optimize the expression.

Reference: pyspark.sql.functions.transform

Best practices

  • Replace for-loops that build a new list with transform(col, lambda x: ...)
  • Replace list comprehensions [f(x) for x in col] with transform(col, lambda x: f(x))
  • For filtering elements, use filter (the array function, not DataFrame filter)
  • For reducing to a scalar, use aggregate

Rule of thumb: If you find yourself looping inside a UDF, there is almost always a native Spark array function that does the same thing faster.

Example

Bad:

@udf(returnType=ArrayType(StringType()))
def upper_all(items):
    return [x.upper() for x in items]   # list comprehension inside UDF

@udf(returnType=ArrayType(IntegerType()))
def double_all(items):
    result = []
    for x in items:                     # for-loop inside UDF
        result.append(x * 2)
    return result

Good:

from pyspark.sql.functions import transform, upper, col

df.withColumn("items", transform(col("items"), lambda x: upper(x)))
df.withColumn("items", transform(col("items"), lambda x: x * 2))