Skip to content

Commit d99ad4d

Browse files
committed
Improved the filtering of result fields
1 parent f9b291e commit d99ad4d

File tree

7 files changed

+57
-27
lines changed

7 files changed

+57
-27
lines changed

pmml-evaluator-example/src/main/java/org/jpmml/evaluator/example/TestingExample.java

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import org.jpmml.evaluator.EvaluatorBuilder;
3535
import org.jpmml.evaluator.ModelEvaluatorBuilder;
3636
import org.jpmml.evaluator.ModelEvaluatorFactory;
37+
import org.jpmml.evaluator.ResultField;
3738
import org.jpmml.evaluator.testing.Batch;
3839
import org.jpmml.evaluator.testing.BatchUtil;
3940
import org.jpmml.evaluator.testing.Conflict;
@@ -154,14 +155,14 @@ public void execute() throws Exception {
154155

155156
List<? extends Map<FieldName, ?>> outputRecords = BatchUtil.parseRecords(outputTable, cellParser);
156157

157-
Predicate<FieldName> predicate;
158+
Predicate<ResultField> predicate;
158159

159160
if(this.ignoredFields != null && !this.ignoredFields.isEmpty()){
160-
predicate = (FieldName name) -> !this.ignoredFields.contains(name);
161+
predicate = (ResultField resultField) -> !this.ignoredFields.contains(resultField.getName());
161162
} else
162163

163164
{
164-
predicate = (FieldName name) -> true;
165+
predicate = (ResultField resultField) -> true;
165166
}
166167

167168
Equivalence<Object> equivalence = new PMMLEquivalence(this.precision, this.zeroThreshold);
@@ -178,7 +179,7 @@ public void execute() throws Exception {
178179
}
179180

180181
static
181-
private Batch createBatch(Evaluator evaluator, List<? extends Map<FieldName, ?>> input, List<? extends Map<FieldName, ?>> output, Predicate<FieldName> predicate, Equivalence<Object> equivalence){
182+
private Batch createBatch(Evaluator evaluator, List<? extends Map<FieldName, ?>> input, List<? extends Map<FieldName, ?>> output, Predicate<ResultField> predicate, Equivalence<Object> equivalence){
182183
Batch batch = new Batch(){
183184

184185
@Override
@@ -197,7 +198,7 @@ public Evaluator getEvaluator(){
197198
}
198199

199200
@Override
200-
public Predicate<FieldName> getPredicate(){
201+
public Predicate<ResultField> getPredicate(){
201202
return predicate;
202203
}
203204

pmml-evaluator-testing/src/main/java/org/jpmml/evaluator/testing/ArchiveBatch.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import org.jpmml.evaluator.FieldNameSet;
3535
import org.jpmml.evaluator.FunctionNameStack;
3636
import org.jpmml.evaluator.ModelEvaluatorBuilder;
37+
import org.jpmml.evaluator.ResultField;
3738
import org.jpmml.evaluator.visitors.DefaultModelEvaluatorBattery;
3839
import org.jpmml.model.PMMLUtil;
3940
import org.jpmml.model.visitors.VisitorBattery;
@@ -45,12 +46,12 @@ public class ArchiveBatch implements Batch {
4546

4647
private String dataset = null;
4748

48-
private Predicate<FieldName> predicate = null;
49+
private Predicate<ResultField> predicate = null;
4950

5051
private Equivalence<Object> equivalence = null;
5152

5253

53-
public ArchiveBatch(String name, String dataset, Predicate<FieldName> predicate, Equivalence<Object> equivalence){
54+
public ArchiveBatch(String name, String dataset, Predicate<ResultField> predicate, Equivalence<Object> equivalence){
5455
setName(Objects.requireNonNull(name));
5556
setDataset(Objects.requireNonNull(dataset));
5657
setPredicate(Objects.requireNonNull(predicate));
@@ -149,11 +150,11 @@ private void setDataset(String dataset){
149150
}
150151

151152
@Override
152-
public Predicate<FieldName> getPredicate(){
153+
public Predicate<ResultField> getPredicate(){
153154
return this.predicate;
154155
}
155156

156-
private void setPredicate(Predicate<FieldName> predicate){
157+
private void setPredicate(Predicate<ResultField> predicate){
157158
this.predicate = predicate;
158159
}
159160

pmml-evaluator-testing/src/main/java/org/jpmml/evaluator/testing/Batch.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import com.google.common.base.Equivalence;
2626
import org.dmg.pmml.FieldName;
2727
import org.jpmml.evaluator.Evaluator;
28+
import org.jpmml.evaluator.ResultField;
2829

2930
public interface Batch extends AutoCloseable {
3031

@@ -55,7 +56,7 @@ public interface Batch extends AutoCloseable {
5556
* (between expected and actual output data records).
5657
* </p>
5758
*/
58-
Predicate<FieldName> getPredicate();
59+
Predicate<ResultField> getPredicate();
5960

6061
/**
6162
* <p>

pmml-evaluator-testing/src/main/java/org/jpmml/evaluator/testing/BatchUtil.java

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
import java.util.LinkedHashSet;
2525
import java.util.List;
2626
import java.util.Map;
27-
import java.util.Objects;
2827
import java.util.Set;
2928
import java.util.function.Function;
3029
import java.util.function.Predicate;
@@ -36,6 +35,9 @@
3635
import org.jpmml.evaluator.Evaluator;
3736
import org.jpmml.evaluator.EvaluatorUtil;
3837
import org.jpmml.evaluator.HasGroupFields;
38+
import org.jpmml.evaluator.OutputField;
39+
import org.jpmml.evaluator.ResultField;
40+
import org.jpmml.evaluator.TargetField;
3941

4042
public class BatchUtil {
4143

@@ -59,7 +61,29 @@ public List<Conflict> evaluate(Batch batch) throws Exception {
5961
throw new IllegalArgumentException("Expected the same number of data rows, got " + input.size() + " input data rows and " + output.size() + " expected output data rows");
6062
}
6163

62-
Predicate<FieldName> predicate = (batch.getPredicate()).and(name -> !Objects.equals(Evaluator.DEFAULT_TARGET_NAME, name));
64+
Predicate<ResultField> predicate = batch.getPredicate();
65+
66+
Set<FieldName> names = new LinkedHashSet<>();
67+
68+
List<TargetField> targetFields = evaluator.getTargetFields();
69+
for(TargetField targetField : targetFields){
70+
71+
if(targetField.isSynthetic()){
72+
continue;
73+
} // End if
74+
75+
if(predicate.test(targetField)){
76+
names.add(targetField.getName());
77+
}
78+
}
79+
80+
List<OutputField> outputFields = evaluator.getOutputFields();
81+
for(OutputField outputField : outputFields){
82+
83+
if(predicate.test(outputField)){
84+
names.add(outputField.getName());
85+
}
86+
}
6387

6488
Equivalence<Object> equivalence = batch.getEquivalence();
6589

@@ -69,11 +93,11 @@ public List<Conflict> evaluate(Batch batch) throws Exception {
6993
Map<FieldName, ?> arguments = input.get(i);
7094

7195
Map<FieldName, ?> expectedResults = output.get(i);
72-
expectedResults = Maps.filterKeys(expectedResults, predicate::test);
96+
expectedResults = Maps.filterKeys(expectedResults, names::contains);
7397

7498
try {
7599
Map<FieldName, ?> actualResults = evaluator.evaluate(arguments);
76-
actualResults = Maps.filterKeys(actualResults, predicate::test);
100+
actualResults = Maps.filterKeys(actualResults, names::contains);
77101

78102
MapDifference<FieldName, ?> difference = Maps.<FieldName, Object>difference(expectedResults, actualResults, equivalence);
79103
if(!difference.areEqual()){

pmml-evaluator-testing/src/main/java/org/jpmml/evaluator/testing/FilterBatch.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import com.google.common.base.Equivalence;
2626
import org.dmg.pmml.FieldName;
2727
import org.jpmml.evaluator.Evaluator;
28+
import org.jpmml.evaluator.ResultField;
2829

2930
public class FilterBatch implements Batch {
3031

@@ -57,7 +58,7 @@ public Evaluator getEvaluator() throws Exception {
5758
}
5859

5960
@Override
60-
public Predicate<FieldName> getPredicate(){
61+
public Predicate<ResultField> getPredicate(){
6162
Batch batch = getBatch();
6263

6364
return batch.getPredicate();

pmml-evaluator-testing/src/main/java/org/jpmml/evaluator/testing/IntegrationTest.java

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,13 @@
2020

2121
import java.util.Arrays;
2222
import java.util.Collection;
23+
import java.util.LinkedHashSet;
2324
import java.util.Objects;
2425
import java.util.function.Predicate;
2526

2627
import com.google.common.base.Equivalence;
2728
import org.dmg.pmml.FieldName;
29+
import org.jpmml.evaluator.ResultField;
2830

2931
abstract
3032
public class IntegrationTest extends BatchTest {
@@ -40,18 +42,18 @@ public void evaluate(String name, String dataset) throws Exception {
4042
evaluate(name, dataset, null, null);
4143
}
4244

43-
public void evaluate(String name, String dataset, Predicate<FieldName> predicate) throws Exception {
45+
public void evaluate(String name, String dataset, Predicate<ResultField> predicate) throws Exception {
4446
evaluate(name, dataset, predicate, null);
4547
}
4648

4749
public void evaluate(String name, String dataset, Equivalence<Object> equivalence) throws Exception {
4850
evaluate(name, dataset, null, equivalence);
4951
}
5052

51-
public void evaluate(String name, String dataset, Predicate<FieldName> predicate, Equivalence<Object> equivalence) throws Exception {
53+
public void evaluate(String name, String dataset, Predicate<ResultField> predicate, Equivalence<Object> equivalence) throws Exception {
5254

5355
if(predicate == null){
54-
predicate = (x -> true);
56+
predicate = (resultField -> true);
5557
} // End if
5658

5759
if(equivalence == null){
@@ -63,7 +65,7 @@ public void evaluate(String name, String dataset, Predicate<FieldName> predicate
6365
}
6466
}
6567

66-
protected Batch createBatch(String name, String dataset, Predicate<FieldName> predicate, Equivalence<Object> equivalence){
68+
protected Batch createBatch(String name, String dataset, Predicate<ResultField> predicate, Equivalence<Object> equivalence){
6769
Batch result = new IntegrationTestBatch(name, dataset, predicate, equivalence){
6870

6971
@Override
@@ -84,17 +86,17 @@ private void setEquivalence(Equivalence<Object> equivalence){
8486
}
8587

8688
static
87-
public Predicate<FieldName> excludeFields(FieldName... names){
88-
return excludeFields(Arrays.asList(names));
89+
public Predicate<ResultField> excludeFields(FieldName... names){
90+
return excludeFields(new LinkedHashSet<>(Arrays.asList(names)));
8991
}
9092

9193
static
92-
public Predicate<FieldName> excludeFields(Collection<FieldName> names){
93-
Predicate<FieldName> predicate = new Predicate<FieldName>(){
94+
public Predicate<ResultField> excludeFields(Collection<FieldName> names){
95+
Predicate<ResultField> predicate = new Predicate<ResultField>(){
9496

9597
@Override
96-
public boolean test(FieldName name){
97-
return !names.contains(name);
98+
public boolean test(ResultField resultField){
99+
return !names.contains(resultField.getName());
98100
}
99101
};
100102

pmml-evaluator-testing/src/main/java/org/jpmml/evaluator/testing/IntegrationTestBatch.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626

2727
import com.google.common.base.Equivalence;
2828
import org.dmg.pmml.Application;
29-
import org.dmg.pmml.FieldName;
3029
import org.dmg.pmml.MiningSchema;
3130
import org.dmg.pmml.PMML;
3231
import org.dmg.pmml.Visitor;
@@ -35,6 +34,7 @@
3534
import org.jpmml.evaluator.EvaluatorBuilder;
3635
import org.jpmml.evaluator.ModelEvaluatorBuilder;
3736
import org.jpmml.evaluator.OutputFilters;
37+
import org.jpmml.evaluator.ResultField;
3838
import org.jpmml.evaluator.visitors.InvalidMarkupInspector;
3939
import org.jpmml.evaluator.visitors.UnsupportedMarkupInspector;
4040
import org.jpmml.model.SerializationUtil;
@@ -46,7 +46,7 @@ public class IntegrationTestBatch extends ArchiveBatch {
4646
private Evaluator evaluator = null;
4747

4848

49-
public IntegrationTestBatch(String name, String dataset, Predicate<FieldName> predicate, Equivalence<Object> equivalence){
49+
public IntegrationTestBatch(String name, String dataset, Predicate<ResultField> predicate, Equivalence<Object> equivalence){
5050
super(name, dataset, predicate, equivalence);
5151
}
5252

0 commit comments

Comments
 (0)