Graph Analysis with GraphX Tutorial

In this tutorial we’ll go over basic graph analysis using the GraphX API. The goal of this tutorial is to show you how to use the GraphX API to perform graph analysis. We’re going to be doing this with publicly available bike data from the Bay Area Bike Share portal. We’re going to be specifically analyzing the second year of data.

Note

GraphX computation is only supported using the Scala and RDD APIs.

You can import this as a notebook by scrolling to the notebook at the bottom.

Graph Processing Primer

Graph processing is important aspect of analysis that applies to a lot of use cases. Fundamentally graph theory and processing are about defining relationships between different nodes and edges. Nodes or vertices are the units while edges define the relationships between nodes. This works great for social network analysis and running algorithms like PageRank to better understand and weigh relationships.

Some business use cases could be to look at central people in social networks [who is most popular in a group of friends], importance of papers in bibliographic networks [which papers are most referenced], and of course ranking web pages!

As mentioned, in this example we’ll be using bay area bike share data. This data is free for use by the public on the website linked above. The way we’re going to orient our analysis is by making every vertex a station and each trip will become an edge connecting two stations. This creates a directed graph.

Further Reference: - Graph Theory on Wikipedia - PageRank on Wikipedia

Setup & Data

To get the data into your workspace, you’re going to want to download the data and unzip it on your computer. After you’ve unzipped the files, upload the 201508_station_data.csv and the 201508_trip_data.csv using the Tables UI in Databricks.

Be sure not to do any changing of column types or preprocessing before upload as it’s not necessary in this example. After the files upload as a table, we can import them as DataFrames into Databricks below. I’ve named the tables sf_201508_station_data and sf_201508_trip_data but feel free to call them something else!

val bikeStations = sqlContext.sql("SELECT * FROM sf_201508_station_data")
val tripData = sqlContext.sql("SELECT * FROM sf_201508_trip_data")
display(bikeStations)
display(tripData)

It can often times be helpful to look at the exact schema to ensure that you have the right types associated with the right columns. In this case we haven’t done any manipulation so we won’t have anything besides string.

bikeStations.printSchema()
tripData.printSchema()

Imports

We’re going to need to perform imports before we can continue. We’re going to import a variety of SQL functions that are going to make working with DataFrames much easier and we’re going to import everything that we’re going to need from GraphX.

import org.apache.spark.graphx._
import org.apache.spark.rdd.RDD

Building the Graph

Now that we’ve imported our data, we’re going to need to build our graph. To do so we’re going to do two things. We are going to build the structure of the vertices (or nodes) and we’re going to build the structure of the edges.

You may have noticed that we have station ids inside of our bikeStations data but not inside of our trip data. This complicates things because we have to ensure that we have numerical data for GraphX. That means that the vertices have to be identifiable with a numeric value not a string value like station name. Therefore we have to perform some joins to ensure that we have those ids associated with each trip.

val justStations = bikeStations
  .selectExpr("float(station_id) as station_id", "name")
  .distinct()

val completeTripData = tripData
  .join(justStations, tripData("Start Station") === bikeStations("name"))
  .withColumnRenamed("station_id", "start_station_id")
  .drop("name")
  .join(justStations, tripData("End Station") === bikeStations("name"))
  .withColumnRenamed("station_id", "end_station_id")
  .drop("name")
val stations = completeTripData
  .select("start_station_id", "end_station_id")
  .rdd
  .distinct() // helps filter out duplicate trips
  .flatMap(x => Iterable(x(0).asInstanceOf[Number].longValue, x(1).asInstanceOf[Number].longValue)) // helps us maintain types
  .distinct()
  .toDF() // return to a DF to make merging + joining easier

stations.take(1) // this is just a station_id at this point

Now we can create our set of vertices and attach a bit of metadata to each of them, which in this case is the name of the station.

val stationVertices: RDD[(VertexId, String)] = stations
  .join(justStations, stations("value") === justStations("station_id"))
  .select("station_id", "name")
  .rdd
  .map(row => (row(0).asInstanceOf[Number].longValue, row(1).asInstanceOf[String])) // maintain type information

stationVertices.take(1)

Now we can create the trip edges from all of our individual rides. We’ll get the station values, then just add a dummy value of 1.

val stationEdges:RDD[Edge[Long]] = completeTripData
  .select("start_station_id", "end_station_id")
  .rdd
  .map(row => Edge(row(0).asInstanceOf[Number].longValue, row(1).asInstanceOf[Number].longValue, 1))

Now we can build our graph. You’ll notice below that I make a default station. This is for any edges that don’t actually point to one of our vertices, imagine some sort of collection error, or a station that has gone out of service. It’s worth understanding and analyzing the data collection process (or historical collection process) to better understand whether or not this merits more thought when applied to your own data.

I’m also going to cache our graph for faster access.

val defaultStation = ("Missing Station")
val stationGraph = Graph(stationVertices, stationEdges, defaultStation)
stationGraph.cache()
println("Total Number of Stations: " + stationGraph.numVertices)
println("Total Number of Trips: " + stationGraph.numEdges)
// sanity check
println("Total Number of Trips in Original Data: " + tripData.count)

Now that we’re all set up and have computed some basic statistics, let’s run some algorithms!

PageRank

GraphX includes a number of built-in algorithms to leverage. PageRank is one of the more popular ones popularized by the Google Search Engine and created by Larry Page. To quote Wikipedia:

PageRank works by counting the number and quality of links to a page to determine a rough estimate of how important the website is. The underlying assumption is that more important websites are likely to receive more links from other websites.

What’s awesome about this concept is that it readily applies to any graph type structure be them web pages or bike stations. Let’s go ahead and run PageRank on our data, we can either run it for a set number of iterations or until convergence. Passing an Integer into pageRank will run for a set number of iterations while a Double will run until convergence.

val ranks = stationGraph.pageRank(0.0001).vertices
ranks
  .join(stationVertices)
  .sortBy(_._2._1, ascending=false) // sort by the rank
  .take(10) // get the top 10
  .foreach(x => println(x._2._2))

We can see above that the Caltrain stations seem to be significant! This makes sense as these are natural connectors and likely one of the most popular uses of these bike share programs to get you from A to B in a way that you don’t need a car!

Trips From Station to Station

One question is what are the most common destinations in the dataset from location to location. We can do this by performing a grouping operator and adding the edge counts together. This will yield a new graph except each edge will now be the sum of all of the semantically same edges. Think about it this way: we have a number of trips that are the exact same from station A to station B, we just want to count those up!

In the below query you’ll see that we’re going to grab the station to station trips that are most common and print out the top 10.

stationGraph
  .groupEdges((edge1, edge2) => edge1 + edge2)
  .triplets
  .sortBy(_.attr, ascending=false)
  .map(triplet =>
    "There were " + triplet.attr.toString + " trips from " + triplet.srcAttr + " to " + triplet.dstAttr + ".")
  .take(10)
  .foreach(println)

In Degrees and Out Degrees

Remember that in this instance we’ve got a directed graph. That means that our trips our directional - from one location to another. Therefore we get access to a wealth of analysis that we can use. We can find the number of trips that go into a specific station and leave from a specific station.

Naturally we can sort this information and find the stations with lots of inbound and outbound trips! Check out this definition of Vertex Degrees for more information.

Now that we’ve defined that process, let’s go ahead and find the stations that have lots of inbound and outbound traffic.

stationGraph
  .inDegrees // computes in Degrees
  .join(stationVertices)
  .sortBy(_._2._1, ascending=false)
  .take(10)
  .foreach(x => println(x._2._2 + " has " + x._2._1 + " in degrees."))
stationGraph
  .outDegrees // out degrees
  .join(stationVertices)
  .sortBy(_._2._1, ascending=false)
  .take(10)
  .foreach(x => println(x._2._2 + " has " + x._2._1 + " out degrees."))

One interesting follow up question we could ask is what is the station with the highest ratio of in degrees but fewest out degrees. As in, what station acts as a pure trip sink. A station where trips end at but rarely start from.

stationGraph
  .inDegrees
  .join(stationGraph.outDegrees) // join with out Degrees
  .join(stationVertices) // join with our other stations
  .map(x => (x._2._1._1.toDouble/x._2._1._2.toDouble, x._2._2)) // ratio of in to out
  .sortBy(_._1, ascending=false)
  .take(5)
  .foreach(x => println(x._2 + " has a in/out degree ratio of " + x._1))

We can do something similar by getting the stations with the lowest in degrees to out degrees ratios, meaning that trips start from that station but don’t end there as often. This is essentially the opposite of what we have above.

stationGraph
  .inDegrees
  .join(stationGraph.inDegrees) // join with out Degrees
  .join(stationVertices) // join with our other stations
  .map(x => (x._2._1._1.toDouble/x._2._1._2.toDouble, x._2._2)) // ratio of in to out
  .sortBy(_._1)
  .take(5)
  .foreach(x => println(x._2 + " has a in/out degree ratio of " + x._1))