aws_audit_logs_etl_uc(Python)

Loading...
dbutils.widgets.text("catalog","")
dbutils.widgets.text("database","")
dbutils.widgets.text("log_bucket","s3://<bucket-name>/<delivery-path-prefix>")
dbutils.widgets.text("checkpoint","")
dbutils.widgets.text("start_date","yyyy-mm-dd")
catalog = dbutils.widgets.get("catalog")
database = dbutils.widgets.get("database")
log_bucket = dbutils.widgets.get("log_bucket")
checkpoint = dbutils.widgets.get("checkpoint")
start_date = dbutils.widgets.get("start_date")
spark.sql(f"USE CATALOG {catalog}")
spark.sql(f"CREATE DATABASE IF NOT EXISTS {catalog}.{database}")
from pyspark.sql.functions import udf, col, from_unixtime, from_utc_timestamp, from_json, to_date, lit
from pyspark.sql.types import StringType, StructField, StructType
import json, time
spark.conf.set("spark.databricks.cloudFiles.schemaInference.sampleSize.numFiles", "10000")
streamDF = (
  spark
  .readStream
  .format("cloudFiles")
  .option("cloudFiles.format", "json")
  .option("cloudFiles.inferColumnTypes", True)
  .option("cloudFiles.schemaHints", "workspaceId long") 
  .option("cloudFiles.schemaLocation", f"{checkpoint}/audit_log_schema")
  .option("cloudFiles.schemaEvolutionMode", "rescue")
  .load(f"{log_bucket}")
  .where(col("date")>=to_date(lit(f"{start_date}"), format='yyyy-MM-dd'))
)
streamQuery = (streamDF
 .writeStream
 .format("delta")
 .partitionBy("date", "workspaceId")
 .outputMode("append")
 .option("checkpointLocation", f"{checkpoint}/bronze")
 .option("mergeSchema", True)
 .trigger(availableNow=True)
 .table(f"{catalog}.{database}.bronze")
)
while spark.streams.active != []:
  print("Waiting for streaming query to finish.")
  time.sleep(5)
spark.sql(f"OPTIMIZE {database}.bronze")
def stripNulls(raw):
  return json.dumps({i: raw.asDict()[i] for i in raw.asDict() if raw.asDict()[i] != None})
strip_udf = udf(stripNulls, StringType())
bronzeDF = spark.readStream.table(f"{catalog}.{database}.bronze")
 
query = (
  bronzeDF
  .withColumn("flattened", strip_udf("requestParams"))
  .withColumn("email", col("userIdentity.email"))
  .withColumn("date_time", from_utc_timestamp(from_unixtime(col("timestamp")/1000), "UTC"))
  .drop("requestParams")
  .drop("userIdentity")
)
(query
 .writeStream
 .format("delta")
 .partitionBy("date")
 .outputMode("append")
 .option("checkpointLocation", f"{checkpoint}/silver")
 .option("mergeSchema", True)
 .trigger(availableNow=True)
 .table(f"{catalog}.{database}.silver")
)
while spark.streams.active != []:
  print("Waiting for streaming query to finish.")
  time.sleep(5)
assert(spark.table(f"{catalog}.{database}.bronze").count() == spark.table(f"{catalog}.{database}.silver").count())
spark.sql(f"OPTIMIZE {database}.silver")
def justKeys(string):
  return [i for i in json.loads(string).keys()]
just_keys_udf = udf(justKeys, StringType())
def flatten_table(service_name):
  flattenedStream = spark.readStream.table(f"{catalog}.{database}.silver")
  flattened = spark.table(f"{database}.silver")
  
  schema = StructType()
  
  keys = (
    flattened
    .filter(col("serviceName") == service_name)
    .select(just_keys_udf(col("flattened")))
    .alias("keys")
    .distinct()
    .collect()
  )
  
  keysList = [i.asDict()['justKeys(flattened)'][1:-1].split(", ") for i in keys]
  
  keysDistinct = {j for i in keysList for j in i if j != ""}
  
  if len(keysDistinct) == 0:
    schema.add(StructField('placeholder', StringType()))
  else:
    for i in keysDistinct:
      schema.add(StructField(i, StringType()))
      
  table_name = service_name.replace("-","_")
  
  (flattenedStream
   .filter(col("serviceName") == service_name)
   .withColumn("requestParams", from_json(col("flattened"), schema))
   .drop("flattened")
   .writeStream
   .partitionBy("date")
   .outputMode("append")
   .format("delta")
   .option("checkpointLocation", f"{checkpoint}/gold/{service_name}")
   .option("mergeSchema", True)
   .trigger(availableNow=True)
   .table(f"{catalog}.{database}.{table_name}")
   )
  spark.sql(f"OPTIMIZE {database}.{table_name}")
service_name_list = [i['serviceName'] for i in spark.table(f"{database}.silver").select("serviceName").distinct().collect()]
for service_name in service_name_list:
  flatten_table(service_name)
while spark.streams.active != []:
  print("Waiting for streaming query to finish.")
  time.sleep(5)
display(spark.sql(f"SHOW TABLES IN {database}"))
display(spark.sql(f"SELECT count(*),date FROM {database}.unitycatalog group by date order by date"))
display(spark.sql(f"SELECT * FROM {database}.unitycatalog"))