Skip to content

Commit 3be72e8

Browse files
author
James Lee
committed
split TypedDataset
1 parent caf78df commit 3be72e8

File tree

3 files changed

+94
-50
lines changed

3 files changed

+94
-50
lines changed
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
package com.sparkTutorial.sparkSql;
2+
3+
import org.apache.log4j.Level;
4+
import org.apache.log4j.Logger;
5+
import org.apache.spark.SparkConf;
6+
import org.apache.spark.api.java.JavaRDD;
7+
import org.apache.spark.api.java.JavaSparkContext;
8+
import org.apache.spark.sql.Dataset;
9+
import org.apache.spark.sql.Encoders;
10+
import org.apache.spark.sql.SparkSession;
11+
12+
public class RddToDataset {
13+
14+
private static final String COMMA_DELIMITER = ",(?=([^\"]*\"[^\"]*\")*[^\"]*$)";
15+
16+
public static void main(String[] args) throws Exception {
17+
18+
Logger.getLogger("org").setLevel(Level.ERROR);
19+
SparkConf conf = new SparkConf().setAppName("StackOverFlowSurvey").setMaster("local[1]");
20+
21+
JavaSparkContext sc = new JavaSparkContext(conf);
22+
23+
SparkSession session = SparkSession.builder().appName("StackOverFlowSurvey").master("local[1]").getOrCreate();
24+
25+
JavaRDD<String> lines = sc.textFile("in/2016-stack-overflow-survey-responses.csv");
26+
27+
JavaRDD<Response> responseRDD = lines
28+
.filter(line -> !line.split(COMMA_DELIMITER, -1)[2].equals("country"))
29+
.map(line -> {
30+
String[] splits = line.split(COMMA_DELIMITER, -1);
31+
return new Response(splits[2], convertStringToFloat(splits[6]), splits[9], convertStringToFloat(splits[14]));
32+
});
33+
Dataset<Response> responseDataset = session.createDataset(responseRDD.rdd(), Encoders.bean(Response.class));
34+
35+
System.out.println("=== Print out schema ===");
36+
responseDataset.printSchema();
37+
38+
System.out.println("=== Print 20 records of responses table ===");
39+
responseDataset.show(20);
40+
41+
JavaRDD<Response> responseJavaRDD = responseDataset.toJavaRDD();
42+
43+
for (Response response : responseJavaRDD.collect()) {
44+
System.out.println(response);
45+
}
46+
47+
}
48+
49+
private static Integer convertStringToFloat(String split) {
50+
return split.isEmpty() ? null : Math.round(Float.valueOf(split));
51+
}
52+
53+
}

src/main/java/com/sparkTutorial/sparkSql/Response.java

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44

55
public class Response implements Serializable {
66
private String country;
7-
private float ageMidPoint;
7+
private Integer ageMidPoint;
88
private String occupation;
9-
private float salaryMidPoint;
9+
private Integer salaryMidPoint;
1010

11-
public Response(String country, float ageMidPoint, String occupation, float salaryMidPoint) {
11+
public Response(String country, Integer ageMidPoint, String occupation, Integer salaryMidPoint) {
1212
this.country = country;
1313
this.ageMidPoint = ageMidPoint;
1414
this.occupation = occupation;
@@ -26,11 +26,11 @@ public void setCountry(String country) {
2626
this.country = country;
2727
}
2828

29-
public float getAgeMidPoint() {
29+
public Integer getAgeMidPoint() {
3030
return ageMidPoint;
3131
}
3232

33-
public void setAgeMidPoint(float ageMidPoint) {
33+
public void setAgeMidPoint(Integer ageMidPoint) {
3434
this.ageMidPoint = ageMidPoint;
3535
}
3636

@@ -42,11 +42,22 @@ public void setOccupation(String occupation) {
4242
this.occupation = occupation;
4343
}
4444

45-
public float getSalaryMidPoint() {
45+
public Integer getSalaryMidPoint() {
4646
return salaryMidPoint;
4747
}
4848

49-
public void setSalaryMidPoint(float salaryMidPoint) {
49+
public void setSalaryMidPoint(Integer salaryMidPoint) {
5050
this.salaryMidPoint = salaryMidPoint;
5151
}
52+
53+
54+
@Override
55+
public String toString() {
56+
return "Response{" +
57+
"country='" + country + '\'' +
58+
", ageMidPoint=" + ageMidPoint +
59+
", occupation='" + occupation + '\'' +
60+
", salaryMidPoint=" + salaryMidPoint +
61+
'}';
62+
}
5263
}

src/main/java/com/sparkTutorial/sparkSql/TypedDataset.java

Lines changed: 23 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -2,81 +2,61 @@
22

33
import org.apache.log4j.Level;
44
import org.apache.log4j.Logger;
5-
import org.apache.spark.SparkConf;
6-
import org.apache.spark.api.java.JavaRDD;
7-
import org.apache.spark.api.java.JavaSparkContext;
8-
import org.apache.spark.sql.Dataset;
9-
import org.apache.spark.sql.Encoders;
10-
import org.apache.spark.sql.SparkSession;
5+
import org.apache.spark.sql.*;
116

127
import static org.apache.spark.sql.functions.avg;
8+
import static org.apache.spark.sql.functions.col;
139
import static org.apache.spark.sql.functions.max;
1410

1511

1612
public class TypedDataset {
1713
private static final String AGE_MIDPOINT = "ageMidpoint";
1814
private static final String SALARY_MIDPOINT = "salaryMidPoint";
1915
private static final String SALARY_MIDPOINT_BUCKET = "salaryMidpointBucket";
20-
private static final float NULL_VALUE = -1.0f;
21-
private static final String COMMA_DELIMITER = ",(?=([^\"]*\"[^\"]*\")*[^\"]*$)";
2216

2317
public static void main(String[] args) throws Exception {
2418

2519
Logger.getLogger("org").setLevel(Level.ERROR);
26-
SparkConf conf = new SparkConf().setAppName("StackOverFlowSurvey").setMaster("local[1]");
20+
SparkSession session = SparkSession.builder().appName("StackOverFlowSurvey").master("local[1]").getOrCreate();
2721

28-
JavaSparkContext sc = new JavaSparkContext(conf);
22+
DataFrameReader dataFrameReader = session.read();
2923

30-
SparkSession session = SparkSession.builder().appName("StackOverFlowSurvey").master("local[1]").getOrCreate();
24+
Dataset<Row> responses = dataFrameReader.option("header","true").csv("in/2016-stack-overflow-survey-responses.csv");
3125

32-
JavaRDD<String> lines = sc.textFile("in/2016-stack-overflow-survey-responses.csv");
26+
Dataset<Row> responseWithSelectedColumns = responses.select(col("country"), col("age_midpoint").as("ageMidPoint").cast("integer"), col("occupation"), col("salary_midpoint").as("salaryMidPoint").cast("integer"));
3327

34-
JavaRDD<Response> responseRDD = lines
35-
.filter(line -> !line.split(COMMA_DELIMITER, -1)[2].equals("country"))
36-
.map(line -> {
37-
String[] splits = line.split(COMMA_DELIMITER, -1);
38-
return new Response(splits[2], convertStringToFloat(splits[6]), splits[9], convertStringToFloat(splits[14]));
39-
});
40-
Dataset<Response> responseDataset = session.createDataset(responseRDD.rdd(), Encoders.bean(Response.class));
28+
Dataset<Response> typedDataset = responseWithSelectedColumns.as(Encoders.bean(Response.class));
4129

4230
System.out.println("=== Print out schema ===");
43-
responseDataset.printSchema();
31+
typedDataset.printSchema();
4432

4533
System.out.println("=== Print 20 records of responses table ===");
46-
responseDataset.show(20);
34+
typedDataset.show(20);
4735

4836
System.out.println("=== Print records where the response is from Afghanistan ===");
49-
responseDataset.filter(response -> response.getCountry().equals("Afghanistan")).show();
37+
typedDataset.filter(response -> response.getCountry().equals("Afghanistan")).show();
5038

5139
System.out.println("=== Print the count of occupations ===");
52-
responseDataset.groupBy(responseDataset.col("occupation")).count().show();
53-
40+
typedDataset.groupBy(typedDataset.col("occupation")).count().show();
5441

5542
System.out.println("=== Print records with average mid age less than 20 ===");
56-
responseDataset.filter(response -> response.getAgeMidPoint() != NULL_VALUE && response.getAgeMidPoint() < 20).show();
43+
typedDataset.filter(response -> response.getAgeMidPoint() !=null && response.getAgeMidPoint() < 20).show();
5744

5845
System.out.println("=== Print the result with salary middle point in descending order ===");
59-
responseDataset.orderBy(responseDataset.col(SALARY_MIDPOINT ).desc()).show();
46+
typedDataset.orderBy(typedDataset.col(SALARY_MIDPOINT ).desc()).show();
6047

6148
System.out.println("=== Group by country and aggregate by average salary middle point and max age middle point ===");
62-
responseDataset
63-
.filter(response -> response.getSalaryMidPoint() != NULL_VALUE)
64-
.groupBy("country")
65-
.agg(avg(SALARY_MIDPOINT), max(AGE_MIDPOINT))
66-
.show();
49+
typedDataset.filter(response -> response.getSalaryMidPoint() != null)
50+
.groupBy("country")
51+
.agg(avg(SALARY_MIDPOINT), max(AGE_MIDPOINT))
52+
.show();
6753

6854
System.out.println("=== Group by salary bucket ===");
69-
70-
responseDataset
71-
.map(response -> Math.round(response.getSalaryMidPoint()/20000) * 20000, Encoders.INT())
72-
.withColumnRenamed("value", SALARY_MIDPOINT_BUCKET)
73-
.groupBy(SALARY_MIDPOINT_BUCKET)
74-
.count()
75-
.orderBy(SALARY_MIDPOINT_BUCKET).show();
55+
typedDataset.filter(response -> response.getSalaryMidPoint() != null)
56+
.map(response -> Math.round(response.getSalaryMidPoint()/20000) * 20000, Encoders.INT())
57+
.withColumnRenamed("value", SALARY_MIDPOINT_BUCKET)
58+
.groupBy(SALARY_MIDPOINT_BUCKET)
59+
.count()
60+
.orderBy(SALARY_MIDPOINT_BUCKET).show();
7661
}
77-
78-
private static float convertStringToFloat(String split) {
79-
return split.isEmpty() ? NULL_VALUE : Float.valueOf(split);
80-
}
81-
8262
}

0 commit comments

Comments
 (0)