Skip to content

Commit 1dde971

Browse files
Vikas Nelamangalamengxr
authored andcommitted
[SPARK-11549][DOCS] Replace example code in mllib-evaluation-metrics.md using include_example
Author: Vikas Nelamangala <vikasnelamangala@Vikass-MacBook-Pro.local> Closes apache#9689 from vikasnp/master. (cherry picked from commit ed47b1e) Signed-off-by: Xiangrui Meng <meng@databricks.com>
1 parent 0665fb5 commit 1dde971

16 files changed

+1319
-925
lines changed

docs/mllib-evaluation-metrics.md

Lines changed: 15 additions & 925 deletions
Large diffs are not rendered by default.
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.examples.mllib;
19+
20+
// $example on$
21+
import scala.Tuple2;
22+
23+
import org.apache.spark.api.java.*;
24+
import org.apache.spark.api.java.function.Function;
25+
import org.apache.spark.mllib.classification.LogisticRegressionModel;
26+
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS;
27+
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics;
28+
import org.apache.spark.mllib.regression.LabeledPoint;
29+
import org.apache.spark.mllib.util.MLUtils;
30+
// $example off$
31+
import org.apache.spark.SparkConf;
32+
import org.apache.spark.SparkContext;
33+
34+
public class JavaBinaryClassificationMetricsExample {
35+
public static void main(String[] args) {
36+
SparkConf conf = new SparkConf().setAppName("Java Binary Classification Metrics Example");
37+
SparkContext sc = new SparkContext(conf);
38+
// $example on$
39+
String path = "data/mllib/sample_binary_classification_data.txt";
40+
JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD();
41+
42+
// Split initial RDD into two... [60% training data, 40% testing data].
43+
JavaRDD<LabeledPoint>[] splits =
44+
data.randomSplit(new double[]{0.6, 0.4}, 11L);
45+
JavaRDD<LabeledPoint> training = splits[0].cache();
46+
JavaRDD<LabeledPoint> test = splits[1];
47+
48+
// Run training algorithm to build the model.
49+
final LogisticRegressionModel model = new LogisticRegressionWithLBFGS()
50+
.setNumClasses(2)
51+
.run(training.rdd());
52+
53+
// Clear the prediction threshold so the model will return probabilities
54+
model.clearThreshold();
55+
56+
// Compute raw scores on the test set.
57+
JavaRDD<Tuple2<Object, Object>> predictionAndLabels = test.map(
58+
new Function<LabeledPoint, Tuple2<Object, Object>>() {
59+
public Tuple2<Object, Object> call(LabeledPoint p) {
60+
Double prediction = model.predict(p.features());
61+
return new Tuple2<Object, Object>(prediction, p.label());
62+
}
63+
}
64+
);
65+
66+
// Get evaluation metrics.
67+
BinaryClassificationMetrics metrics = new BinaryClassificationMetrics(predictionAndLabels.rdd());
68+
69+
// Precision by threshold
70+
JavaRDD<Tuple2<Object, Object>> precision = metrics.precisionByThreshold().toJavaRDD();
71+
System.out.println("Precision by threshold: " + precision.toArray());
72+
73+
// Recall by threshold
74+
JavaRDD<Tuple2<Object, Object>> recall = metrics.recallByThreshold().toJavaRDD();
75+
System.out.println("Recall by threshold: " + recall.toArray());
76+
77+
// F Score by threshold
78+
JavaRDD<Tuple2<Object, Object>> f1Score = metrics.fMeasureByThreshold().toJavaRDD();
79+
System.out.println("F1 Score by threshold: " + f1Score.toArray());
80+
81+
JavaRDD<Tuple2<Object, Object>> f2Score = metrics.fMeasureByThreshold(2.0).toJavaRDD();
82+
System.out.println("F2 Score by threshold: " + f2Score.toArray());
83+
84+
// Precision-recall curve
85+
JavaRDD<Tuple2<Object, Object>> prc = metrics.pr().toJavaRDD();
86+
System.out.println("Precision-recall curve: " + prc.toArray());
87+
88+
// Thresholds
89+
JavaRDD<Double> thresholds = precision.map(
90+
new Function<Tuple2<Object, Object>, Double>() {
91+
public Double call(Tuple2<Object, Object> t) {
92+
return new Double(t._1().toString());
93+
}
94+
}
95+
);
96+
97+
// ROC Curve
98+
JavaRDD<Tuple2<Object, Object>> roc = metrics.roc().toJavaRDD();
99+
System.out.println("ROC curve: " + roc.toArray());
100+
101+
// AUPRC
102+
System.out.println("Area under precision-recall curve = " + metrics.areaUnderPR());
103+
104+
// AUROC
105+
System.out.println("Area under ROC = " + metrics.areaUnderROC());
106+
107+
// Save and load model
108+
model.save(sc, "target/tmp/LogisticRegressionModel");
109+
LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc,
110+
"target/tmp/LogisticRegressionModel");
111+
// $example off$
112+
}
113+
}
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.examples.mllib;
19+
20+
// $example on$
21+
import java.util.Arrays;
22+
import java.util.List;
23+
24+
import scala.Tuple2;
25+
26+
import org.apache.spark.api.java.*;
27+
import org.apache.spark.mllib.evaluation.MultilabelMetrics;
28+
import org.apache.spark.rdd.RDD;
29+
import org.apache.spark.SparkConf;
30+
// $example off$
31+
import org.apache.spark.SparkContext;
32+
33+
public class JavaMultiLabelClassificationMetricsExample {
34+
public static void main(String[] args) {
35+
SparkConf conf = new SparkConf().setAppName("Multilabel Classification Metrics Example");
36+
JavaSparkContext sc = new JavaSparkContext(conf);
37+
// $example on$
38+
List<Tuple2<double[], double[]>> data = Arrays.asList(
39+
new Tuple2<double[], double[]>(new double[]{0.0, 1.0}, new double[]{0.0, 2.0}),
40+
new Tuple2<double[], double[]>(new double[]{0.0, 2.0}, new double[]{0.0, 1.0}),
41+
new Tuple2<double[], double[]>(new double[]{}, new double[]{0.0}),
42+
new Tuple2<double[], double[]>(new double[]{2.0}, new double[]{2.0}),
43+
new Tuple2<double[], double[]>(new double[]{2.0, 0.0}, new double[]{2.0, 0.0}),
44+
new Tuple2<double[], double[]>(new double[]{0.0, 1.0, 2.0}, new double[]{0.0, 1.0}),
45+
new Tuple2<double[], double[]>(new double[]{1.0}, new double[]{1.0, 2.0})
46+
);
47+
JavaRDD<Tuple2<double[], double[]>> scoreAndLabels = sc.parallelize(data);
48+
49+
// Instantiate metrics object
50+
MultilabelMetrics metrics = new MultilabelMetrics(scoreAndLabels.rdd());
51+
52+
// Summary stats
53+
System.out.format("Recall = %f\n", metrics.recall());
54+
System.out.format("Precision = %f\n", metrics.precision());
55+
System.out.format("F1 measure = %f\n", metrics.f1Measure());
56+
System.out.format("Accuracy = %f\n", metrics.accuracy());
57+
58+
// Stats by labels
59+
for (int i = 0; i < metrics.labels().length - 1; i++) {
60+
System.out.format("Class %1.1f precision = %f\n", metrics.labels()[i], metrics.precision
61+
(metrics.labels()[i]));
62+
System.out.format("Class %1.1f recall = %f\n", metrics.labels()[i], metrics.recall(metrics
63+
.labels()[i]));
64+
System.out.format("Class %1.1f F1 score = %f\n", metrics.labels()[i], metrics.f1Measure
65+
(metrics.labels()[i]));
66+
}
67+
68+
// Micro stats
69+
System.out.format("Micro recall = %f\n", metrics.microRecall());
70+
System.out.format("Micro precision = %f\n", metrics.microPrecision());
71+
System.out.format("Micro F1 measure = %f\n", metrics.microF1Measure());
72+
73+
// Hamming loss
74+
System.out.format("Hamming loss = %f\n", metrics.hammingLoss());
75+
76+
// Subset accuracy
77+
System.out.format("Subset accuracy = %f\n", metrics.subsetAccuracy());
78+
// $example off$
79+
}
80+
}
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.examples.mllib;
19+
20+
// $example on$
21+
import scala.Tuple2;
22+
23+
import org.apache.spark.api.java.*;
24+
import org.apache.spark.api.java.function.Function;
25+
import org.apache.spark.mllib.classification.LogisticRegressionModel;
26+
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS;
27+
import org.apache.spark.mllib.evaluation.MulticlassMetrics;
28+
import org.apache.spark.mllib.regression.LabeledPoint;
29+
import org.apache.spark.mllib.util.MLUtils;
30+
import org.apache.spark.mllib.linalg.Matrix;
31+
// $example off$
32+
import org.apache.spark.SparkConf;
33+
import org.apache.spark.SparkContext;
34+
35+
public class JavaMulticlassClassificationMetricsExample {
36+
public static void main(String[] args) {
37+
SparkConf conf = new SparkConf().setAppName("Multi class Classification Metrics Example");
38+
SparkContext sc = new SparkContext(conf);
39+
// $example on$
40+
String path = "data/mllib/sample_multiclass_classification_data.txt";
41+
JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD();
42+
43+
// Split initial RDD into two... [60% training data, 40% testing data].
44+
JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[]{0.6, 0.4}, 11L);
45+
JavaRDD<LabeledPoint> training = splits[0].cache();
46+
JavaRDD<LabeledPoint> test = splits[1];
47+
48+
// Run training algorithm to build the model.
49+
final LogisticRegressionModel model = new LogisticRegressionWithLBFGS()
50+
.setNumClasses(3)
51+
.run(training.rdd());
52+
53+
// Compute raw scores on the test set.
54+
JavaRDD<Tuple2<Object, Object>> predictionAndLabels = test.map(
55+
new Function<LabeledPoint, Tuple2<Object, Object>>() {
56+
public Tuple2<Object, Object> call(LabeledPoint p) {
57+
Double prediction = model.predict(p.features());
58+
return new Tuple2<Object, Object>(prediction, p.label());
59+
}
60+
}
61+
);
62+
63+
// Get evaluation metrics.
64+
MulticlassMetrics metrics = new MulticlassMetrics(predictionAndLabels.rdd());
65+
66+
// Confusion matrix
67+
Matrix confusion = metrics.confusionMatrix();
68+
System.out.println("Confusion matrix: \n" + confusion);
69+
70+
// Overall statistics
71+
System.out.println("Precision = " + metrics.precision());
72+
System.out.println("Recall = " + metrics.recall());
73+
System.out.println("F1 Score = " + metrics.fMeasure());
74+
75+
// Stats by labels
76+
for (int i = 0; i < metrics.labels().length; i++) {
77+
System.out.format("Class %f precision = %f\n", metrics.labels()[i],metrics.precision
78+
(metrics.labels()[i]));
79+
System.out.format("Class %f recall = %f\n", metrics.labels()[i], metrics.recall(metrics
80+
.labels()[i]));
81+
System.out.format("Class %f F1 score = %f\n", metrics.labels()[i], metrics.fMeasure
82+
(metrics.labels()[i]));
83+
}
84+
85+
//Weighted stats
86+
System.out.format("Weighted precision = %f\n", metrics.weightedPrecision());
87+
System.out.format("Weighted recall = %f\n", metrics.weightedRecall());
88+
System.out.format("Weighted F1 score = %f\n", metrics.weightedFMeasure());
89+
System.out.format("Weighted false positive rate = %f\n", metrics.weightedFalsePositiveRate());
90+
91+
// Save and load model
92+
model.save(sc, "target/tmp/LogisticRegressionModel");
93+
LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc,
94+
"target/tmp/LogisticRegressionModel");
95+
// $example off$
96+
}
97+
}

0 commit comments

Comments
 (0)