What is PySpark?

PySpark is the Python API for Apache Spark, allowing you to write Spark applications using Python. It provides a simple way to parallelize your data processing across a cluster while writing familiar Python code.

PySpark combines Python's ease of use with Spark's distributed computing power, making it the most popular way to work with Spark.

Getting Started

# Install PySpark
pip install pyspark

# For Jupyter notebook support
pip install pyspark[sql]

# Create a SparkSession
from pyspark.sql import SparkSession

spark = SparkSession.builder \
    .appName("MyApp") \
    .master("local[*]") \
    .config("spark.sql.shuffle.partitions", "8") \
    .getOrCreate()

# Check Spark version
print(spark.version)

# Access SparkContext
sc = spark.sparkContext

Creating DataFrames

# From Python list
data = [
    ("Alice", 28, "Engineering"),
    ("Bob", 35, "Marketing"),
    ("Charlie", 42, "Sales")
]
columns = ["name", "age", "department"]
df = spark.createDataFrame(data, columns)
df.show()
# +-------+---+-----------+
# |   name|age| department|
# +-------+---+-----------+
# |  Alice| 28|Engineering|
# |    Bob| 35|  Marketing|
# |Charlie| 42|      Sales|
# +-------+---+-----------+

# From CSV file
df_csv = spark.read \
    .option("header", "true") \
    .option("inferSchema", "true") \
    .csv("data/employees.csv")

# From JSON file
df_json = spark.read.json("data/events.json")

# From Parquet (columnar format - recommended)
df_parquet = spark.read.parquet("data/sales.parquet")

# With explicit schema
from pyspark.sql.types import StructType, StructField, StringType, IntegerType

schema = StructType([
    StructField("name", StringType(), True),
    StructField("age", IntegerType(), True),
    StructField("department", StringType(), True)
])
df = spark.createDataFrame(data, schema)

DataFrame Operations

from pyspark.sql.functions import col, lit, when, count, sum, avg, max, min

# Select columns
df.select("name", "age").show()
df.select(col("name"), col("age") + 1).show()

# Filter rows
df.filter(col("age") > 30).show()
df.where("age > 30 AND department = 'Sales'").show()

# Add new columns
df = df.withColumn("senior", when(col("age") >= 40, True).otherwise(False))
df = df.withColumn("bonus", col("salary") * 0.1)
df = df.withColumn("country", lit("USA"))

# Rename columns
df = df.withColumnRenamed("name", "employee_name")

# Drop columns
df = df.drop("temp_column")

# Sort/Order
df.orderBy("age").show()  # Ascending
df.orderBy(col("age").desc()).show()  # Descending

# Distinct values
df.select("department").distinct().show()

# Handle nulls
df.na.drop()  # Drop rows with any null
df.na.drop(subset=["age"])  # Drop if age is null
df.na.fill(0, subset=["salary"])  # Fill nulls with 0

Aggregations

from pyspark.sql.functions import count, sum, avg, max, min, countDistinct

# Basic aggregations
df.agg(
    count("*").alias("total_rows"),
    avg("age").alias("avg_age"),
    max("salary").alias("max_salary")
).show()

# Group by
df.groupBy("department").agg(
    count("*").alias("employee_count"),
    avg("salary").alias("avg_salary"),
    sum("salary").alias("total_salary")
).show()

# Multiple grouping columns
df.groupBy("department", "job_title").agg(
    count("*").alias("count")
).orderBy("department").show()

# Having (filter after grouping)
df.groupBy("department") \
    .agg(count("*").alias("count")) \
    .filter(col("count") > 10) \
    .show()

# Pivot tables
df.groupBy("year") \
    .pivot("quarter", ["Q1", "Q2", "Q3", "Q4"]) \
    .agg(sum("revenue")) \
    .show()

Joins

# Sample DataFrames
employees = spark.createDataFrame([
    (1, "Alice", 101),
    (2, "Bob", 102),
    (3, "Charlie", 103)
], ["emp_id", "name", "dept_id"])

departments = spark.createDataFrame([
    (101, "Engineering"),
    (102, "Marketing"),
    (104, "HR")
], ["dept_id", "dept_name"])

# Inner join (default)
employees.join(departments, "dept_id").show()

# Left outer join
employees.join(departments, "dept_id", "left").show()

# Right outer join
employees.join(departments, "dept_id", "right").show()

# Full outer join
employees.join(departments, "dept_id", "outer").show()

# Cross join (Cartesian product)
employees.crossJoin(departments).show()

# Join on multiple columns
df1.join(df2, ["col1", "col2"], "inner").show()

# Join with different column names
employees.join(
    departments,
    employees.dept_id == departments.dept_id,
    "inner"
).drop(departments.dept_id).show()

# Self join
from pyspark.sql.functions import col
emp = employees.alias("e")
mgr = employees.alias("m")
emp.join(mgr, col("e.manager_id") == col("m.emp_id")).show()

Window Functions

from pyspark.sql.window import Window
from pyspark.sql.functions import row_number, rank, dense_rank, lag, lead, sum

# Define window spec
window_spec = Window.partitionBy("department").orderBy(col("salary").desc())

# Row number within partition
df = df.withColumn("row_num", row_number().over(window_spec))

# Rank (gaps after ties)
df = df.withColumn("rank", rank().over(window_spec))

# Dense rank (no gaps)
df = df.withColumn("dense_rank", dense_rank().over(window_spec))

# Previous/Next row values
df = df.withColumn("prev_salary", lag("salary", 1).over(window_spec))
df = df.withColumn("next_salary", lead("salary", 1).over(window_spec))

# Running totals
running_window = Window.partitionBy("department") \
    .orderBy("date") \
    .rowsBetween(Window.unboundedPreceding, Window.currentRow)

df = df.withColumn("running_total", sum("amount").over(running_window))

# Moving average (last 3 rows)
moving_window = Window.partitionBy("product") \
    .orderBy("date") \
    .rowsBetween(-2, 0)

df = df.withColumn("moving_avg", avg("sales").over(moving_window))

Spark SQL

# Register DataFrame as temp view
df.createOrReplaceTempView("employees")

# Run SQL queries
result = spark.sql("""
    SELECT
        department,
        COUNT(*) as employee_count,
        AVG(salary) as avg_salary
    FROM employees
    WHERE age > 25
    GROUP BY department
    HAVING COUNT(*) > 5
    ORDER BY avg_salary DESC
""")
result.show()

# Global temp view (accessible across sessions)
df.createOrReplaceGlobalTempView("global_employees")
spark.sql("SELECT * FROM global_temp.global_employees").show()

# Mix SQL with DataFrame API
spark.sql("SELECT * FROM employees WHERE age > 30") \
    .groupBy("department") \
    .agg(count("*").alias("count")) \
    .show()

User Defined Functions (UDFs)

from pyspark.sql.functions import udf, pandas_udf
from pyspark.sql.types import StringType, IntegerType
import pandas as pd

# Standard UDF
def categorize_age(age):
    if age < 30:
        return "Young"
    elif age < 50:
        return "Middle"
    else:
        return "Senior"

categorize_udf = udf(categorize_age, StringType())
df = df.withColumn("age_category", categorize_udf(col("age")))

# Register UDF for SQL
spark.udf.register("categorize_age", categorize_age, StringType())
spark.sql("SELECT name, categorize_age(age) FROM employees")

# Pandas UDF (vectorized - much faster!)
@pandas_udf(StringType())
def categorize_age_pandas(ages: pd.Series) -> pd.Series:
    return ages.apply(lambda x: "Young" if x < 30 else "Middle" if x < 50 else "Senior")

df = df.withColumn("age_category", categorize_age_pandas(col("age")))

Reading and Writing Data

# CSV
df.write.option("header", "true").csv("output/data.csv")
df.write.mode("overwrite").csv("output/data.csv")

# Parquet (recommended for analytics)
df.write.parquet("output/data.parquet")
df.write.partitionBy("year", "month").parquet("output/partitioned")

# JSON
df.write.json("output/data.json")

# JDBC (databases)
df.write \
    .format("jdbc") \
    .option("url", "jdbc:postgresql://localhost:5432/mydb") \
    .option("dbtable", "employees") \
    .option("user", "user") \
    .option("password", "password") \
    .save()

# Delta Lake
df.write.format("delta").save("output/delta_table")

# Write modes
# overwrite - Replace existing data
# append    - Add to existing data
# ignore    - Skip if exists
# error     - Throw error if exists (default)

df.write.mode("overwrite").parquet("output/data")

Performance Optimization

# 1. Cache frequently used DataFrames
df.cache()  # or df.persist()
df.count()  # Trigger caching
df.unpersist()  # Release from memory

# 2. Broadcast small tables for joins
from pyspark.sql.functions import broadcast

small_df = spark.read.parquet("dim_product.parquet")
large_df = spark.read.parquet("fact_sales.parquet")

# Broadcast join (small table sent to all nodes)
result = large_df.join(broadcast(small_df), "product_id")

# 3. Repartition for parallelism
df = df.repartition(100)  # Increase partitions
df = df.repartition("department")  # Partition by column
df = df.coalesce(10)  # Decrease partitions (no shuffle)

# 4. Use appropriate data types
from pyspark.sql.types import IntegerType
df = df.withColumn("id", col("id").cast(IntegerType()))

# 5. Filter early (predicate pushdown)
df.filter(col("date") > "2024-01-01") \
    .join(other_df, "id") \
    .select("col1", "col2")  # Select only needed columns

# 6. Avoid UDFs when possible - use built-in functions

# 7. Check execution plan
df.explain()  # Logical plan
df.explain(True)  # Physical plan

Common Patterns

# Deduplication
df_deduped = df.dropDuplicates(["id"])

# Get latest record per group
window = Window.partitionBy("customer_id").orderBy(col("updated_at").desc())
latest = df.withColumn("rn", row_number().over(window)) \
    .filter(col("rn") == 1) \
    .drop("rn")

# Flatten nested JSON
from pyspark.sql.functions import explode, col

df_exploded = df.select(
    col("id"),
    explode(col("items")).alias("item")
).select(
    col("id"),
    col("item.name"),
    col("item.price")
)

# Conditional aggregation
df.groupBy("category").agg(
    count(when(col("status") == "completed", 1)).alias("completed_count"),
    count(when(col("status") == "pending", 1)).alias("pending_count")
).show()

# Date operations
from pyspark.sql.functions import year, month, dayofmonth, datediff, date_add

df = df.withColumn("year", year(col("date")))
df = df.withColumn("month", month(col("date")))
df = df.withColumn("days_diff", datediff(col("end_date"), col("start_date")))

Master PySpark

Our Data Engineering program provides hands-on PySpark training with real-world big data projects.

Explore Data Engineering Program

Related Articles