Skip to content

Commit 6952050

Browse files
author
lebesgue
committed
add SGHT & .gitignore
1 parent 4dc11e6 commit 6952050

File tree

7 files changed

+396
-0
lines changed

7 files changed

+396
-0
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
*~
2+
*/*~

SGHT/README.md

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
Sparse Group Hard Thresholding (SGHT)
2+
3+
> Sparse Group Feature Selection (SGFS) via SGHT.
4+
5+
6+
## Introduction
7+
The `.m` files are the matlab interfaces for solving the Sparse Group Feature Selectio problem while the `sght_*.cpp` files contain the key proximal parts, i.e., the **Sparse Group Hard Thresholding (SGHT)** problem. Currently there are three versions available:
8+
9+
- `sght.cpp`: **recommended**, best readability, regular DP format
10+
- `sght_external.cpp`: use external memory, complicated indices conversion
11+
- `sght_persisten.cpp`: **unstable and deprecated**, persistent the DP table, complicated indices conversion, needs `clear sght_persistent` properly
12+
13+
## Functions
14+
15+
| Tables | FISTA | ISTA | Barzilai-Borwein | Const | Lipschiz | Sufficient Decrease|
16+
| --------------------|:---------------:|:-----:|:----------------:|:-----:|:--------:|:------------------:|
17+
| `sghtFISTA.m` | Y | |Y | |Y | |
18+
| `sghtFISTAConst.m` | Y | | |Y |Y | |
19+
| `sghtISTA.m` | |Y |Y | | |Y |
20+
| `sghtISTAConst.m` | |Y | |Y | |Y |
21+
| `sghtISTAWolfe.m` | |Y |Y | |Y | |
22+
23+
24+
- framework: FISTA/ISTA:
25+
- Line search initialization: Barzilai-Borwein/const
26+
- Line search criterion: Lipschiz/sufficient decrease
27+
28+
## Usage
29+
30+
See `test_sght.m` for details of calling the functions in matlab. Make sure do `mex sght.cpp` before that.

SGHT/sght.cpp

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
/*
2+
* Matlab usage: x = sght(v, ind, s1, s2)
3+
*/
4+
#include <cstdio>
5+
#include <cstdlib>
6+
#include <algorithm>
7+
#include <cstring>
8+
#include <iostream>
9+
#include <cmath>
10+
#include <mex.h>
11+
12+
using namespace std;
13+
14+
int n, g;
15+
int flmt, glmt;
16+
double *v, *x;
17+
int *vidx;
18+
int *gidx;
19+
20+
// 1-based
21+
double ***d;
22+
int ***path;
23+
double **selection_value;
24+
25+
bool cmp (int x, int y) {
26+
return fabs(v[x]) > fabs(v[y]) + 1e-9;
27+
}
28+
29+
void dp_preprocess () {
30+
int i, j, k;
31+
32+
d = (double ***)malloc((g + 1) * sizeof(double **));
33+
path = (int ***)malloc((g + 1) * sizeof(int **));
34+
for (i = 0; i <= g; ++i) {
35+
d[i] = (double **)malloc((glmt + 1) * sizeof(double *));
36+
path[i] = (int **)malloc((glmt + 1) * sizeof(int *));
37+
for (j = 0; j <= glmt; ++j) {
38+
d[i][j] = (double *)malloc((flmt + 1) * sizeof(double));
39+
path[i][j] = (int *)malloc((flmt + 1) * sizeof(int));
40+
for (k = 0; k <= flmt; ++k) {
41+
d[i][j][k] = .0;
42+
path[i][j][k] = 0;
43+
}
44+
}
45+
}
46+
int up = 0;
47+
for (i = 1; i <= g; ++i)
48+
up = max(up, gidx[i] - gidx[i-1]);
49+
50+
selection_value = (double **)malloc((g + 1) * sizeof(double *));
51+
for (i = 0; i <= g; ++i)
52+
selection_value[i] = (double *)malloc((up + 1) * sizeof(double));
53+
54+
vidx = (int *)malloc(n * sizeof(int));
55+
for (i = 0; i < n; ++i) vidx[i] = i;
56+
57+
for (i = 1; i <= g; ++i) {
58+
sort(vidx + gidx[i-1], vidx + gidx[i], cmp);
59+
selection_value[i][0] = .0;
60+
for (j = 1; j <= gidx[i] - gidx[i-1]; ++j)
61+
selection_value[i][j] = selection_value[i][j-1] +
62+
v[vidx[gidx[i-1]+j-1]] * v[vidx[gidx[i-1]+j-1]];
63+
}
64+
}
65+
66+
void dp () {
67+
int i, j, k, t;
68+
69+
dp_preprocess ();
70+
71+
for (i = 1; i <= g; ++i) {
72+
for (j = 1; j <= glmt; ++j)
73+
for (k = 1; k <= flmt; ++k) {
74+
d[i][j][k] = d[i-1][j][k];
75+
int u = min(k, gidx[i] - gidx[i-1]); // min(|G_i|, k)
76+
int max_idx = 0;
77+
double best = d[i][j][k], v;
78+
for (t = 1; t <= u; ++t) {
79+
if ((v = d[i-1][j-1][k-t] + selection_value[i][t]) > best) {
80+
best = v;
81+
max_idx = t;
82+
}
83+
}
84+
d[i][j][k] = best;
85+
path[i][j][k] = max_idx;
86+
}
87+
}
88+
}
89+
90+
void calc_sol () {
91+
int i, j, k;
92+
memset (x, 0, n * sizeof(double));
93+
94+
for (i = g, j = glmt, k = flmt; i >= 1; --i) {
95+
int num = path[i][j][k];
96+
k -= num;
97+
if (num) --j;
98+
for (int t = 1; t <= num; ++t)
99+
x[vidx[gidx[i-1]+t-1]] = v[vidx[gidx[i-1]+t-1]];
100+
}
101+
}
102+
103+
void destruct () {
104+
int i, j;
105+
106+
//free(v); v = NULL;
107+
free(vidx); vidx = NULL;
108+
free(gidx); gidx = NULL;
109+
for (i = 0; i <= g; ++i) {
110+
for (j = 0; j <= glmt; ++j) {
111+
free(d[i][j]); d[i][j] = NULL;
112+
free(path[i][j]); path[i][j] = NULL;
113+
}
114+
free(d[i]); d[i] = NULL;
115+
free(path[i]); path[i] = NULL;
116+
}
117+
free(d); d = NULL;
118+
free(path); path = NULL;
119+
120+
for (i = 0; i <= g; ++i) {
121+
free(selection_value[i]);
122+
selection_value[i] = NULL;
123+
}
124+
free(selection_value);
125+
selection_value = NULL;
126+
}
127+
128+
void mexFunction (int nlhs, mxArray* plhs[],
129+
int nrhs, const mxArray* prhs[])
130+
{
131+
v = mxGetPr(prhs[0]);
132+
double* gidx_double = mxGetPr(prhs[1]);
133+
134+
n = mxGetNumberOfElements(prhs[0]);
135+
g = mxGetNumberOfElements(prhs[1]) - 1;
136+
137+
gidx = (int *)malloc((g + 1) * sizeof(int));
138+
for (int i = 0; i <= g; ++i) gidx[i] = (int)gidx_double[i];
139+
140+
flmt = mxGetScalar(prhs[2]);
141+
glmt = mxGetScalar(prhs[3]);
142+
143+
double eps = 1e-7;
144+
145+
plhs[0] = mxCreateDoubleMatrix(n, 1, mxREAL);
146+
x = mxGetPr(plhs[0]);
147+
148+
dp ();
149+
calc_sol ();
150+
//mexPrintf ("sglt: Optimal solution is %.5lf\n", d[g][glmt][flmt]);
151+
152+
destruct ();
153+
}
154+

SGHT/sght.mexa64

13.3 KB
Binary file not shown.

SGHT/sghtISTA.m

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
function [ x, fval ] = sghtISTA( A, y, ind, s1, s2 )
2+
% sghtISTA Summary of this function goes here
3+
% Detailed explanation goes here
4+
5+
[m, n] = size(A);
6+
grp = length(ind) - 1;
7+
x = zeros(n, 1);
8+
9+
maxIter = 1000;
10+
fval = zeros(maxIter, 1);
11+
tol = 1e-5;
12+
13+
%% Sanity check
14+
assert(ind(grp+1) == n);
15+
assert(ind(1) == 0);
16+
assert(s1 <= n);
17+
assert(s2 <= grp);
18+
19+
%%
20+
fval(1) = 0.5 * norm(A * x - y, 2)^2;
21+
x_last = x;
22+
g_last = A' * (A * x - y);
23+
L = 1;
24+
25+
while true
26+
x_cur = sght(x_last - g_last/L, ind, s1, s2);
27+
ax = A * x_cur;
28+
fval_cur = 0.5 * norm(ax - y, 2)^2;
29+
if (fval_cur < fval(1) - L/2 * norm(x_cur - x_last, 2)^2)
30+
break
31+
end
32+
L = L * 2;
33+
% fprintf ('\t\t L = %f\n', L);
34+
end
35+
36+
fval(2) = fval_cur;
37+
38+
39+
for itr = 3 : maxIter
40+
41+
g_cur = A' * (ax - y);
42+
43+
delta_x = x_cur - x_last;
44+
delta_g = g_cur - g_last;
45+
46+
if (norm(delta_x) == 0 || delta_x' * delta_g == 0)
47+
L = 1;
48+
else
49+
L = (delta_x' * delta_g) / (delta_x' * delta_x);
50+
end
51+
52+
% if (rem(itr, 10) == 0)
53+
% fprintf ('\tIteration %d\n', itr);
54+
% fprintf ('\t\tL is initialized to %f\n', L);
55+
% fprintf ('\t\tobj is %f\n', fval(itr-1));
56+
% end
57+
58+
while true
59+
x = sght(x_cur - g_cur/L, ind, s1, s2);
60+
ax = A * x;
61+
f = 0.5 * norm(ax - y, 2)^2;
62+
if (f < fval(itr-1) - L/2 * norm(x - x_cur, 2)^2)
63+
break
64+
end
65+
L = L * 2;
66+
end
67+
fval(itr) = f;
68+
x_last = x_cur;
69+
g_last = g_cur;
70+
x_cur = x;
71+
72+
if (abs(f - fval(itr-1)) < tol * fval(itr-1) || norm(g_cur) < tol || f < tol)
73+
break
74+
end
75+
end
76+
77+
fval = fval(1 : itr);
78+
end
79+

SGHT/sghtISTAWolfe.m

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
function [ x, fval ] = sghtISTAWolfe( A, y, ind, s1, s2 )
2+
%sghtISTAWolfe Summary of this function goes here
3+
% Detailed explanation goes here
4+
5+
[m, n] = size(A);
6+
grp = length(ind) - 1;
7+
x = zeros(n, 1);
8+
9+
maxIter = 1000;
10+
fval = zeros(maxIter, 1);
11+
tol = 1e-5;
12+
13+
%% Sanity check
14+
assert(ind(grp+1) == n);
15+
assert(ind(1) == 0);
16+
assert(s1 <= n);
17+
assert(s2 <= grp);
18+
19+
%%
20+
fval(1) = 0.5 * norm(A * x - y, 2)^2;
21+
x_last = x;
22+
g_last = A' * (A * x - y);
23+
L = 1;
24+
25+
while true
26+
x_cur = sght(x_last - g_last/L, ind, s1, s2);
27+
ax = A * x_cur;
28+
fval_cur = 0.5 * norm(ax - y, 2)^2;
29+
dlt = x_cur - x_last;
30+
left = norm(A * dlt, 2)^2;
31+
right = dlt' * dlt;
32+
if left <= L * right
33+
break
34+
end
35+
L = L * 2;
36+
% fprintf ('\t\t L = %f\n', L);
37+
end
38+
39+
fval(2) = fval_cur;
40+
41+
42+
for itr = 3 : maxIter
43+
44+
g_cur = A' * (ax - y);
45+
46+
delta_x = x_cur - x_last;
47+
delta_g = g_cur - g_last;
48+
49+
if (norm(delta_x) == 0 || delta_x' * delta_g == 0)
50+
L = 1;
51+
else
52+
L = (delta_x' * delta_g) / (delta_x' * delta_x);
53+
end
54+
55+
% if (rem(itr, 10) == 3)
56+
fprintf ('\tIteration %d\n', itr);
57+
fprintf ('\t\tL is initialized to %f\n', L);
58+
fprintf ('\t\tobj is %f\n', fval(itr-1));
59+
% end
60+
61+
while true
62+
x = sght(x_cur - g_cur/L, ind, s1, s2);
63+
ax = A * x;
64+
f = 0.5 * norm(ax - y, 2)^2;
65+
66+
dlt = x - x_cur;
67+
left = norm(A * dlt, 2)^2;
68+
right = dlt' * dlt;
69+
if left <= L * right
70+
break
71+
end
72+
73+
L = L * 2;
74+
% fprintf ('L = %d\n', L);
75+
end
76+
fval(itr) = f;
77+
x_last = x_cur;
78+
g_last = g_cur;
79+
x_cur = x;
80+
81+
if (abs(f - fval(itr-1)) < tol * fval(itr-1) || norm(g_cur) < tol || f < tol)
82+
break
83+
end
84+
end
85+
86+
fval = fval(1 : itr);
87+
end
88+

0 commit comments

Comments
 (0)