Skip to content

Commit 1aa5cf8

Browse files
author
minjk-bl
committed
Add scoring option to GridSearch
1 parent c4f6837 commit 1aa5cf8

File tree

3 files changed

+64
-2
lines changed

3 files changed

+64
-2
lines changed

visualpython/data/m_ml/mlLibrary.js

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -717,10 +717,12 @@ define([
717717
'grid-search': {
718718
name: 'GridSearch',
719719
import: 'from sklearn.model_selection import GridSearchCV',
720+
code: 'GridSearchCV(${estimator}, ${param_grid}${scoring}${n_jobs}${cv}${verbose}${etc})',
720721
returnType: 'GridSearchCV',
721722
options: [
722723
{ name: 'estimator', component: ['data_select'], placeholder: 'Select model'},
723724
{ name: 'param_grid', component: ['input'], placeholder: 'Enter parameters'},
725+
{ name: 'scoring', component: ['input'], placeholder: 'None', usePair: true }, // https://scikit-learn.org/stable/modules/model_evaluation.html#scoring-parameter
724726
{ name: 'n_jobs', component: ['input'], placeholder: 'None', usePair: true },
725727
{ name: 'cv', component: ['input'], placeholder: 'None', usePair: true },
726728
{ name: 'verbose', component: ['input_number'], placeholder: 'Input number', usePair: true }

visualpython/html/m_ml/gridSearch.html

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
<hr style="margin: 5px 0;" />
2323
<div class="vp-flex-gap10">
2424
<label class="vp-bold vp-param-grid-title">Param grid</label>
25-
<button class="vp-button vp-param-set-add" id="vp_addParamSet">+ Add param set</button>
2625
</div>
2726
<div class="vp-grid-box vp-param-grid-box">
2827
<div class="vp-grid-border-box vp-param-set-box">
@@ -36,6 +35,7 @@
3635
<button class="vp-button vp-param-item-add">+ Add param</button>
3736
</div>
3837
</div>
38+
<button class="vp-button vp-param-set-add" id="vp_addParamSet">+ Add param set</button>
3939
</div>
4040
</div>
4141
</body>

visualpython/js/m_ml/GridSearch.js

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@ define([
8181
that.hideInstallButton();
8282
}
8383

84+
that.handleScoringOptions(modelType);
85+
8486
// reset model param set
8587
$(that.wrapSelector('.vp-param-grid-box')).html('');
8688
$(that.wrapSelector('.vp-param-grid-box')).html(that.templateForParamSet());
@@ -140,6 +142,10 @@ define([
140142
let parentTag = $(thisTag).parent();
141143
let paramIsText = $(parentTag).find('.vp-param-val').data('type') === 'text'; // text / var
142144
let paramVal = $(parentTag).find('.vp-param-val').val();
145+
let reservedKeywordList = ['None', 'True', 'False', 'np.nan', 'np.NaN'];
146+
if (reservedKeywordList.includes(paramVal)) {
147+
paramIsText = false;
148+
}
143149
// check , and split it
144150
let paramSplit = paramVal.split(',');
145151
paramSplit && paramSplit.forEach(val => {
@@ -195,6 +201,58 @@ define([
195201
});
196202
}
197203

204+
handleScoringOptions(modelType) {
205+
let options = {
206+
'Classification': [
207+
"'accuracy'",
208+
"'balanced_accuracy'",
209+
"'top_k_accuracy'",
210+
"'average_precision'",
211+
"'neg_brier_score'",
212+
"'f1'",
213+
"'f1_micro'",
214+
"'f1_macro'",
215+
"'f1_weighted'",
216+
"'f1_samples'",
217+
"'neg_log_loss'",
218+
"'precision' etc.",
219+
"'recall' etc.",
220+
"'jaccard' etc.",
221+
"'roc_auc'",
222+
"'roc_auc_ovr'",
223+
"'roc_auc_ovo'",
224+
"'roc_auc_ovr_weighted'",
225+
"'roc_auc_ovo_weighted'",
226+
],
227+
'Regression': [
228+
"'explained_variance'",
229+
"'max_error'",
230+
"'neg_mean_absolute_error'",
231+
"'neg_mean_squared_error'",
232+
"'neg_root_mean_squared_error'",
233+
"'neg_mean_squared_log_error'",
234+
"'neg_median_absolute_error'",
235+
"'r2'",
236+
"'neg_mean_poisson_deviance'",
237+
"'neg_mean_gamma_deviance'",
238+
"'neg_mean_absolute_percentage_error'",
239+
"'d2_absolute_error_score'",
240+
"'d2_pinball_score'",
241+
"'d2_tweedie_score'"
242+
]
243+
}
244+
let modelCategory = this.modelTypeList['Regression'].includes(modelType)?'Regression':'Classification';
245+
246+
// Set suggestInput on scoring option
247+
var suggestInput = new SuggestInput();
248+
suggestInput.setComponentID('scoring');
249+
suggestInput.setPlaceholder('Select option');
250+
suggestInput.addClass('vp-input vp-state');
251+
suggestInput.setSuggestList(function() { return options[modelCategory]; });
252+
suggestInput.setNormalFilter(true);
253+
$(this.wrapSelector('#scoring')).replaceWith(suggestInput.toTagString());
254+
}
255+
198256
templateForParamSet() {
199257
let paramSetNo = 1;
200258
// set param set number
@@ -378,6 +436,7 @@ define([
378436
let thisTag = $(that.wrapSelector('.' + suggestInput.uuid));
379437
that.handleAddParamValue($(thisTag));
380438
$(thisTag).val('');
439+
return false;
381440
});
382441
paramSet.appendLine(suggestInput.toTagString());
383442
}
@@ -394,6 +453,8 @@ define([
394453

395454
// Model Editor
396455
this.modelEditor = new ModelEditor(this, "model", "instanceEditor");
456+
457+
this.handleScoringOptions(this.state.modelType);
397458
}
398459

399460
generateInstallCode() {
@@ -432,7 +493,6 @@ define([
432493
state['estimator'] = estimator;
433494
state['param_grid'] = '{}';
434495

435-
let reservedKeywordList = ['None', 'True', 'False', 'np.nan', 'np.NaN'];
436496
let paramGrid = [];
437497
// generate param_grid
438498
$(this.wrapSelector('.vp-param-set-box')).each((i, tag) => {

0 commit comments

Comments
 (0)