vpc-peering(Scala)

Step 1: Enter your AWS Credentials

/* The following 3 parameters are always required to set up VPC peering */

// AWS Access Key and secret of an IAM User that has permission to create VPC peering connections.
val databricksVpcAWSAccessKey = ""
val databricksVpcAWSAccessSecretKey = ""

// VPC ID of the VPC you wish to connect Databricks to.
val externalVpcID = "" 
databricksVpcAWSAccessKey: String = "" databricksVpcAWSAccessSecretKey: String = "" externalVpcID: String = ""
/* The following 3 parameters are required iff the other VPC is in a DIFFERENT AWS account from Databricks.
   If the other VPC is in the same AWS account, leave these parameters as empty strings. */
val externalVpcAWSAccessKey = ""
val externalVpcAWSAccessSecretKey=""
// The following parameter should be all numbers.
val dstVpcOwnerAccountId = ""
externalVpcAWSAccessKey: String = "" externalVpcAWSAccessSecretKey: String = "" dstVpcOwnerAccountId: String = ""

Step 2: Run the cell below that contains useful libraries for VPC Peering.

import com.amazonaws.AmazonServiceException
import com.amazonaws.auth.AWSCredentials
import com.amazonaws.regions.Regions
import com.amazonaws.services.ec2.AmazonEC2Client
import com.amazonaws.services.ec2.model._
import com.fasterxml.jackson.databind.ObjectMapper
import com.fasterxml.jackson.module.scala.experimental.ScalaObjectMapper

import scala.collection.JavaConversions._

/**
 * Helper object
 */

object VpcPeeringUtils {

  val mapper = new ObjectMapper() with ScalaObjectMapper

  /**
   * Get the VPC ID of the Databricks VPC
   * @return vpcid
   */
  def getLocalVpc:String = {
    val mac = getUrl("http://169.254.169.254/latest/meta-data/network/interfaces/macs")
    val vpcId = getUrl(s"http://169.254.169.254/latest/meta-data/network/interfaces/macs/$mac/vpc-id")
    vpcId
  }

  /**
   * Get the region Databricks platform is deployed in
   * @return region
   */

  def getRegion:String = {
    val json = getUrl("http://169.254.169.254/latest/dynamic/instance-identity/document")
    mapper.readValue[java.util.Map[String, String]](json).get("region")
  }

  /**
   * Get the customer AWS Account ID where Databricks platform is deployed in
   * @return AWS Account ID
   */

  def getLocalVpcAccountId:String = {
    val json = getUrl("http://169.254.169.254/latest/dynamic/instance-identity/document")
    mapper.readValue[java.util.Map[String, String]](json).get("accountId")
  }

  /**
   * Get the security group name assigned to Spark EC2 nodes.
   * If this a legacy customer with a single security group,
   * return the only security group attached to the instance
   * @return Security Group Name
   */

  def getLocalSecurityGroups:String = {
    val groups = getUrl("http://169.254.169.254/latest/meta-data/security-groups").split("\n")

    if (groups.size>1){
      getUrl("http://169.254.169.254/latest/meta-data/security-groups").split("\n").filter(f => f.contains("unmanaged")).mkString
    } else {
      getUrl("http://169.254.169.254/latest/meta-data/security-groups").split("\n")(0)
    }
  }

  /**
   * Helper function to call local EC2 metadata service
   * @param url
   * @return
   */
  private def getUrl(url: String) : String =  {
    import org.apache.commons.httpclient.methods.GetMethod
    import org.apache.commons.httpclient._

    val client = new HttpClient()
    val get = new GetMethod(url)

    try {
      client.executeMethod(get)
      scala.io.Source.fromInputStream(get.getResponseBodyAsStream).mkString
    } finally {
      get.releaseConnection()
    }
  }

}

/**
 *
 * THe main class responsible for handling VPC peering.
 *
 * This tool can be used to create VPC peering with another VPC within the same AWS account or peering with a VPC that exists in a different AWS account
 *
 * If both VPCs belong to the same AWS account the externalVpcCreds and dstVpcOwnerAccountId are optional. Since we're dealing with the same AWS account, we only need one set of AWS credentials.
 *
 * If the destination VPC is in a different AWS account, externalVpcCreds and dstVpcOwnerAccountId are required
 *
 *
 * @param databricksVpcCreds AWS credentials required to perform VPC peering. This is the AWS credentials in the AWS account where Databricks platform is deployed in.
 * @param externalVpcID The destination VPC ID where the VPC-peering connection gets established with
 * @param externalVpcCreds AWS credential of the destination VPC. This field is only required if the external VPC is in a different AWS account than Databricks VPC
 * @param dstVpcOwnerAccountId The AWS account ID of the external account where the external VPC is in. This field is only required if the external VPC is in a different AWS account than Databricks VPC
 * @param subnets
 */

class VpcPeering (databricksVpcCreds:AWSCredentials, externalVpcID:String, externalVpcCreds:Option[AWSCredentials]=None, dstVpcOwnerAccountId:Option[String]=None,subnets:Option[Seq[String]]=None) {

  val srcVpcClient = new AmazonEC2Client(databricksVpcCreds).withRegion[AmazonEC2Client](Regions.fromName(VpcPeeringUtils.getRegion))
  val dstVpcClient = new AmazonEC2Client(externalVpcCreds.getOrElse(databricksVpcCreds)).withRegion[AmazonEC2Client](Regions.fromName(VpcPeeringUtils.getRegion))

  def createVpcPeeringRequest:String  = {
    val createVpcPeeringConnectionRequest = new CreateVpcPeeringConnectionRequest()
      .withVpcId(VpcPeeringUtils.getLocalVpc).withPeerVpcId(externalVpcID)
      .withPeerOwnerId(dstVpcOwnerAccountId.getOrElse(VpcPeeringUtils.getLocalVpcAccountId))

    val createVpcPeeringConnectionResult = srcVpcClient.createVpcPeeringConnection(createVpcPeeringConnectionRequest)
    val vpcPeeringId = createVpcPeeringConnectionResult.getVpcPeeringConnection.getVpcPeeringConnectionId
    vpcPeeringId
  }

  def acceptVpcPeeringRequest(peeringId:String):String = {
    val acceptVpcPeeringConnectionRequest = new AcceptVpcPeeringConnectionRequest().withVpcPeeringConnectionId(peeringId)
    val result = dstVpcClient.acceptVpcPeeringConnection(acceptVpcPeeringConnectionRequest)
    result.getVpcPeeringConnection.getVpcPeeringConnectionId
  }

  def getVpcCidr(client:AmazonEC2Client,vpcId:String) = {
    val describeVpcsRequest = new DescribeVpcsRequest().withVpcIds(vpcId)
    val res = client.describeVpcs(describeVpcsRequest)
    res.getVpcs.get(0).getCidrBlock
  }

  def setupRouting(peeringId:String)  = {

    def getAllRouteTables(client:AmazonEC2Client,vpdId:String):Seq[String] = {
      val describeRouteTablesRequest = new DescribeRouteTablesRequest().withFilters(new Filter().withName("vpc-id").withValues(vpdId))
      client.describeRouteTables(describeRouteTablesRequest).getRouteTables.map(route=>route.getRouteTableId).toSeq
    }

    def addRuleToSecurityGroup(client:AmazonEC2Client,cidr:String,groupName:String) = {

      val describeSecurityGroupsRequest = new DescribeSecurityGroupsRequest().withFilters(new Filter().withName("group-name").withValues(groupName))
      val groupIds = client.describeSecurityGroups(describeSecurityGroupsRequest).getSecurityGroups.map(f=>f.getGroupId)

      groupIds.foreach(groupId=>{
        val authorizeSecurityGroupIngressRequest = new AuthorizeSecurityGroupIngressRequest().withCidrIp(cidr).withGroupId(groupId).withFromPort(0).withToPort(65336).withIpProtocol("tcp")
        client.authorizeSecurityGroupIngress(authorizeSecurityGroupIngressRequest)
      })
    }

    val srcMainRouteTableIds = getAllRouteTables(srcVpcClient,VpcPeeringUtils.getLocalVpc)
    val dstMainRouteTableIds = subnets.getOrElse(getAllRouteTables(dstVpcClient,externalVpcID))

    val srcVpcCidr = getVpcCidr(srcVpcClient,VpcPeeringUtils.getLocalVpc)
    val dstVpcCidr = getVpcCidr(dstVpcClient,externalVpcID)

    // Associate destination cidr block with src VPC
    try {
      srcMainRouteTableIds.foreach(srcMainRouteTableId => {
        val createDestRouteRequest = new CreateRouteRequest().withRouteTableId(srcMainRouteTableId).withDestinationCidrBlock(dstVpcCidr).withVpcPeeringConnectionId(peeringId)
        srcVpcClient.createRoute(createDestRouteRequest)
      })

      // Associate source cidr block with dst VPC
      dstMainRouteTableIds.foreach(dstMainRouteTableId => {
        val createSrcRouteRequest = new CreateRouteRequest().withRouteTableId(dstMainRouteTableId).withDestinationCidrBlock(srcVpcCidr).withVpcPeeringConnectionId(peeringId)
        dstVpcClient.createRoute(createSrcRouteRequest)
      })

      addRuleToSecurityGroup(srcVpcClient,dstVpcCidr,VpcPeeringUtils.getLocalSecurityGroups)
    } catch {
      case e:AmazonServiceException => {
        e.getErrorCode match {
          case "RouteAlreadyExists" => println("Route table had the necessary route rule")
          case "InvalidPermission.Duplicate" => println("Security Group Rule Already Exists in the Databricks VPC Security Group")
          case _ => {
            e.printStackTrace()
          }
        }
      }
    }
    
    printFinalInstruction()

  }
 
 private def printFinalInstruction():Unit = {
    val localVpcCidr = getVpcCidr(srcVpcClient, VpcPeeringUtils.getLocalVpc)
    println(s"In order for Spark cluster nodes to access instances in your VPC, please add the following IP CIDR to the required security groups in $externalVpcID: $localVpcCidr")
  }

}
import com.amazonaws.AmazonServiceException import com.amazonaws.auth.AWSCredentials import com.amazonaws.regions.Regions import com.amazonaws.services.ec2.AmazonEC2Client import com.amazonaws.services.ec2.model._ import com.fasterxml.jackson.databind.ObjectMapper import com.fasterxml.jackson.module.scala.experimental.ScalaObjectMapper import scala.collection.JavaConversions._ defined module VpcPeeringUtils defined class VpcPeering

Run the code below to create the VPC Peering connection.

import com.amazonaws.auth.BasicAWSCredentials
val vpcPeering = if (dstVpcOwnerAccountId!="" & externalVpcAWSAccessKey!="" & externalVpcAWSAccessSecretKey!="") { 
  new VpcPeering(
  databricksVpcCreds=new BasicAWSCredentials(databricksVpcAWSAccessKey,databricksVpcAWSAccessSecretKey),
  externalVpcID=externalVpcID,
  externalVpcCreds=Some(new BasicAWSCredentials(externalVpcAWSAccessKey,externalVpcAWSAccessSecretKey)),
  dstVpcOwnerAccountId=Some(dstVpcOwnerAccountId)
  )
} else {
   new VpcPeering(
  databricksVpcCreds=new BasicAWSCredentials(databricksVpcAWSAccessKey,databricksVpcAWSAccessSecretKey),
  externalVpcID=externalVpcID
  )
}
val peeringId = vpcPeering.createVpcPeeringRequest
// It may take a minute or two for following to succeed.
vpcPeering.acceptVpcPeeringRequest(peeringId)
vpcPeering.setupRouting(peeringId)

Step 4: Modify your security groups to allow the instances to access each other.

You can whitelist the entire VPC IPAddress range - such as "10.137.0.0/17" if you want all the instances in one VPC to be able to access one another.