Skip to content

Commit d5d56bf

Browse files
author
minjk-bl
committed
Add plot_feature_importances_ to ModelInfo
1 parent 8bcc288 commit d5d56bf

File tree

2 files changed

+95
-26
lines changed

2 files changed

+95
-26
lines changed

js/com/com_generatorV2.js

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,14 @@ define([
258258
content = renderTabBlock(pageThis, obj, state);
259259
break;
260260
case 'bool_checkbox':
261+
content = $(`<input type="checkbox" id="${obj.name}" class="vp-checkbox"/>`);
262+
if (value != undefined) {
263+
// set as saved value
264+
$(content).attr({
265+
'checked': value
266+
});
267+
}
268+
break;
261269
case 'bool_select':
262270
// True False select box
263271
var optSlct = $(`<select id="${obj.name}" class="vp-select vp-state"></select>`);
@@ -553,13 +561,17 @@ define([
553561
value = input;
554562
break;
555563
case 'option_checkbox':
556-
var checked = $(pageThis.wrapSelector("input[name='"+obj.name+"']:checked")).val();
564+
let checked = $(pageThis.wrapSelector("input[name='"+obj.name+"']:checked")).val();
557565

558566
for (var i = 0; i < checked.length; i++) {
559567
value += "'" + $(checked[i]).val() + "',";
560568
}
561569
value = value.substr(0, value.length-1);
562570
break;
571+
case 'bool_checkbox':
572+
let isChecked = $(pageThis.wrapSelector('#'+obj.name)).prop('checked');
573+
value = isChecked?'True':'False';
574+
break;
563575
case 'input_multi':
564576
case 'bool_select':
565577
case 'var_select':

js/m_ml/ModelInfo.js

Lines changed: 82 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -313,9 +313,17 @@ define([
313313

314314
generateCode() {
315315
let { model } = this.state;
316+
let codeList = [];
316317
let code = new com_String();
317318
let replaceDict = {'${model}': model};
318319

320+
// If functions are available
321+
if (this.state.optionConfig.functions != undefined) {
322+
this.state.optionConfig.functions.forEach(func => {
323+
codeList.push(func);
324+
});
325+
}
326+
319327
// If import code is available, generate its code in front of code
320328
if (this.state.optionConfig.import != undefined) {
321329
code.appendLine(this.state.optionConfig.import);
@@ -342,8 +350,9 @@ define([
342350
}
343351
}
344352
}
353+
codeList.push(code.toString());
345354

346-
return code.toString();
355+
return codeList;
347356
}
348357

349358
getModelCategory(modelType) {
@@ -399,22 +408,61 @@ define([
399408
{ name: 'importance_allocate', label: 'Allocate to', component: ['input'], placeholder: 'New variable', value: 'importances' }
400409
]
401410
},
411+
'feature_importances': {
412+
name: 'feature_importances',
413+
label: 'Feature importances',
414+
functions: [
415+
"def create_feature_importances(model, X_train=None, sort=False):\
416+
\n if isinstance(X_train, pd.core.frame.DataFrame):\
417+
\n feature_names = X_train.columns\
418+
\n else:\n\
419+
\n feature_names = [ 'X{}'.format(i) for i in range(len(model.feature_importances_)) ]\
420+
\n\
421+
\n df_i = pd.DataFrame(model.feature_importances_, index=feature_names, columns=['Feature_importance'])\
422+
\n df_i['Percentage'] = 100 * (df_i['Feature_importance'] / df_i['Feature_importance'].max())\
423+
\n if sort: df_i.sort_values(by='Feature_importance', ascending=False, inplace=True)\
424+
\n df_i = df_i.round(2)\
425+
\n\
426+
\n return df_i"
427+
],
428+
code: "${fi_allocate} = create_feature_importances(${model}, ${fi_featureData}${sort})",
429+
description: 'Allocate feature_importances_',
430+
options: [
431+
{ name: 'fi_featureData', label: 'Feature Data', component: ['var_select'], var_type: ['DataFrame', 'Series', 'ndarray', 'list', 'dict'], value: 'X_train' },
432+
{ name: 'fi_allocate', label: 'Allocate to', component: ['input'], placeholder: 'New variable', value: 'df_i' },
433+
{ name: 'sort', label: 'Sort data', component: ['bool_checkbox'], value: true, usePair: true }
434+
]
435+
},
402436
'plot_feature_importances': {
403437
name: 'plot_feature_importances',
404438
label: 'Plot feature importances',
405-
code: "def plot_feature_importances(model):\n\
406-
n_features = len(model.feature_importances_)\n\
407-
feature_names = [ 'X{}'.format(i) for i in range(n_features) ]\n\
408-
plt.barh(np.arange(n_features), model.feature_importances_, align='center')\n\
409-
plt.yticks(np.arange(n_features), feature_names)\n\
410-
plt.xlabel('Feature importance')\n\
411-
plt.ylabel('Features')\n\
412-
plt.ylim(-1, n_features)\n\
413-
plt.show()\n\n\
414-
plot_feature_importances(${model})",
415-
description: '',
439+
functions: [
440+
"def create_feature_importances(model, X_train=None, sort=False):\
441+
\n if isinstance(X_train, pd.core.frame.DataFrame):\
442+
\n feature_names = X_train.columns\
443+
\n else:\n\
444+
\n feature_names = [ 'X{}'.format(i) for i in range(len(model.feature_importances_)) ]\
445+
\n\
446+
\n df_i = pd.DataFrame(model.feature_importances_, index=feature_names, columns=['Feature_importance'])\
447+
\n df_i['Percentage'] = 100 * (df_i['Feature_importance'] / df_i['Feature_importance'].max())\
448+
\n if sort: df_i.sort_values(by='Feature_importance', ascending=False, inplace=True)\
449+
\n df_i = df_i.round(2)\
450+
\n\
451+
\n return df_i",
452+
"def plot_feature_importances(model, X_train=None, sort=False):\
453+
\n df_i = create_feature_importances(model, X_train, sort)\
454+
\n\
455+
\n df_i['Percentage'].sort_values().plot(kind='barh')\
456+
\n plt.xlabel('Feature importance Percentage')\
457+
\n plt.ylabel('Features')\
458+
\n\
459+
\n plt.show()"
460+
],
461+
code: "plot_feature_importances(${model}, ${fi_featureData}${sort})",
462+
description: 'Draw feature_importances_',
416463
options: [
417-
464+
{ name: 'fi_featureData', label: 'Feature Data', component: ['var_select'], var_type: ['DataFrame', 'Series', 'ndarray', 'list', 'dict'], value: 'X_train' },
465+
{ name: 'sort', label: 'Sort data', component: ['bool_checkbox'], value: true, usePair: true }
418466
]
419467
}
420468
}
@@ -522,6 +570,7 @@ plot_feature_importances(${model})",
522570
]
523571
},
524572
'permutation_importance': defaultInfos['permutation_importance'],
573+
'feature_importances': defaultInfos['feature_importances'],
525574
'plot_feature_importances': defaultInfos['plot_feature_importances'],
526575
'Coefficient': {
527576
name: 'coef_',
@@ -573,11 +622,11 @@ plot_feature_importances(${model})",
573622
name: 'roc_curve',
574623
label: 'ROC Curve',
575624
import: 'from sklearn import metrics',
576-
code: "fpr, tpr, thresholds = metrics.roc_curve(${roc_targetData}, ${model}.predict_proba(${roc_featureData})[:, 1])\n\
577-
plt.plot(fpr, tpr, label='ROC Curve')\n\
578-
plt.xlabel('Sensitivity')\n\
579-
plt.ylabel('Specificity')\n\
580-
plt.show()",
625+
code: "fpr, tpr, thresholds = metrics.roc_curve(${roc_targetData}, ${model}.predict_proba(${roc_featureData})[:, 1])\
626+
\nplt.plot(fpr, tpr, label='ROC Curve')\
627+
\nplt.xlabel('Sensitivity')\
628+
\nplt.ylabel('Specificity')\
629+
\nplt.show()",
581630
description: '',
582631
options: [
583632
{ name: 'roc_targetData', label: 'Target Data', component: ['var_select'], var_type: ['DataFrame', 'Series', 'ndarray', 'list', 'dict'], value: 'y_test' },
@@ -595,8 +644,16 @@ plt.show()",
595644
{ name: 'auc_featureData', label: 'Feature Data', component: ['var_select'], var_type: ['DataFrame', 'Series', 'ndarray', 'list', 'dict'], value: 'X_test' }
596645
]
597646
},
598-
'permutation_importance': defaultInfos['permutation_importance'],
599-
'plot_feature_importances': defaultInfos['plot_feature_importances']
647+
'permutation_importance': defaultInfos['permutation_importance']
648+
}
649+
650+
// feature importances
651+
if (modelType != 'LogisticRegression' && modelType != 'SVC') {
652+
infos = {
653+
...infos,
654+
'feature_importances': defaultInfos['feature_importances'],
655+
'plot_feature_importances': defaultInfos['plot_feature_importances']
656+
}
600657
}
601658

602659
// use decision_function on ROC, AUC
@@ -608,11 +665,11 @@ plt.show()",
608665
...infos,
609666
'roc_curve': {
610667
...infos['roc_curve'],
611-
code: "fpr, tpr, thresholds = metrics.roc_curve(${roc_targetData}, ${model}.decision_function(${roc_featureData}))\n\
612-
plt.plot(fpr, tpr, label='ROC Curve')\n\
613-
plt.xlabel('Sensitivity')\n\
614-
plt.ylabel('Specificity')\n\
615-
plt.show()"
668+
code: "fpr, tpr, thresholds = metrics.roc_curve(${roc_targetData}, ${model}.decision_function(${roc_featureData}))\
669+
\nplt.plot(fpr, tpr, label='ROC Curve')\
670+
\nplt.xlabel('Sensitivity')\
671+
\nplt.ylabel('Specificity')\
672+
\nplt.show()"
616673
},
617674
'auc': {
618675
...infos['auc'],

0 commit comments

Comments
 (0)