PySpark - Exercise: Reviewing AI-Generated Code

PySpark - Exercise: Reviewing AI-Generated Code

Reviewing AI-Generated Code

By now you can write Spark, and so can an AI assistant. The skill that pays the rent is reading what it produces and knowing when it will hurt you. This exercise hands you a pipeline that an assistant might give you: it runs, and on a sample the answer is correct. Your job is to make it correct and fast on the full table, and to explain every change.

The setup

Run this first. It builds an orders table and a customers table so the exercise is self-contained.

from pyspark.sql import functions as F

orders = (spark.range(1_000_000).withColumnRenamed("id", "order_id")
          .withColumn("customer_id", (F.rand() * 50_000).cast("int"))
          .withColumn("country", F.element_at(
              F.array(F.lit("nl"), F.lit("de"), F.lit("fr"), F.lit("us"), F.lit("uk")),
              (F.rand() * 5 + 1).cast("int")))
          .withColumn("amount", F.round(F.rand() * 100, 2)))

customers = (spark.range(50_000).withColumnRenamed("id", "customer_id")
             .withColumn("segment", F.when(F.rand() > 0.5, "A").otherwise("B")))

A million rows is small enough to run quickly and large enough to feel the difference. Picture it standing in for the fifty million row table you would meet in production.

The pipeline an assistant gave you

You asked for the total order amount per country. This is what came back. It runs, and the totals are right.

from pyspark.sql.types import StringType

# normalise the country code
to_upper = F.udf(lambda s: s.upper() if s else s, StringType())
orders_norm = orders.withColumn("country", to_upper("country"))

# attach customer details
joined = orders_norm.join(customers, "customer_id")

# total the amounts per country
rows = joined.select("country", "amount").collect()
totals = {}
for r in rows:
    totals[r["country"]] = totals.get(r["country"], 0) + r["amount"]

print(totals)

Your task

  1. Run it. It works, so the problem is not correctness. It is cost.
  2. Open the query profile for the run, from the cell output or the Query History. Find the step that pulls data back to the driver, and the stage that moves the most data.
  3. Rewrite the pipeline so it returns the same totals, faster, and keeps the work on the cluster.
  4. Write two or three sentences for each change, saying what was wrong and why your version is better.

What to look for

Three things in that code will not scale. Try to find them before you read on.

  • A Python UDF does what a built-in already does. Every value has to be handed out to Python and back, which is slow, and it blocks Spark’s optimiser from reasoning about that column. F.upper runs inside the engine.
  • The join never tells Spark that customers is small. Without that, Spark can shuffle the large orders table across the network to line up keys. A broadcast join sends the small table to every executor instead, and skips the shuffle.
  • The code collect()s the whole joined table to the driver and totals it in a Python loop. That pulls a million rows off the cluster into one machine’s memory, which is slow at best and an out-of-memory crash at scale. Grouping is Spark’s job.

One way to fix it

totals = (orders
          .withColumn("country", F.upper("country"))
          .join(F.broadcast(customers), "customer_id")
          .groupBy("country")
          .sum("amount"))
totals.show()

Same answer. The uppercasing runs in the engine, the small table is broadcast instead of shuffled, and the totalling stays distributed. Nothing returns to the driver except the handful of rows in the result.

What to hand in

  • The fixed notebook.
  • The before and after run times.
  • A short diagnosis of each of the three traps.

You are graded on the diagnosis, not on whether the code runs. Anyone can run it. Knowing why a stage is slow is the job, and it is exactly what the Spark certification asks about under broadcasting, shuffling, and tuning. For the techniques behind these fixes, see Tuning Essentials.

Key questions you can now answer

  • Why is a Python UDF usually slower than the equivalent built-in function?
  • When does a broadcast join beat a regular join, and what goes wrong without it?
  • Why is collect() on a large DataFrame dangerous, and what do you use instead?
  • How do you find the most expensive stage of a job from the query profile?
  • Why is “it returns the right answer” not enough to call Spark code finished?