Skip to content

Commit 890bb96

Browse files
author
Yusuke Sugomori
committed
cpp
1 parent 3af35b4 commit 890bb96

File tree

3 files changed

+177
-1
lines changed

3 files changed

+177
-1
lines changed

.gitignore

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
.DS_Store
2-
*.pyc
2+
*.pyc
3+
*.out

cpp/LogisticRegression.cpp

+161
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
#include <iostream>
2+
#include <string>
3+
#include <math.h>
4+
#include "LogisticRegression.h"
5+
using namespace std;
6+
7+
8+
LogisticRegression::LogisticRegression(int size, int in, int out) {
9+
N = size;
10+
n_in = in;
11+
n_out = out;
12+
13+
// initialize W, b
14+
W = new double*[n_in];
15+
for (int i=0; i<n_in; i++) W[i] = new double[n_out];
16+
b = new double[n_out];
17+
}
18+
19+
LogisticRegression::~LogisticRegression() {
20+
for (int i=0; i<n_out; i++) delete[] W[i];
21+
delete[] W;
22+
delete[] b;
23+
}
24+
25+
26+
void LogisticRegression::train(int *x, int *y, double lr) {
27+
int i,j;
28+
double p_y_given_x[n_out];
29+
double dy[n_out];
30+
31+
for (i=0; i<n_out; i++) {
32+
for (j=0; j<n_in; j++) {
33+
p_y_given_x[i] += W[i][j] * x[j];
34+
}
35+
p_y_given_x[i] += b[i];
36+
}
37+
softmax(p_y_given_x);
38+
39+
for (i=0; i<n_out; i++) {
40+
dy[i] = y[i] - p_y_given_x[i];
41+
42+
for (j=0; j<n_in; j++) {
43+
W[i][j] += lr * dy[i] * x[j] / N;
44+
}
45+
46+
b[i] += lr * dy[i] / N;
47+
}
48+
}
49+
50+
void LogisticRegression::softmax(double *x) {
51+
double max;
52+
double sum;
53+
54+
int i;
55+
for (i=0; i<n_out; i++) if(max < x[i]) max = x[i];
56+
for (i=0; i<n_out; i++) {
57+
x[i] = exp(x[i] - max);
58+
sum += x[i];
59+
}
60+
61+
for(i=0; i<n_out; i++) x[i] /= sum;
62+
}
63+
64+
void LogisticRegression::predict(int *x, double *y) {
65+
for (int i=0; i<n_out; i++) {
66+
for (int j=0; j<n_in; j++) {
67+
y[i] += W[i][j] * x[j];
68+
}
69+
y[i] += b[i];
70+
}
71+
72+
softmax(y);
73+
}
74+
75+
76+
void test_lr() {
77+
int i,j;
78+
79+
double learning_rate = 0.1;
80+
double n_epochs = 500;
81+
82+
int train_N = 6;
83+
int test_N = 1;
84+
int n_in = 6;
85+
int n_out = 2;
86+
// int **train_X;
87+
// int **train_Y;
88+
// int **test_X;
89+
// double **test_Y;
90+
91+
// train_X = new int*[train_N];
92+
// train_Y = new int*[train_N];
93+
// for (i=0; i<train_N; i++){
94+
// train_X[i] = new int[n_in];
95+
// train_Y[i] = new int[n_out];
96+
// };
97+
98+
// test_X = new int*[test_N];
99+
// test_Y = new double*[test_N];
100+
// for (i=0; i<test_N; i++){
101+
// test_X[i] = new int[n_in];
102+
// test_Y[i] = new double[n_out];
103+
// }
104+
105+
106+
// training data
107+
int train_X[6][6] = {
108+
{1, 1, 1, 0, 0, 0},
109+
{1, 0, 1, 0, 0, 0},
110+
{1, 1, 1, 0, 0, 0},
111+
{0, 0, 1, 1, 1, 0},
112+
{0, 0, 1, 1, 0, 0},
113+
{0, 0, 1, 1, 1, 0}
114+
};
115+
116+
int train_Y[6][2] = {
117+
{1, 0},
118+
{1, 0},
119+
{1, 0},
120+
{0, 1},
121+
{0, 1},
122+
{0, 1}
123+
};
124+
125+
126+
// construct LogisticRegression
127+
LogisticRegression classifier(train_N, n_in, n_out);
128+
129+
130+
// train online
131+
for (int epoch=0; epoch<n_epochs; epoch++) {
132+
for (i=0; i<train_N; i++) {
133+
classifier.train(train_X[i], train_Y[i], learning_rate);
134+
}
135+
learning_rate *= 0.95;
136+
}
137+
138+
139+
// test data
140+
int test_X[1][6] = {
141+
{1, 1, 1, 0, 0, 0}
142+
};
143+
144+
double test_Y[1][2];
145+
146+
147+
// test
148+
for (i=0; i<test_N; i++) {
149+
classifier.predict(test_X[i], test_Y[i]);
150+
for (j=0; j<n_out; j++) {
151+
cout << test_Y[i][j] << endl;
152+
}
153+
}
154+
155+
}
156+
157+
158+
int main() {
159+
test_lr();
160+
return 0;
161+
}

cpp/LogisticRegression.h

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
class LogisticRegression {
2+
3+
public:
4+
int N; // num of inputs
5+
int n_in;
6+
int n_out;
7+
double **W;
8+
double *b;
9+
LogisticRegression(int, int, int);
10+
~LogisticRegression();
11+
void train(int*, int*, double);
12+
void softmax(double*);
13+
void predict(int*, double*);
14+
};

0 commit comments

Comments
 (0)