Skip to content

Commit db834ef

Browse files
committed
scripted_metric support
1 parent 7c1e253 commit db834ef

File tree

4 files changed

+81
-2
lines changed

4 files changed

+81
-2
lines changed

src/main/java/org/nlpcn/es4sql/domain/MethodField.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package org.nlpcn.es4sql.domain;
22

3+
import java.util.HashMap;
34
import java.util.List;
5+
import java.util.Map;
46

57
import org.nlpcn.es4sql.Util;
68

@@ -27,6 +29,14 @@ public List<KVValue> getParams() {
2729
return params;
2830
}
2931

32+
public Map<String,Object> getParamsAsMap(){
33+
Map<String,Object> paramsAsMap = new HashMap<>();
34+
for(KVValue kvValue : this.params){
35+
paramsAsMap.put(kvValue.key,kvValue.value);
36+
}
37+
return paramsAsMap;
38+
}
39+
3040
@Override
3141
public String toString() {
3242
if (option != null) {

src/main/java/org/nlpcn/es4sql/domain/Select.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
public class Select extends Query {
1717

1818
// Using this functions, will cause query to execute as aggregation.
19-
private final List<String> aggsFunctions = Arrays.asList("SUM", "MAX", "MIN", "AVG", "TOPHITS", "COUNT", "STATS","EXTENDED_STATS","PERCENTILES");
19+
private final List<String> aggsFunctions = Arrays.asList("SUM", "MAX", "MIN", "AVG", "TOPHITS", "COUNT", "STATS","EXTENDED_STATS","PERCENTILES","SCRIPTED_METRIC");
2020
private List<Hint> hints = new ArrayList<>();
2121
private List<Field> fields = new ArrayList<>();
2222
private List<List<Field>> groupBys = new ArrayList<>();

src/main/java/org/nlpcn/es4sql/query/maker/AggMaker.java

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import org.elasticsearch.search.aggregations.bucket.terms.TermsBuilder;
2121
import org.elasticsearch.search.aggregations.metrics.MetricsAggregationBuilder;
2222
import org.elasticsearch.search.aggregations.metrics.ValuesSourceMetricsAggregationBuilder;
23+
import org.elasticsearch.search.aggregations.metrics.scripted.ScriptedMetricBuilder;
2324
import org.elasticsearch.search.aggregations.metrics.tophits.TopHitsBuilder;
2425
import org.elasticsearch.search.sort.SortOrder;
2526
import org.nlpcn.es4sql.Util;
@@ -92,6 +93,8 @@ public AbstractAggregationBuilder makeFieldAgg(MethodField field, AbstractAggreg
9293
return builder;
9394
case "TOPHITS":
9495
return makeTopHitsAgg(field);
96+
case "SCRIPTED_METRIC":
97+
return scriptedMetric(field);
9598
case "COUNT":
9699
groupMap.put(field.getAlias(), new KVValue("COUNT", parent));
97100
return makeCountAgg(field);
@@ -136,6 +139,61 @@ private AggregationBuilder<?> makeRangeGroup(MethodField field) throws SqlParseE
136139

137140
}
138141

142+
private AbstractAggregationBuilder scriptedMetric(MethodField field) throws SqlParseException {
143+
String aggName = gettAggNameFromParamsOrAlias(field);
144+
ScriptedMetricBuilder scriptedMetricBuilder = AggregationBuilders.scriptedMetric(aggName);
145+
Map<String, Object> scriptedMetricParams = field.getParamsAsMap();
146+
if(!scriptedMetricParams.containsKey("map_script") && !scriptedMetricParams.containsKey("map_script_id") && !scriptedMetricParams.containsKey("map_script_file")){
147+
throw new SqlParseException("scripted metric parameters must contain map_script/map_script_id/map_script_file parameter");
148+
}
149+
for(Map.Entry<String,Object> param : scriptedMetricParams.entrySet()) {
150+
String paramValue = param.getValue().toString();
151+
switch (param.getKey().toLowerCase()) {
152+
case "map_script":
153+
scriptedMetricBuilder.mapScript(paramValue);
154+
break;
155+
case "map_script_id":
156+
scriptedMetricBuilder.mapScriptId(paramValue);
157+
break;
158+
case "map_script_file":
159+
scriptedMetricBuilder.mapScriptFile(paramValue);
160+
break;
161+
case "init_script":
162+
scriptedMetricBuilder.initScript(paramValue);
163+
break;
164+
case "init_script_id":
165+
scriptedMetricBuilder.initScriptId(paramValue);
166+
break;
167+
case "init_script_file":
168+
scriptedMetricBuilder.initScriptFile(paramValue);
169+
break;
170+
case "combine_script":
171+
scriptedMetricBuilder.combineScript(paramValue);
172+
break;
173+
case "combine_script_id":
174+
scriptedMetricBuilder.combineScriptId(paramValue);
175+
break;
176+
case "combine_script_file":
177+
scriptedMetricBuilder.combineScriptFile(paramValue);
178+
break;
179+
case "reduce_script":
180+
scriptedMetricBuilder.reduceScript(paramValue);
181+
break;
182+
case "reduce_script_id":
183+
scriptedMetricBuilder.reduceScriptId(paramValue);
184+
break;
185+
case "reduce_script_file":
186+
scriptedMetricBuilder.reduceScriptFile(paramValue);
187+
break;
188+
case "alias":
189+
break;
190+
default:
191+
throw new SqlParseException("scripted_metric err or not define field " + param.getKey());
192+
}
193+
}
194+
return scriptedMetricBuilder;
195+
}
196+
139197
private AggregationBuilder<?> geohashGrid(MethodField field) throws SqlParseException {
140198
String aggName = gettAggNameFromParamsOrAlias(field);
141199
GeoHashGridBuilder geoHashGrid = AggregationBuilders.geohashGrid(aggName);

src/test/java/org/nlpcn/es4sql/AggregationTest.java

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import org.elasticsearch.search.aggregations.metrics.max.Max;
1515
import org.elasticsearch.search.aggregations.metrics.min.Min;
1616
import org.elasticsearch.search.aggregations.metrics.percentiles.Percentiles;
17+
import org.elasticsearch.search.aggregations.metrics.scripted.ScriptedMetric;
1718
import org.elasticsearch.search.aggregations.metrics.stats.Stats;
1819
import org.elasticsearch.search.aggregations.metrics.stats.extended.ExtendedStats;
1920
import org.elasticsearch.search.aggregations.metrics.sum.Sum;
@@ -47,7 +48,8 @@ public void sumTest() throws IOException, SqlParseException, SQLFeatureNotSuppor
4748
assertThat(sum.getValue(), equalTo(25714837.0));
4849
}
4950

50-
// script on metric aggregation tests. uncomment if your elastic has scripts enable
51+
// script on metric aggregation tests. uncomment if your elastic has scripts enable (disabled by default)
52+
//todo: find a way to check if scripts are enabled
5153
// @Test
5254
// public void sumWithScriptTest() throws IOException, SqlParseException, SQLFeatureNotSupportedException {
5355
// Aggregations result = query(String.format("SELECT SUM(script('','doc[\\'balance\\'].value + doc[\\'balance\\'].value')) as doubleSum FROM %s/account", TEST_INDEX));
@@ -67,6 +69,15 @@ public void sumTest() throws IOException, SqlParseException, SQLFeatureNotSuppor
6769
// Aggregations result = query(String.format("SELECT SUM(balance + balance) FROM %s/account", TEST_INDEX));
6870
// Sum sum = result.get("SUM(script=script(balance + balance,doc('balance').value + doc('balance').value))");
6971
// assertThat(sum.getValue(), equalTo(25714837.0*2));
72+
// }
73+
//
74+
// @Test
75+
// public void scriptedMetricAggregation() throws SQLFeatureNotSupportedException, SqlParseException {
76+
// Aggregations result = query ("select scripted_metric('map_script'='if(doc[\\'balance\\'].value > 49670){ if(!_agg.containsKey(\\'ages\\')) { _agg.put(\\'ages\\',doc[\\'age\\'].value); } " +
77+
// "else { _agg.put(\\'ages\\',_agg.get(\\'ages\\')+doc[\\'age\\'].value); }}'," +
78+
// "'reduce_script'='sumThem = 0; for (a in _aggs) { if(a.containsKey(\\'ages\\')){ sumThem += a.get(\\'ages\\');} }; return sumThem;') as wierdSum from " + TEST_INDEX + "/account");
79+
// ScriptedMetric metric = result.get("wierdSum");
80+
// Assert.assertEquals(136L,metric.aggregation());
7081
// }
7182

7283
@Test

0 commit comments

Comments
 (0)