Some Data Gymnastics with Spark 2.0s DataSets and Equivalent Code in RDDs

While working with some healthcare use case of Patient 360 Degree analytics here is a situation I wanted to solve. I am giving a very simplified example below

A Patient has a Claims data set which contains –

PatientID, Diagnosis1, Diagnosis2, Other Fields…

and a given Patient can have multiple rows of Claims Data

A Patient has a Labs dataset which contains –

PatientID, Lab1, Lab2, Other Fields…

and a given Patient can have multiple rows of Lab Data

A Patient has a  RxClaims dataset which contains –

PatientID, Drug1, Drug2, Other Fields…

a given Patient can have multiple rows of RxClaims Data

My Claims Data would typically look like ( again simplified example )

“PID1”, “diag1”, “diag2”
“PID1”, “diag2”, “diag3”
“PID2”, “diag2”, “diag3”
“PID2”, “diag1”, “diag3”
“PID3”, “diag1”, “diag2”
“PID2”, “diag2”, “diag4”

My Lab Data would look like ( again simplified example )

“PID1”, “lab1”, “lab2”
“PID1”, “lab2”, “lab3”
“PID2”, “lab2”, “lab3”
“PID2”, “lab1”, “lab3”
“PID3”, “lab1”, “lab2”

My RxClaims Data Would look like ( again simplified example )

“PID1”, “drug1”, “drug2”
“PID1”, “drug2”, “drug3”
“PID2”, “drug2”, “drug3”
“PID2”, “drug1”, “drug3”
“PID3”, “drug1”, “drug2”

I have shown the example with only 3 Data Sets ( Claims, Lab, RXClaims ) – I will have more data sets

This is what I wanted to get from the above Data which in RDD world would have the following data structure

org.apache.spark.rdd.RDD[(String, (Iterable[Claim], Iterable[Lab], Iterable[RxClaims]))]

SparkQuestionToMailingList.Image1

Spark 2.0 with DataSets

( tested with the Preview Version of Spark )

case class Claim(pid:String, diag1:String, diag2:String)
case class Lab(pid:String, lab1:String, lab2:String)
case class RxClaim(pid:String, drug1:String, drug2:String)

case class PatientClaims(pid:String, claims:Array[Claim])
case class PatientLab(pid:String, labs:Array[Lab])
case class PatientRxClaim(pid:String, rxclaims:Array[RxClaim])

case class Patient(pid:String, claims:Array[Claim], labs:Array[Lab], rxclaims:Array[RxClaim])

val claimsData = Seq(Claim("PID1", "diag1", "diag2"), Claim("PID1", "diag2", "diag3"), Claim("PID1", "diag1", "diag5"),Claim("PID2", "diag3", "diag4"), Claim("PID2", "diag2", "diag1"))


val labsData = Seq(Lab("PID1", "lab1", "lab2"), Lab("PID1", "lab1", "lab2"), Lab("PID2", "lab3", "lab4"), Lab("PID2", "lab3", "lab6"))


val rxClaimsData = Seq(RxClaim("PID1", "drug4", "drug1"), RxClaim("PID1", "drug3", "drug1"), RxClaim("PID2", "drug3", "drug5"), RxClaim("PID2", "drug2", "drug1"), RxClaim("PID1", "drug5", "drug2"))

val claimRDD = spark.sparkContext.parallelize(claimsData)
val labRDD = spark.sparkContext.parallelize(labsData)
val rxclaimRDD = spark.sparkContext.parallelize(rxClaimsData)

val claimPairRDD = claimRDD.map(x => (x.pid, x.diag1, x.diag2))
val claimDS = claimRDD.toDF("pid", "diag1", "diag2").as[Claim]
val claimsDSGroupedByPID = claimDS.groupByKey(v => v.pid)
val gClaims = claimsDSGroupedByPID.mapGroups({case(k,iter) => PatientClaims(k,iter.map(x => Claim(x.pid, x.diag1, x.diag2)).toArray)})


val labPairRDD = labRDD.map(x => (x.pid, x))
val labDS = labRDD.toDF("pid","lab1","lab2").as[Lab]
val labDSGroupedByPID = labDS.groupByKey(v => v.pid)
val gLabs = labDSGroupedByPID.mapGroups({case(k,iter) => PatientLab(k,iter.map(x => Lab(x.pid, x.lab1, x.lab2)).toArray)})

val rxclaimPairRDD = rxclaimRDD.map(x => (x.pid, x))
val rxClaimDS = rxclaimRDD.toDF("pid","drug1","drug2").as[RxClaim]
val rxClaimsDSGroupedByPID = rxClaimDS.groupByKey(v => v.pid)
val gRxClaim = rxClaimsDSGroupedByPID.mapGroups({case(k,iter) => PatientRxClaim(k,iter.map(x => RxClaim(x.pid, x.drug1, x.drug2)).toArray)})

val allJoined = gClaims.join(gLabs, "pid").join(gRxClaim, "pid")

val allJoinedDS = allJoined.as[Patient]

allJoinedDS show false

Spark 2.0 code with RDD

(no DataSets / DataFrame)

case class Claim(pid:String, diag1:String, diag2:String)
case class Lab(pid:String, lab1:String, lab2:String)
case class RxClaims(pid:String, drug1:String, drug2:String)

val claim = Seq(Claim("PID1", "diag1", "diag2"), Claim("PID1", "diag2", "diag3"),Claim("PID2", "diag2", "diag3"),Claim("PID2", "diag1", "diag3"), Claim("PID3", "diag1", "diag2"), Claim("PID2", "diag2", "diag4"))
val claimRDD = spark.sparkContext.parallelize(claim)
val claimPairRDD = claimRDD.map(x => (x.pid, x))

val lab = Seq(Lab("PID1", "lab1", "lab2"), Lab("PID1", "lab2", "lab3"),Lab("PID2", "lab2", "lab3"),Lab("PID2", "lab1", "lab3"), Lab("PID3", "lab1", "lab2"))
val labRDD = spark.sparkContext.parallelize(lab)
val labPairRDD = labRDD.map(x => (x.pid, x))

val rxclaim = Seq(RxClaims("PID1", "drug1", "drug2"), RxClaims("PID1", "drug2", "drug3"),RxClaims("PID2", "drug2", "drug3"),RxClaims("PID2", "drug1", "drug3"), RxClaims("PID3", "drug1", "drug2"), RxClaims("PID2", "drug2", "drug4"))
val rxclaimRDD = spark.sparkContext.parallelize(rxclaim)
val rxclaimPairRDD = rxclaimRDD.map(x => (x.pid, x))

val result = claimPairRDD.cogroup(labPairRDD, rxclaimPairRDD)

result foreach println
Advertisements

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s