|
2 | 2 |
|
3 | 3 | import org.apache.log4j.Level;
|
4 | 4 | import org.apache.log4j.Logger;
|
5 |
| -import org.apache.spark.SparkConf; |
6 |
| -import org.apache.spark.api.java.JavaRDD; |
7 |
| -import org.apache.spark.api.java.JavaSparkContext; |
8 |
| -import org.apache.spark.sql.Dataset; |
9 |
| -import org.apache.spark.sql.Encoders; |
10 |
| -import org.apache.spark.sql.SparkSession; |
| 5 | +import org.apache.spark.sql.*; |
11 | 6 |
|
12 | 7 | import static org.apache.spark.sql.functions.avg;
|
| 8 | +import static org.apache.spark.sql.functions.col; |
13 | 9 | import static org.apache.spark.sql.functions.max;
|
14 | 10 |
|
15 | 11 |
|
16 | 12 | public class TypedDataset {
|
17 | 13 | private static final String AGE_MIDPOINT = "ageMidpoint";
|
18 | 14 | private static final String SALARY_MIDPOINT = "salaryMidPoint";
|
19 | 15 | private static final String SALARY_MIDPOINT_BUCKET = "salaryMidpointBucket";
|
20 |
| - private static final float NULL_VALUE = -1.0f; |
21 |
| - private static final String COMMA_DELIMITER = ",(?=([^\"]*\"[^\"]*\")*[^\"]*$)"; |
22 | 16 |
|
23 | 17 | public static void main(String[] args) throws Exception {
|
24 | 18 |
|
25 | 19 | Logger.getLogger("org").setLevel(Level.ERROR);
|
26 |
| - SparkConf conf = new SparkConf().setAppName("StackOverFlowSurvey").setMaster("local[1]"); |
| 20 | + SparkSession session = SparkSession.builder().appName("StackOverFlowSurvey").master("local[1]").getOrCreate(); |
27 | 21 |
|
28 |
| - JavaSparkContext sc = new JavaSparkContext(conf); |
| 22 | + DataFrameReader dataFrameReader = session.read(); |
29 | 23 |
|
30 |
| - SparkSession session = SparkSession.builder().appName("StackOverFlowSurvey").master("local[1]").getOrCreate(); |
| 24 | + Dataset<Row> responses = dataFrameReader.option("header","true").csv("in/2016-stack-overflow-survey-responses.csv"); |
31 | 25 |
|
32 |
| - JavaRDD<String> lines = sc.textFile("in/2016-stack-overflow-survey-responses.csv"); |
| 26 | + Dataset<Row> responseWithSelectedColumns = responses.select(col("country"), col("age_midpoint").as("ageMidPoint").cast("integer"), col("occupation"), col("salary_midpoint").as("salaryMidPoint").cast("integer")); |
33 | 27 |
|
34 |
| - JavaRDD<Response> responseRDD = lines |
35 |
| - .filter(line -> !line.split(COMMA_DELIMITER, -1)[2].equals("country")) |
36 |
| - .map(line -> { |
37 |
| - String[] splits = line.split(COMMA_DELIMITER, -1); |
38 |
| - return new Response(splits[2], convertStringToFloat(splits[6]), splits[9], convertStringToFloat(splits[14])); |
39 |
| - }); |
40 |
| - Dataset<Response> responseDataset = session.createDataset(responseRDD.rdd(), Encoders.bean(Response.class)); |
| 28 | + Dataset<Response> typedDataset = responseWithSelectedColumns.as(Encoders.bean(Response.class)); |
41 | 29 |
|
42 | 30 | System.out.println("=== Print out schema ===");
|
43 |
| - responseDataset.printSchema(); |
| 31 | + typedDataset.printSchema(); |
44 | 32 |
|
45 | 33 | System.out.println("=== Print 20 records of responses table ===");
|
46 |
| - responseDataset.show(20); |
| 34 | + typedDataset.show(20); |
47 | 35 |
|
48 | 36 | System.out.println("=== Print records where the response is from Afghanistan ===");
|
49 |
| - responseDataset.filter(response -> response.getCountry().equals("Afghanistan")).show(); |
| 37 | + typedDataset.filter(response -> response.getCountry().equals("Afghanistan")).show(); |
50 | 38 |
|
51 | 39 | System.out.println("=== Print the count of occupations ===");
|
52 |
| - responseDataset.groupBy(responseDataset.col("occupation")).count().show(); |
53 |
| - |
| 40 | + typedDataset.groupBy(typedDataset.col("occupation")).count().show(); |
54 | 41 |
|
55 | 42 | System.out.println("=== Print records with average mid age less than 20 ===");
|
56 |
| - responseDataset.filter(response -> response.getAgeMidPoint() != NULL_VALUE && response.getAgeMidPoint() < 20).show(); |
| 43 | + typedDataset.filter(response -> response.getAgeMidPoint() !=null && response.getAgeMidPoint() < 20).show(); |
57 | 44 |
|
58 | 45 | System.out.println("=== Print the result with salary middle point in descending order ===");
|
59 |
| - responseDataset.orderBy(responseDataset.col(SALARY_MIDPOINT ).desc()).show(); |
| 46 | + typedDataset.orderBy(typedDataset.col(SALARY_MIDPOINT ).desc()).show(); |
60 | 47 |
|
61 | 48 | System.out.println("=== Group by country and aggregate by average salary middle point and max age middle point ===");
|
62 |
| - responseDataset |
63 |
| - .filter(response -> response.getSalaryMidPoint() != NULL_VALUE) |
64 |
| - .groupBy("country") |
65 |
| - .agg(avg(SALARY_MIDPOINT), max(AGE_MIDPOINT)) |
66 |
| - .show(); |
| 49 | + typedDataset.filter(response -> response.getSalaryMidPoint() != null) |
| 50 | + .groupBy("country") |
| 51 | + .agg(avg(SALARY_MIDPOINT), max(AGE_MIDPOINT)) |
| 52 | + .show(); |
67 | 53 |
|
68 | 54 | System.out.println("=== Group by salary bucket ===");
|
69 |
| - |
70 |
| - responseDataset |
71 |
| - .map(response -> Math.round(response.getSalaryMidPoint()/20000) * 20000, Encoders.INT()) |
72 |
| - .withColumnRenamed("value", SALARY_MIDPOINT_BUCKET) |
73 |
| - .groupBy(SALARY_MIDPOINT_BUCKET) |
74 |
| - .count() |
75 |
| - .orderBy(SALARY_MIDPOINT_BUCKET).show(); |
| 55 | + typedDataset.filter(response -> response.getSalaryMidPoint() != null) |
| 56 | + .map(response -> Math.round(response.getSalaryMidPoint()/20000) * 20000, Encoders.INT()) |
| 57 | + .withColumnRenamed("value", SALARY_MIDPOINT_BUCKET) |
| 58 | + .groupBy(SALARY_MIDPOINT_BUCKET) |
| 59 | + .count() |
| 60 | + .orderBy(SALARY_MIDPOINT_BUCKET).show(); |
76 | 61 | }
|
77 |
| - |
78 |
| - private static float convertStringToFloat(String split) { |
79 |
| - return split.isEmpty() ? NULL_VALUE : Float.valueOf(split); |
80 |
| - } |
81 |
| - |
82 | 62 | }
|
0 commit comments