Skip to content

Commit e8073fa

Browse files
divide mean image by number of training data points
1 parent 463c9bf commit e8073fa

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

src/main/scala/apps/ImageNetApp.scala

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -57,21 +57,20 @@ object ImageNetApp {
5757
var trainDF = sqlContext.createDataFrame(trainRDD.map{ case (a, b) => Row(a, b)}, schema)
5858
var testDF = sqlContext.createDataFrame(testRDD.map{ case (a, b) => Row(a, b)}, schema)
5959

60+
val numTrainData = trainDF.count()
61+
logger.log("numTrainData = " + numTrainData.toString)
62+
val numTestData = testDF.count()
63+
logger.log("numTestData = " + numTestData.toString)
64+
6065
logger.log("computing mean image")
6166
val meanImage = trainDF.map(row => row(0).asInstanceOf[Array[Byte]].map(e => e.toLong))
6267
.reduce((a, b) => (a, b).zipped.map(_ + _))
63-
.map(e => e.toFloat)
68+
.map(e => (e.toDouble / numTrainData).toFloat)
6469

6570
logger.log("coalescing") // if you want to shuffle your data, replace coalesce with repartition
6671
trainDF = trainDF.coalesce(numWorkers)
6772
testDF = testDF.coalesce(numWorkers)
6873

69-
val numTrainData = trainDF.count()
70-
logger.log("numTrainData = " + numTrainData.toString)
71-
72-
val numTestData = testDF.count()
73-
logger.log("numTestData = " + numTestData.toString)
74-
7574
val trainPartitionSizes = trainDF.mapPartitions(iter => Array(iter.size).iterator).persist()
7675
val testPartitionSizes = testDF.mapPartitions(iter => Array(iter.size).iterator).persist()
7776
trainPartitionSizes.foreach(size => workerStore.put("trainPartitionSize", size))

0 commit comments

Comments
 (0)