|
13 | 13 | // [CLASS] Evaluation
|
14 | 14 | //============================================================================
|
15 | 15 | define([
|
| 16 | + 'text!vp_base/html/m_ml/evaluation.html!strip', |
16 | 17 | 'vp_base/js/com/com_util',
|
17 | 18 | 'vp_base/js/com/com_Const',
|
18 | 19 | 'vp_base/js/com/com_String',
|
19 | 20 | 'vp_base/js/com/component/PopupComponent'
|
20 |
| -], function(com_util, com_Const, com_String, PopupComponent) { |
| 21 | +], function(evalHTML, com_util, com_Const, com_String, PopupComponent) { |
21 | 22 |
|
22 | 23 | /**
|
23 | 24 | * Evaluation
|
24 | 25 | */
|
25 | 26 | class Evaluation extends PopupComponent {
|
26 | 27 | _init() {
|
27 | 28 | super._init();
|
28 |
| - /** Write codes executed before rendering */ |
| 29 | + this.config.dataview = false; |
| 30 | + |
| 31 | + this.state = { |
| 32 | + modelType: 'clf', |
| 33 | + predictData: 'pred', |
| 34 | + targetData: 'y_test', |
| 35 | + // classification |
| 36 | + confusion_matrix: true, report: true, |
| 37 | + accuracy: false, precision: false, recall: false, f1_score: false, |
| 38 | + // regression |
| 39 | + coefficient: false, intercept: false, r_squared: true, |
| 40 | + mae: false, mape: false, rmse: true, scatter_plot: false, |
| 41 | + ...this.state |
| 42 | + } |
29 | 43 | }
|
30 | 44 |
|
31 | 45 | _bindEvent() {
|
32 | 46 | super._bindEvent();
|
33 | 47 | /** Implement binding events */
|
34 | 48 | var that = this;
|
35 |
| - this.$target.on('click', function(evt) { |
36 |
| - var target = evt.target; |
37 |
| - if ($(that.wrapSelector()).find(target).length > 0) { |
38 |
| - // Sample : getDataList from Kernel |
39 |
| - vpKernel.getDataList().then(function(resultObj) { |
40 |
| - vpLog.display(VP_LOG_TYPE.DEVELOP, resultObj); |
41 |
| - }).catch(function(err) { |
42 |
| - vpLog.display(VP_LOG_TYPE.DEVELOP, err); |
43 |
| - }); |
44 |
| - } |
| 49 | + |
| 50 | + // import library |
| 51 | + $(this.wrapSelector('#vp_importLibrary')).on('click', function() { |
| 52 | + com_interface.insertCell('code', 'from sklearn import metrics'); |
45 | 53 | });
|
| 54 | + |
| 55 | + // model type change |
| 56 | + $(this.wrapSelector('#modelType')).on('change', function() { |
| 57 | + let modelType = $(this).val(); |
| 58 | + that.state.modelType = modelType; |
| 59 | + |
| 60 | + $(page).find('.vp-eval-box').hide(); |
| 61 | + $(page).find('.vp-eval-'+modelType).show(); |
| 62 | + }) |
46 | 63 | }
|
47 | 64 |
|
48 | 65 | templateForBody() {
|
49 |
| - /** Implement generating template */ |
50 |
| - return 'This is sample.'; |
| 66 | + let page = $(evalHTML); |
| 67 | + |
| 68 | + $(page).find('.vp-eval-box').hide(); |
| 69 | + $(page).find('.vp-eval-'+this.state.modelType).show(); |
| 70 | + |
| 71 | + return page; |
51 | 72 | }
|
52 | 73 |
|
53 | 74 | generateCode() {
|
54 |
| - return "print('sample code')"; |
| 75 | + let code = new com_String(); |
| 76 | + let { |
| 77 | + modelType, predictData, targetData, |
| 78 | + // classification |
| 79 | + confusion_matrix, report, accuracy, precision, recall, f1_score, |
| 80 | + // regression |
| 81 | + coefficient, intercept, r_squared, mae, mape, rmse, scatter_plot |
| 82 | + } = this.state; |
| 83 | + |
| 84 | + //==================================================================== |
| 85 | + // Classfication |
| 86 | + //==================================================================== |
| 87 | + if (modelType == 'clf') { |
| 88 | + if (confusion_matrix) { |
| 89 | + code.appendLine("# Confusion Matrix"); |
| 90 | + code.appendFormatLine('pd.crosstab({0}, {1}, margins=True)', targetData, predictData); |
| 91 | + } |
| 92 | + if (report) { |
| 93 | + code.appendLine("# Classification report"); |
| 94 | + code.appendFormatLine('print(metrics.classification_report({0}, {1}))', targetData, predictData); |
| 95 | + } |
| 96 | + if (accuracy) { |
| 97 | + code.appendLine("# Accuracy"); |
| 98 | + code.appendFormatLine('metrics.accuracy_score({0}, {1})', targetData, predictData); |
| 99 | + } |
| 100 | + if (precision) { |
| 101 | + code.appendLine("# Precision"); |
| 102 | + code.appendFormatLine("metrics.precision_score({0}, {1}, average='weighted')", targetData, predictData); |
| 103 | + } |
| 104 | + if (recall) { |
| 105 | + code.appendLine("# Recall"); |
| 106 | + code.appendFormatLine("metrics.recall_score({0}, {1}, average='weighted')", targetData, predictData); |
| 107 | + } |
| 108 | + if (f1_score) { |
| 109 | + code.appendLine("# F1-score"); |
| 110 | + code.appendFormatLine("metrics.f1_score({0}, {1}, average='weighted')", targetData, predictData); |
| 111 | + } |
| 112 | + } |
| 113 | + |
| 114 | + //==================================================================== |
| 115 | + // Regression |
| 116 | + //==================================================================== |
| 117 | + if (modelType == 'rgs') { |
| 118 | + if (coefficient) { |
| 119 | + code.appendLine("# Coefficient (scikit-learn only)"); |
| 120 | + code.appendFormatLine('model.coef_'); |
| 121 | + } |
| 122 | + if (intercept) { |
| 123 | + code.appendLine("# Intercept (scikit-learn only)"); |
| 124 | + code.appendFormatLine('model.intercept_'); |
| 125 | + } |
| 126 | + if (r_squared) { |
| 127 | + code.appendLine("# R square"); |
| 128 | + code.appendFormatLine('metrics.r2_score({0}, {1})', targetData, predictData); |
| 129 | + } |
| 130 | + if (mae) { |
| 131 | + code.appendLine("# MAE(Mean Absolute Error)"); |
| 132 | + code.appendFormatLine('metrics.mean_absolute_error({0}, {1})', targetData, predictData); |
| 133 | + } |
| 134 | + if (mape) { |
| 135 | + code.appendLine("# MAPE(Mean Absolute Percentage Error)"); |
| 136 | + code.appendLine('def MAPE(y_test, y_pred):'); |
| 137 | + code.appendLine(' return np.mean(np.abs((y_test - pred) / y_test)) * 100'); |
| 138 | + code.appendLine(); |
| 139 | + code.appendFormatLine('MAPE({0}, {1})', targetData, predictData); |
| 140 | + } |
| 141 | + if (rmse) { |
| 142 | + code.appendLine("# RMSE(Root Mean Squared Error)"); |
| 143 | + code.appendFormatLine('metrics.mean_squared_error({0}, {1})**0.5', targetData, predictData); |
| 144 | + } |
| 145 | + if (scatter_plot) { |
| 146 | + code.appendLine('# Regression plot'); |
| 147 | + code.appendFormatLine('plt.scatter({0}, {1})', targetData, predictData); |
| 148 | + code.appendFormatLine("plt.xlabel('{0}')", targetData); |
| 149 | + code.appendFormatLine("plt.ylabel('{1}')", predictData); |
| 150 | + code.appendLine('plt.show()'); |
| 151 | + } |
| 152 | + } |
| 153 | + // FIXME: as seperated cells |
| 154 | + return code.toString(); |
55 | 155 | }
|
56 | 156 |
|
57 | 157 | }
|
|
0 commit comments