Skip to content

Commit 062d72a

Browse files
author
lebesgue
committed
add exclusive sght
1 parent 6952050 commit 062d72a

File tree

2 files changed

+252
-0
lines changed

2 files changed

+252
-0
lines changed

SGHT/sghtISTAWolfeExclusive.m

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
function [ x, fval ] = sghtISTAWolfeExclusive( A, y, ind, s1, s2, upper)
2+
%sghtISTAWolfe Summary of this function goes here
3+
% Detailed explanation goes here
4+
5+
[m, n] = size(A);
6+
grp = length(ind) - 1;
7+
x = zeros(n, 1);
8+
9+
maxIter = 1000;
10+
fval = zeros(maxIter, 1);
11+
tol = 1e-5;
12+
13+
%% Sanity check
14+
assert(ind(grp+1) == n);
15+
assert(ind(1) == 0);
16+
assert(s1 <= n);
17+
assert(s2 <= grp);
18+
19+
if nargin <= 5
20+
upper = ind(2 : grp + 1) - ind(1 : grp);
21+
end
22+
23+
%%
24+
fval(1) = 0.5 * norm(A * x - y, 2)^2;
25+
x_last = x;
26+
g_last = A' * (A * x - y);
27+
L = 1;
28+
29+
while true
30+
x_cur = sght_exclusive(x_last - g_last/L, ind, s1, s2, upper);
31+
ax = A * x_cur;
32+
fval_cur = 0.5 * norm(ax - y, 2)^2;
33+
dlt = x_cur - x_last;
34+
left = norm(A * dlt, 2)^2;
35+
right = dlt' * dlt;
36+
if left <= L * right
37+
break
38+
end
39+
L = L * 2;
40+
% fprintf ('\t\t L = %f\n', L);
41+
end
42+
43+
fval(2) = fval_cur;
44+
45+
46+
for itr = 3 : maxIter
47+
48+
g_cur = A' * (ax - y);
49+
50+
delta_x = x_cur - x_last;
51+
delta_g = g_cur - g_last;
52+
53+
if (norm(delta_x) == 0 || delta_x' * delta_g == 0)
54+
L = 1;
55+
else
56+
L = (delta_x' * delta_g) / (delta_x' * delta_x);
57+
end
58+
59+
% if (rem(itr, 10) == 3)
60+
fprintf ('\tIteration %d\n', itr);
61+
fprintf ('\t\tL is initialized to %f\n', L);
62+
fprintf ('\t\tobj is %f\n', fval(itr-1));
63+
% end
64+
65+
while true
66+
x = sght(x_cur - g_cur/L, ind, s1, s2);
67+
ax = A * x;
68+
f = 0.5 * norm(ax - y, 2)^2;
69+
70+
dlt = x - x_cur;
71+
left = norm(A * dlt, 2)^2;
72+
right = dlt' * dlt;
73+
if left <= L * right
74+
break
75+
end
76+
77+
L = L * 2;
78+
% fprintf ('L = %d\n', L);
79+
end
80+
fval(itr) = f;
81+
x_last = x_cur;
82+
g_last = g_cur;
83+
x_cur = x;
84+
85+
if (abs(f - fval(itr-1)) < tol * fval(itr-1) || norm(g_cur) < tol || f < tol)
86+
break
87+
end
88+
end
89+
90+
fval = fval(1 : itr);
91+
end
92+

SGHT/sght_exclusive.cpp

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
/*
2+
* Matlab usage: x = sght(v, ind, s1, s2, num)
3+
*/
4+
#include <cstdio>
5+
#include <cstdlib>
6+
#include <algorithm>
7+
#include <cstring>
8+
#include <iostream>
9+
#include <cmath>
10+
#include <mex.h>
11+
12+
using namespace std;
13+
14+
int n, g;
15+
int flmt, glmt;
16+
double *v, *x, *num;
17+
int *vidx;
18+
int *gidx;
19+
20+
// 1-based
21+
double ***d;
22+
int ***path;
23+
double **selection_value;
24+
25+
bool cmp (int x, int y) {
26+
return fabs(v[x]) > fabs(v[y]) + 1e-9;
27+
}
28+
29+
void dp_preprocess () {
30+
int i, j, k;
31+
32+
d = (double ***)malloc((g + 1) * sizeof(double **));
33+
path = (int ***)malloc((g + 1) * sizeof(int **));
34+
for (i = 0; i <= g; ++i) {
35+
d[i] = (double **)malloc((glmt + 1) * sizeof(double *));
36+
path[i] = (int **)malloc((glmt + 1) * sizeof(int *));
37+
for (j = 0; j <= glmt; ++j) {
38+
d[i][j] = (double *)malloc((flmt + 1) * sizeof(double));
39+
path[i][j] = (int *)malloc((flmt + 1) * sizeof(int));
40+
for (k = 0; k <= flmt; ++k) {
41+
d[i][j][k] = .0;
42+
path[i][j][k] = 0;
43+
}
44+
}
45+
}
46+
int up = 0;
47+
for (i = 1; i <= g; ++i)
48+
up = max(up, gidx[i] - gidx[i-1]);
49+
50+
selection_value = (double **)malloc((g + 1) * sizeof(double *));
51+
for (i = 0; i <= g; ++i)
52+
selection_value[i] = (double *)malloc((up + 1) * sizeof(double));
53+
54+
vidx = (int *)malloc(n * sizeof(int));
55+
for (i = 0; i < n; ++i) vidx[i] = i;
56+
57+
for (i = 1; i <= g; ++i) {
58+
sort(vidx + gidx[i-1], vidx + gidx[i], cmp);
59+
selection_value[i][0] = .0;
60+
for (j = 1; j <= gidx[i] - gidx[i-1]; ++j)
61+
selection_value[i][j] = selection_value[i][j-1] +
62+
v[vidx[gidx[i-1]+j-1]] * v[vidx[gidx[i-1]+j-1]];
63+
}
64+
}
65+
66+
void dp () {
67+
int i, j, k, t;
68+
69+
dp_preprocess ();
70+
71+
for (i = 1; i <= g; ++i) {
72+
int upper = (int)num[i-1];
73+
for (j = 1; j <= glmt; ++j)
74+
for (k = 1; k <= flmt; ++k) {
75+
d[i][j][k] = d[i-1][j][k];
76+
77+
// min(|G_i|, min(num_i, k))
78+
int u = min(k, min(upper, gidx[i] - gidx[i-1]));
79+
80+
int max_idx = 0;
81+
double best = d[i][j][k], v;
82+
for (t = 1; t <= u; ++t) {
83+
if ((v = d[i-1][j-1][k-t] + selection_value[i][t]) > best) {
84+
best = v;
85+
max_idx = t;
86+
}
87+
}
88+
d[i][j][k] = best;
89+
path[i][j][k] = max_idx;
90+
}
91+
}
92+
}
93+
94+
void calc_sol () {
95+
int i, j, k;
96+
memset (x, 0, n * sizeof(double));
97+
98+
for (i = g, j = glmt, k = flmt; i >= 1; --i) {
99+
int num = path[i][j][k];
100+
k -= num;
101+
if (num) --j;
102+
for (int t = 1; t <= num; ++t)
103+
x[vidx[gidx[i-1]+t-1]] = v[vidx[gidx[i-1]+t-1]];
104+
}
105+
}
106+
107+
void destruct () {
108+
int i, j;
109+
110+
//free(v); v = NULL;
111+
free(vidx); vidx = NULL;
112+
free(gidx); gidx = NULL;
113+
for (i = 0; i <= g; ++i) {
114+
for (j = 0; j <= glmt; ++j) {
115+
free(d[i][j]); d[i][j] = NULL;
116+
free(path[i][j]); path[i][j] = NULL;
117+
}
118+
free(d[i]); d[i] = NULL;
119+
free(path[i]); path[i] = NULL;
120+
}
121+
free(d); d = NULL;
122+
free(path); path = NULL;
123+
124+
for (i = 0; i <= g; ++i) {
125+
free(selection_value[i]);
126+
selection_value[i] = NULL;
127+
}
128+
free(selection_value);
129+
selection_value = NULL;
130+
}
131+
132+
void mexFunction (int nlhs, mxArray* plhs[],
133+
int nrhs, const mxArray* prhs[])
134+
{
135+
v = mxGetPr(prhs[0]);
136+
double* gidx_double = mxGetPr(prhs[1]);
137+
138+
n = mxGetNumberOfElements(prhs[0]);
139+
g = mxGetNumberOfElements(prhs[1]) - 1;
140+
141+
gidx = (int *)malloc((g + 1) * sizeof(int));
142+
for (int i = 0; i <= g; ++i) gidx[i] = (int)gidx_double[i];
143+
144+
flmt = mxGetScalar(prhs[2]);
145+
glmt = mxGetScalar(prhs[3]);
146+
147+
num = mxGetPr(prhs[4]);
148+
149+
double eps = 1e-7;
150+
151+
plhs[0] = mxCreateDoubleMatrix(n, 1, mxREAL);
152+
x = mxGetPr(plhs[0]);
153+
154+
dp ();
155+
calc_sol ();
156+
//mexPrintf ("sglt: Optimal solution is %.5lf\n", d[g][glmt][flmt]);
157+
158+
destruct ();
159+
}
160+

0 commit comments

Comments
 (0)