Install and Compile Cython

This document will explain how to run Spark code with compiled Cython code. The steps are as follows:

  1. Creates an example cython module on Databricks File System - DBFS.
  2. Adds the file to the SparkSession
  3. Creates a wrapper method to load the module on the executors
  4. Runs the mapper on a sample dataset
  5. Generate a larger dataset and compare the performance with native python example

Note: By default, paths use dbfs:/ if no protocol is referenced.

# Write an example cython module to /example/cython/fib.pyx in DBFS.
dbutils.fs.put("/example/cython/fib.pyx", """
def fib_mapper_cython(n):
    '''
    Return the first fibonnaci number > n.
    '''
    cdef int a = 0
    cdef int b = 1
    cdef int j = int(n)
    while b<j:
        a, b  = b, a+b
    return b, 1
""", True)

# Write an example input file to /example/cython/input.txt in DBFS.
# Every line of this file is an integer.
dbutils.fs.put("/example/cython_input/input.txt", """
1
10
100
""", True)

# Take a look at the example input.
dbutils.fs.head("/example/cython_input/input.txt")

Add Cython Source Files to Spark

To make the Cython source files available across the cluster, we will use sc.addPyFile to add these files to Spark. For example,

sc.addPyFile("dbfs:/example/cython/fib.pyx")

If your Cython source files are stored in S3, you can use s3a:// protocol to access those files. s3a:// also works well if you leverage our IAM role. For example,

sc.addPyFile("s3a://dbc-mwc/cython/fib.pyx")

Test Cython compilation on the driver node

This code will test compilation on the driver node first.

import pyximport
import os

pyximport.install()
import fib

Define the wapper function to compile and import the module

The print statements will get executed on the executor nodes. You can view the stdout log messages to track the progress of your module.

import sys, os, shutil, cython

def spark_cython(module, method):
  def wrapped(*args, **kwargs):
    print 'Entered function with: %s' % args
    global cython_function_
    try:
      return cython_function_(*args, **kwargs)
    except:
      import pyximport
      pyximport.install()
      print 'Cython compilation complete'
      cython_function_ = getattr(__import__(module), method)
    print 'Defined function: %s' % cython_function_
    return cython_function_(*args, **kwargs)
  return wrapped

Run the Cython example

The below snippet runs the fibonacci example on a few data points.

# use the CSV reader to generate a Spark DataFrame. Roll back to RDDs from DataFrames and grab the single element from the GenericRowObject
lines = spark.read.csv("/example/cython_input/").rdd.map(lambda y: y.__getitem__(0))

mapper = spark_cython('fib', 'fib_mapper_cython')
fib_frequency = lines.map(mapper).reduceByKey(lambda a, b: a+b).collect()
print fib_frequency

Performance Comparison

Below we’ll test out the speed difference between the 2 implementations. We will use the spark.range() api to generate data points from 10,000 to 100,000,000 with 50 Spark partitions. We will write this output to DBFS as a CSV.

For this test, please disable Autoscaling in order to make sure the cluster has the fixed number of Spark executors.

dbutils.fs.rm("/tmp/cython_input/", True)
spark.range(10000, 100000000, 1, 50).write.csv("/tmp/cython_input/")

Normal PySpark Code

def fib_mapper_python(n):
  a = 0
  b = 1
  print "Trying: %s" % n
  while b < int(n):
    a, b = b, a+b
  return (b, 1)

print fib_mapper_python(2000)

lines = spark.read.csv("/tmp/cython_input/").rdd.map(lambda y: y.__getitem__(0))
fib_frequency = lines.map(lambda x: fib_mapper_python(x)).reduceByKey(lambda a, b: a+b).collect()
print fib_frequency

Test Cython Code

Now we will test the compiled cython code.

lines = spark.read.csv("/tmp/cython_input/").rdd.map(lambda y: y.__getitem__(0))
mapper = spark_cython('fib', 'fib_mapper_cython')
fib_frequency = lines.map(mapper).reduceByKey(lambda a, b: a+b).collect()
print fib_frequency

The test dataset we generated has 50 Spark partitions, which creates 50 csv files seen below. You can view the dataset with dbutils.fs.ls("/tmp/cython_input/")