Optimize a Spark query to run x100 times faster

This project started with a deceptively simple question:

Can we identify return trips in New York City taxi data?

A taxi trip b is a return trip for a if:

  1. b picks up within 8 hours after a drops off

  2. b picks up within r meters of a’s dropoff

  3. b drops off within r meters of a’s pickup
    where r ∈ {50, 100, 150, 200} meters

The baseline implementation looked straightforward.
In practice, the first version took 250 seconds — more than four minutes.


The 250-Second Problem

My beginner approach looked like this:

val joined = trips.alias("a").crossJoin(trips.alias("b"))

If you have ~10 million trips, that means:

10M × 10M = 100 trillion comparisons

Even after adding time filters, distance calculations, and reasonable Spark tuning, it still ran around 250 seconds. It was clear Spark wasn’t the problem — my thinking was.

I shifted the question from:

“How do I make Spark faster?”
to
“Why am I comparing trips that will never match?”

That shift unlocked the solution.


Rethinking Geography

NYC is around 1,200 km², roughly 50 km × 50 km.
Our matching radius is 50–200 meters.

A tiny dot inside a huge canvas.

So I mapped NYC into grid cells where:

Any two points inside the same cell are guaranteed to be < r meters apart.

Instead of joining across the entire city, a trip in cell (i, j) only needs to compare with trips in these 9 spatial neighborhoods:

+-----------+-----------+-----------+ | (i-1,j-1) | (i-1, j) | (i-1,j+1) | +-----------+-----------+-----------+ | (i, j-1) | (i, j) | (i, j+1) | +-----------+-----------+-----------+ | (i+1,j-1) | (i+1, j) | (i+1,j+1) | +-----------+-----------+-----------+

Just 9 boxes, instead of the entire city.

This alone removed ~99% of unnecessary comparisons.


Time Works the Same Way

Return trips must happen within 8 hours.
So I discretized time into 8-hour buckets:

time_idx = floor(timestamp / 8 hours)

Return-trip candidates satisfy:

time_idx_B ∈ [time_idx_A, time_idx_A + 1]

Again, only local neighbors — not the whole month.


Reshaping the Data Into Something Spark Loves

I duplicated each trip in tripAB for each valid combination:

  • 9 spatial neighbor cells

  • × 2 time buckets
    18 rows per trip

Usually duplication is bad, but here it made the join trivial.

The join key became three integers:

(lat_idx, lon_idx, time_idx)

Spark can hash-partition this instantly — no trig functions, no shuffle explosions, no Cartesian nightmares.

The expensive haversine distance and time checks were applied after the join, on a tiny set of candidates.


The Result

The optimized pipeline didn’t run in four minutes.
It didn’t run in one minute.

It ran in:

👉 55 milliseconds

From 250 seconds → 0.055 seconds.

A speedup of over 4,500×.

Not from tuning Spark, but from giving Spark far less work.


Lessons Learned

  • Don’t optimize the join. Reduce the search space.
    Most pairs should never be compared; remove them before the join exists.

  • Geography and time are not fields — they are structure.
    Use them to build natural partitions the data actually lives in.

  • Spark isn’t slow. Wrong data shapes are slow.
    Make your data small, local, and integer-joinable.

  • Trigonometry is expensive. Run it last, not first.

  • Sometimes duplication speeds things up.
    If it turns a hard join into a cheap one, do it.

  • The best optimization is a better question.
    “Why am I comparing these rows at all?”