Skip to content

Commit 97457ff

Browse files
author
minjk-bl
committed
Edit to consider SMOTE
1 parent 112792c commit 97457ff

File tree

2 files changed

+60
-7
lines changed

2 files changed

+60
-7
lines changed

visualpython/js/com/component/ModelEditor.js

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,6 @@ define([
143143
description: 'Transform labels to normalized encoding.'
144144
}
145145
}
146-
147146
if (modelType != 'ColumnTransformer') {
148147
actions = {
149148
...actions,
@@ -159,6 +158,32 @@ define([
159158
}
160159
}
161160
}
161+
if (modelType === 'SMOTE') {
162+
actions = {
163+
'fit': {
164+
name: 'fit',
165+
label: 'Fit',
166+
code: '${model}.fit(${fit_featureData}, ${fit_targetData})',
167+
description: 'Check inputs and statistics of the sampler.',
168+
options: [
169+
{ name: 'fit_featureData', label: 'Feature Data', component: ['data_select'], var_type: ['DataFrame', 'Series', 'ndarray', 'list', 'dict'], value: 'X_train' },
170+
{ name: 'fit_targetData', label: 'Target Data', component: ['data_select'], var_type: ['DataFrame', 'Series', 'ndarray', 'list', 'dict'], value: 'y_train' }
171+
]
172+
},
173+
'fit_resample': {
174+
name: 'fit_resample',
175+
label: 'Fit and resample',
176+
code: '${fit_res_allocateX}, ${fit_res_allocatey} = ${model}.fit_resample(${fit_res_featureData}, ${fit_res_targetData})',
177+
description: 'Resample the dataset.',
178+
options: [
179+
{ name: 'fit_res_allocateX', label: 'Allocate feature', component: ['input'], placeholder: 'New variable', value: 'X_res' },
180+
{ name: 'fit_res_allocatey', label: 'Allocate target', component: ['input'], placeholder: 'New variable', value: 'y_res' },
181+
{ name: 'fit_res_featureData', label: 'Feature Data', component: ['data_select'], var_type: ['DataFrame', 'Series', 'ndarray', 'list', 'dict'], value: 'X_train' },
182+
{ name: 'fit_res_targetData', label: 'Target Data', component: ['data_select'], var_type: ['DataFrame', 'Series', 'ndarray', 'list', 'dict'], value: 'y_train' }
183+
]
184+
}
185+
}
186+
}
162187
break;
163188
case 'Regression':
164189
actions = {
@@ -407,10 +432,11 @@ define([
407432
'fit': {
408433
name: 'fit',
409434
label: 'Fit',
410-
code: '${model}.fit(${fit_featureData})',
435+
code: '${model}.fit(${fit_featureData}${fit_targetData})',
411436
description: 'Run fit with all sets of parameters.',
412437
options: [
413-
{ name: 'fit_featureData', label: 'Feature Data', component: ['data_select'], var_type: ['DataFrame', 'Series', 'ndarray', 'list', 'dict'], value: 'X' }
438+
{ name: 'fit_featureData', label: 'Feature Data', component: ['data_select'], var_type: ['DataFrame', 'Series', 'ndarray', 'list', 'dict'], value: 'X_train' },
439+
{ name: 'fit_targetData', label: 'Target Data', component: ['data_select'], var_type: ['DataFrame', 'Series', 'ndarray', 'list', 'dict'], value: 'y_train', usePair: true, pairKey: 'y' }
414440
]
415441
},
416442
'predict': {
@@ -419,7 +445,7 @@ define([
419445
code: '${pred_allocate} = ${model}.predict(${pred_featureData})',
420446
description: 'Call predict on the estimator with the best found parameters.',
421447
options: [
422-
{ name: 'pred_featureData', label: 'Feature Data', component: ['data_select'], var_type: ['DataFrame', 'Series', 'ndarray', 'list', 'dict'], value: 'X' },
448+
{ name: 'pred_featureData', label: 'Feature Data', component: ['data_select'], var_type: ['DataFrame', 'Series', 'ndarray', 'list', 'dict'], value: 'X_test' },
423449
{ name: 'pred_allocate', label: 'Allocate to', component: ['input'], placeholder: 'New variable', value: 'pred' }
424450
]
425451
},
@@ -598,6 +624,19 @@ define([
598624
}
599625
}
600626
}
627+
if (modelType === 'SMOTE') {
628+
infos = {
629+
'get_feature_names_out': {
630+
name: 'get_feature_names_out',
631+
label: 'Get feature names',
632+
code: '${feature_names_allocate} = ${model}.get_feature_names_out()',
633+
description: 'Get output feature names for transformation.',
634+
options: [
635+
{ name: 'feature_names_allocate', label: 'Allocate to', component: ['input'], placeholder: 'New variable', value: 'features' }
636+
]
637+
}
638+
}
639+
}
601640
infos = {
602641
...infos,
603642
'get_params': defaultInfos['get_params']

visualpython/js/m_ml/Pipeline.js

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ define([
6363
- Fit
6464
- Transform
6565
- Predict
66+
- Fit and Predict
67+
- Fit and Transform
68+
- Fit and Resample
6669
*/
6770
this.templateList = {
6871
'data-prep': {
@@ -73,9 +76,10 @@ define([
7376
* ml_* is pre-defined app
7477
* pp_* is defined only for Pipeline
7578
*/
76-
{ name: 'ml_dataPrep', label: 'Data Prep', useApp: true },
79+
{ name: 'ml_dataPrep', label: 'Data Prep', useApp: true, child: ['pp_fit', 'pp_transform', 'pp_fit_resample'] },
7780
{ name: 'pp_fit', label: 'Fit' },
78-
{ name: 'pp_transform', label: 'Transform' }
81+
{ name: 'pp_transform', label: 'Transform' },
82+
{ name: 'pp_fit_resample', label: 'Fit and Resample' }
7983
]
8084
},
8185
'regression': {
@@ -286,7 +290,7 @@ define([
286290
that.state.modelTypeName = modelTypeName;
287291

288292
// show fit / predict / transform depends on model selection
289-
let defaultActions = ['fit', 'predict', 'transform', 'fit_predict', 'fit_transform'];
293+
let defaultActions = ['fit', 'predict', 'transform', 'fit_predict', 'fit_transform', 'fit_resample'];
290294
let actions = that.modelEditor.getAction(modelTypeName);
291295
defaultActions.forEach(actKey => {
292296
if (actions[actKey] === undefined) {
@@ -308,6 +312,10 @@ define([
308312
} else {
309313
$(that.wrapSelector(`.vp-pp-item[data-name="pp_${actKey}"]`)).hide();
310314
}
315+
} else if (actKey === 'fit_resample') {
316+
// for SMOTE: show fit_resample only
317+
$(that.wrapSelector(`.vp-pp-item[data-name="pp_fit"]`)).hide();
318+
$(that.wrapSelector(`.vp-pp-item[data-name="pp_transform"]`)).hide();
311319
}
312320
}
313321
$(that.wrapSelector('.vp-pp-item')).removeClass('vp-last-visible');
@@ -580,6 +588,9 @@ define([
580588
case 'pp_fit_transform':
581589
tag = this.templateForOptionPage(actions['fit_transform']);
582590
break;
591+
case 'pp_fit_resample':
592+
tag = this.templateForOptionPage(actions['fit_resample']);
593+
break;
583594
}
584595
$(this.wrapSelector(`.vp-pp-step-page[data-name="${appId}"]`)).html(`
585596
<div class="vp-grid-border-box vp-grid-col-110">${tag}</div>
@@ -680,6 +691,9 @@ define([
680691
case 'pp_fit_transform':
681692
actObj = actions['fit_transform'];
682693
break;
694+
case 'pp_fit_resample':
695+
actObj = actions['fit_resample'];
696+
break;
683697
}
684698

685699
let code = new com_String();

0 commit comments

Comments
 (0)