-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathkruskal_sparsereg.m
350 lines (331 loc) · 12.2 KB
/
kruskal_sparsereg.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
function [beta0_final,beta_final,beta_scale,glmstats] = ...
kruskal_sparsereg(X,M,y,r,dist,lambda,pentype,penparam,varargin)
% KRUSKAL_SPARSEREG Fit the rank-r GLM sparse Kruskal tensor regression
% [BETA0,BETA,BETA_SCALE,GLMSTATS] =
% KRUSKAL_SPARSEREG(X,M,Y,R,DIST,LAMBDA,PENTYPE,PENPARAM) fits the sparse
% Kruskal tensor regression using the regular covariate matrix X,
% multidimensional array(or tensor) variates M, response Y, rank of the
% Kruskal tensor regression R, the assumed distribution of the model
% DIST, the penalty PENTYPE at fixed tuning parameter value LAMBDA, and
% the index of the penalty type PENPARAM. Available value of DIST is are
% 'normal', 'binomial', 'gamma', 'inverse gaussian', and 'poisson'. For
% the input value PENTYPE, available penalties are 'enet',
% 'log','mcp','power' and 'scad'. The result BETA0 is the regression
% coefficient vector for the regular covariates matrix X, BETA is the
% tensor regression coefficient of the tensor covariates M, BETA_SCALE is
% the tensr of the scaling constants for the tensor covariates M,
% GLMSTATS is the summary statistics of GLM fit from the last iteration
% and DEV is the deviance of final model.
%
% [BETA0,BETA,BETA_SCALE,GLMSTATS] =
% KRUSKAL_SPARSEREG(X,M,Y,R,DIST,LAMBDA,PENTYPE,PENPARAM,'PARAM1',val1,
% 'PARAM2',val2...) allows you to specify optional parameters to control
% the model fit. Availavle parameter name/value pairs are:
% 'B0': starting point, it can be a numeric array or a tensor
% 'BurninMaxIter': Max. iter. for the burn-in runs, default is 20
% 'BurninTolFun': Tolerance for the burn-in runs, default is 1e-2
% 'BurninReplicates': Number of the burn-in runs, default is 5
% 'Display': 'off' (default) or 'iter'
% 'PenaltyMaxIter': Max. iters. at penalization stage, default is 50
% 'PenaltyTolFun': Tolerence at penalization stage, default is 1e-3
% 'weights': observation weights, default is ones for each obs.
%
% INPUT:
% X: n-by-p0 regular covariate matrix
% M: array variates (or tensors) with dim(M) = [p1,p2,...,pd,n]
% y: n-by-1 respsonse vector
% r: rank of tensor regression
% dist: 'binomial'|'normal'|'poisson'
% lambda: penalty tuning constant
% pentype: 'enet'|'log'|'mcp'|'power'|'scad'
% penparam: the index parameter for the pentype
%
% OUTPUT:
% beta0_final: regression coefficients for the regular covariates
% beta_final: a tensor of regression coefficientsn for array variates
% beta_scale: a tensor of the scaling constants for the array
% coefficients
% glmstats: GLM statistics from the last fitting of the regular
% covariates
%
% Examples
%
% See also kruskal_reg, matrix_sparsereg, tucker_reg, tucker_sparsereg.
%
% Reference
% H Zhou, L Li, and H Zhu (2013) Tensor regression with applications in
% neuroimaging data analysis, JASA 108(502):540-552
%
% TODO
%
% COPYRIGHT 2011-2013 North Carolina State University
% Hua Zhou <[email protected]>
% parse inputs
argin = inputParser;
argin.addRequired('X', @isnumeric);
argin.addRequired('M', @(x) isa(x,'tensor') || isnumeric(x));
argin.addRequired('y', @isnumeric);
argin.addRequired('r', @isnumeric);
argin.addRequired('dist', @(x) ischar(x));
argin.addRequired('lambda', @(x) isnumeric(x) && x>=0);
argin.addRequired('pentype', @ischar);
argin.addRequired('penparam', @isnumeric);
argin.addParamValue('B0', [], @(x) isnumeric(x) || ...
isa(x,'tensor') || isa(x,'ktensor') || isa(x,'ttensor'));
argin.addParamValue('Display', 'off', @(x) strcmp(x,'off')||strcmp(x,'iter'));
argin.addParamValue('BurninMaxIter', 20, @(x) isnumeric(x) && x>0);
argin.addParamValue('BurninTolFun', 1e-2, @(x) isnumeric(x) && x>0);
argin.addParamValue('BurninReplicates', 5, @(x) isnumeric(x) && x>0);
argin.addParamValue('PenaltyMaxIter', 50, @(x) isnumeric(x) && x>0);
argin.addParamValue('PenaltyTolFun', 1e-3, @(x) isnumeric(x) && x>0);
argin.addParamValue('warn', false, @(x) islogical(x));
argin.addParamValue('weights', [], @(x) isnumeric(x) && all(x>=0));
argin.parse(X,M,y,r,dist,lambda,pentype,penparam,varargin{:});
B0 = argin.Results.B0;
Display = argin.Results.Display;
BurninMaxIter = argin.Results.BurninMaxIter;
BurninTolFun = argin.Results.BurninTolFun;
BurninReplicates = argin.Results.BurninReplicates;
PenaltyMaxIter = argin.Results.PenaltyMaxIter;
PenaltyTolFun = argin.Results.PenaltyTolFun;
warn = argin.Results.warn;
wts = argin.Results.weights;
if isempty(wts)
wts = ones(size(X,1),1);
end
% check positivity of tuning parameter
if lambda==0
error('tensorreg:kruskal_sparsereg:nopen', ...
'lambda=0 (no penalization); call kruskal_reg instead');
end
% check validity of rank r
if isempty(r)
r = 1;
end
% decide least squares or GLM model
if strcmpi(dist,'normal')
isglm = false;
else
isglm = true;
% translate to model specifier for sparse regression
if strcmpi(dist,'binomial')
glmmodel = 'logistic';
elseif strcmpi(dist,'poisson')
glmmodel = 'loglinear';
end
end
% check dimensions
[n,p0] = size(X);
d = ndims(M)-1; % dimension of array variates
p = size(M); % sizes array variates
if p(end)~=n
error('tensorreg:kruskal_sparsereg:dim', ...
'dimension of M does not match that of X!');
end
% convert M into a tensor T
TM = tensor(M);
% if space allowing, pre-compute mode-d matricization of TM
if strcmpi(computer,'PCWIN64') || strcmpi(computer,'PCWIN32')
iswindows = true;
% memory function is only available on windows !!!
[dummy,sys] = memory; %#ok<ASGLU>
else
iswindows = false;
end
% CAUTION: may cause out of memory on Linux
if ~iswindows || d*(8*prod(size(TM)))<.75*sys.PhysicalMemory.Available %#ok<PSIZE>
Md = cell(d,1);
for dd=1:d
Md{dd} = double(tenmat(TM,[d+1,dd],[1:dd-1 dd+1:d]));
end
end
% Burn-in stage (loose convergence criterion)
if ~strcmpi(Display,'off')
display(' ');
display('==================');
display('Burn-in stage ...');
display('==================');
end
% reduce tensor size for reliable estimation in burnin stage
if isempty(B0) % no user-supplied start point
if strcmpi(dist,'normal')
shrink_factor = (n/5) / (r*sum(p(1:end-1)));
elseif strcmpi(dist,'binomial')
shrink_factor = (n/10) / (r*sum(p(1:end-1)));
elseif strcmpi(dist,'poisson')
shrink_factor = (n/10) / (r*sum(p(1:end-1)));
end
if shrink_factor <=1
[dummy,beta_burnin] = kruskal_reg(X,M,y,r,dist, ...
'MaxIter',BurninMaxIter, ...
'TolFun',BurninTolFun,...
'Replicates',BurninReplicates,'weights',wts); %#ok<ASGLU>
else
targetdim = round(p(1:end-1)/shrink_factor);
M_reduce = array_resize(M, targetdim);
% estimate at reduced dimension
[dummy,beta_burnin] = kruskal_reg(X,M_reduce,y,r,dist, ...
'MaxIter',BurninMaxIter, ...
'TolFun',BurninTolFun,...
'Replicates',BurninReplicates,'weights',wts); %#ok<ASGLU>
% resize back to original dimension
beta_burnin = array_resize(beta_burnin, p(1:end-1));
% warm start from coarsened estimate
[dummy,beta_burnin] = kruskal_reg(X,M,y,r,dist, ...
'B0', beta_burnin, ...
'MaxIter',BurninMaxIter, ...
'TolFun',BurninTolFun,...
'weights',wts); %#ok<ASGLU>
end
else % user-supplied start point
% check dimension
if ndims(B0)~=d
error('tensorreg:kruskal_sparsereg:badB0', ...
'dimension of B0 does not match that of data!');
end
% turn B0 into a tensor (if it's not)
if isnumeric(B0)
B0 = tensor(B0);
end
% resize to compatible dimension (if it's not)
if any(size(B0)~=p(1:end-1))
B0 = array_resize(B0, p);
end
% perform CP decomposition if it's not a ktensor of correct rank
if isa(B0,'tensor') || isa(B0,'ttensor') || ...
(isa(B0, 'ktensor') && size(B0.U{1},2)~=r)
B0 = cp_als(B0, r, 'printitn', 0);
end
% make sure B0.U is a 1-by-d cell array
B0.U = reshape(B0.U, 1, d);
beta_burnin = B0;
end
% penalization stage
if ~strcmpi(Display,'off')
display(' ');
display('==================');
display('Penalization stage');
display('==================');
end
% turn off warnings from glmfit_priv
if ~warn
warning('off','stats:glmfit:IterationLimit');
warning('off','stats:glmfit:BadScaling');
warning('off','stats:glmfit:IllConditioned');
end
glmstats = cell(1,d+1);
dev0 = inf;
beta = beta_burnin;
for iter = 1:PenaltyMaxIter
% update regular covariate coefficients
if (iter==1)
eta = double(tenmat(TM,d+1)*tenmat(beta,1:d));
else
eta = Xj*betatmp(1:end-1);
end
[betatmp,devtmp,glmstats{d+1}] = glmfit_priv([X,eta],y,dist, ...
'constant','off', ...
'weights',wts);
beta0 = betatmp(1:p0);
% stopping rule
diffdev = devtmp-dev0;
dev0 = devtmp;
if (abs(diffdev)<PenaltyTolFun*(abs(dev0)+1))
break;
end
% update scale of array coefficients and standardize
beta = arrange(beta*betatmp(end));
for j=1:d
beta.U{j} = bsxfun(@times,beta.U{j},(beta.lambda').^(1/d));
end
beta.lambda = ones(r,1);
% cyclic update of array regression coefficients
eta0 = X*beta0;
for j=1:d
if (j==1)
cumkr = ones(1,r);
end
if (exist('Md','var'))
if (j==d)
Xj = reshape(Md{j}*cumkr,n,p(j)*r);
else
Xj = reshape(Md{j}*khatrirao([beta.U(d:-1:j+1),cumkr]),...
n,p(j)*r);
end
else
if (j==d)
Xj = reshape(double(tenmat(TM,[d+1,j]))*cumkr, ...
n,p(j)*r);
else
Xj = reshape(double(tenmat(TM,[d+1,j])) ...
*khatrirao({beta.U{d:-1:j+1},cumkr}),n,p(j)*r);
end
end
if (isglm)
betatmp = glm_sparsereg([Xj,eta0],y,lambda,glmmodel, ...
'weights',wts,...
'penidx',[true(1,p(j)*r),false],...
'penalty',pentype,'penparam',penparam);
else
betatmp = lsq_sparsereg([Xj,eta0],y,lambda,'weights',wts,...
'x0',[beta{j}(:);0],'penidx',[true(1,p(j)*r),false],...
'penalty',pentype,'penparam',penparam);
end
beta{j} = reshape(betatmp(1:end-1),p(j),r);
eta0 = eta0*betatmp(end);
cumkr = khatrirao(beta{j},cumkr);
end
if (~strcmpi(Display,'off'))
disp(' ');
disp([' iterate: ' num2str(iter)]);
disp([' deviance: ' num2str(dev0)]);
disp([' beta0: ' num2str(beta0')]);
end
end
beta0_final = beta0;
beta_final = beta;
if (~strcmpi(Display,'off'))
display(' ');
display('==================');
display('Scaling stage');
display('==================');
end
% find a scaling for the estimates
beta_scale = ktensor(arrayfun(@(j) zeros(p(j),r), 1:d, ...
'UniformOutput',false));
eta0 = X*beta0;
for j=1:d
idxj = 1:d;
idxj(j) = [];
if (exist('Md','var'))
Xj = reshape(Md{j}*khatrirao(beta.U(idxj(end:-1:1))),n,p(j)*r);
else
Xj = reshape(double(tenmat(TM,[d+1,j])) ...
*khatrirao(beta.U(idxj(end:-1:1))),n,p(j)*r);
end
[~,~,glmstats{d}] = glmfit_priv([Xj,eta0],y,dist,'constant','off');
beta_scale{j} = reshape(glmstats{d}.se(1:end-1),p(j),r);
end
% output the BIC
cutoff = 1e-8;
if (d==2)
glmstats{d+1}.BIC = dev0 + log(n)*max(nnz(abs(beta.U{1})>cutoff) ...
+ nnz(abs(beta.U{2})>cutoff)-r*r,0);
else
glmstats{d+1}.BIC = dev0 + log(n)* ...
max(sum(arrayfun(@(j) nnz(beta.U{j}), 1:d, 'UniformOutput',true)) ...
- r*(d-1),0);
end
% say goodbye
if (~strcmpi(Display,'off'))
disp(' ');
disp(' DONE!');
disp(' ');
end
% turn warnings on
if ~warn
warning('on','stats:glmfit:IterationLimit');
warning('on','stats:glmfit:BadScaling');
warning('on','stats:glmfit:IllConditioned');
end
end