Skip to content

Commit 370efb9

Browse files
committed
Change default parameters
1. For lbfgs, the curvature could be very small 2. When iterate between A and Z, their line-search step sizes could be difficult to identify. Now using a large maxiter to ensure descend and adaptive init_stepsize to make it faster.
1 parent 06ddd62 commit 370efb9

File tree

3 files changed

+25
-24
lines changed

3 files changed

+25
-24
lines changed

RarmaSolvers.m

+10-10
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
% regularized ARMA
44

55
properties(Constant)
6-
SUCCESS = 0;
7-
ERROR_MAXITER = 10;
8-
ERROR_BACKTRACK = 20;
9-
ERROR_BACKTRACK_MAXITER = 21
6+
SUCCESS = 0;
7+
ERROR_MAXITER = 10;
8+
ERROR_BACKTRACK = 20;
9+
ERROR_BACKTRACK_MAXITER = 21;
1010
end
1111

1212
methods(Static)
@@ -16,7 +16,7 @@
1616
% x0 can be a matrix
1717

1818
backtrack_backoff = 0.5;
19-
backtrack_maxiter = 10;
19+
backtrack_maxiter = 100; % make this large to ensure descend
2020

2121
for iter = 1:backtrack_maxiter
2222

@@ -38,8 +38,8 @@
3838

3939
if ~isa(fun,'function_handle'), error('fmin_LBFGS -> improper function handle'); end
4040

41-
DEFAULTS.curvTol = 1e-3;
42-
DEFAULTS.funTol = 1e-6;
41+
DEFAULTS.curvTol = 1e-6; % for curvature condition
42+
DEFAULTS.funTol = 1e-6;
4343
DEFAULTS.m = 50; % number gradients in bundle
4444
DEFAULTS.maxiter = 1000;
4545
DEFAULTS.verbose = 0;
@@ -58,7 +58,7 @@
5858
opts = RarmaUtilities.getOptions(opts, DEFAULTS);
5959
end
6060

61-
x = x0;
61+
x = x0;
6262
flag = RarmaSolvers.SUCCESS;
6363

6464
t = length(x0);
@@ -70,7 +70,7 @@
7070
slope = Inf;
7171

7272
% damped limited memory BFGS method
73-
[f,g] = fun(x);
73+
[f,g] = fun(x);
7474
for iter = 1:opts.maxiter
7575
% compute search direction
7676
dir = RarmaSolvers.invhessmult(-g,Y,S,Rho,H0,inds,opts.m);
@@ -85,7 +85,7 @@
8585
s = xnew - x;
8686
y = gnew - g;
8787
curvature = y'*s;
88-
if curvature > opts.curvTol
88+
if curvature > opts.curvTol
8989
rho = 1/curvature;
9090
if length(inds) < opts.m
9191
i = length(inds)+1;

demo.m

+6-5
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
clear
22
clc
33

4-
rng(100);
4+
rng(10);
55
num_reps = 10;
66
Models = cell(num_reps,1);
77
Xpredictall = cell(num_reps,1);
8+
isStable = zeros(num_reps,1);
89
Err = zeros(num_reps,1);
910

1011
%% Generate data
@@ -14,15 +15,15 @@
1415

1516
%% Learn RARMA
1617
% In practice, the following paramters should be
17-
% cross-validated before applied
18+
% cross-validated for EACH dataset before applied
1819
opts = [];
1920
opts.ardim = 2;
2021
opts.madim = 2;
21-
opts.reg_wgt_ar = 1e-2;
22-
opts.reg_wgt_ma = 1e-1;
22+
opts.reg_wgt_ar = 0.07; % stronger regularization on A -> more stable
23+
opts.reg_wgt_ma = 0.01;
2324
for ii = 1:num_reps
2425
Models{ii} = rarma(Xtrainall{ii},opts);
25-
[isStable, eigs] = RarmaUtilities.checkStable(Models{ii}.A)
26+
isStable(ii) = RarmaUtilities.checkStable(Models{ii}.A);
2627
end
2728

2829
%% Prediction and Evaluation

rarma.m

+9-9
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
DEFAULTS.ardim = 5;
2828
DEFAULTS.init_stepsize = 10;
2929
DEFAULTS.Loss = @RarmaFcns.euclidean_rarma;
30-
DEFAULTS.maxiter = 100;
30+
DEFAULTS.maxiter = 1000;
3131
DEFAULTS.madim = 5;
3232
DEFAULTS.recover = 1; % Recover B and Epsilon from learned Z
3333
DEFAULTS.reg_ar = @RarmaFcns.frob_norm_sq;
@@ -107,14 +107,14 @@
107107
Z = zeros(sizeZ);
108108
[A, obj, iter, msg] =opts.optimizer(@(Avec)(objA(Avec, X, Z)), Ainit(:));
109109
A = reshape(A, sizeA);
110-
[Z, prev_obj] = iterateZ(Z,A);
110+
[Z, prev_obj] = iterateZ(Z,A,opts.init_stepsize);
111111

112112
for i = 1:opts.maxiter
113113
% Do A first since it returns the incorrect obj
114-
A = iterateA(A,Z);
115-
[Z, obj] = iterateZ(Z,A);
114+
A = iterateA(A,Z,opts.init_stepsize/i); % adaptive stepsize
115+
[Z, obj] = iterateZ(Z,A,opts.init_stepsize/i);
116116

117-
if abs(obj-prev_obj) < opts.TOL
117+
if abs(prev_obj-obj) < opts.TOL % doing minimization
118118
break;
119119
end
120120
prev_obj = obj;
@@ -133,15 +133,15 @@
133133
model.predict = @(Xstart, horizon, opts)(RarmaFcns.iterate_predict_ar(Xstart, model, horizon, opts));
134134
end
135135

136-
function [A, f] = iterateA(A, Z)
136+
function [A, f] = iterateA(A, Z, init_stepsize)
137137
[f,g] = objA(A, X, Z);
138-
stepsize = RarmaSolvers.line_search(A, f, g, @(A)(objA(A, X, Z)),opts.init_stepsize);
138+
stepsize = RarmaSolvers.line_search(A, f, g, @(A)(objA(A, X, Z)),init_stepsize);
139139
A = A - stepsize*(g);
140140
end
141141

142-
function [Z, f] = iterateZ(Z, A)
142+
function [Z, f] = iterateZ(Z, A, init_stepsize)
143143
[f,g] = objZ(Z, A);
144-
stepsize = RarmaSolvers.line_search(Z, f, g, @(Z)(objZ(Z, A)), opts.init_stepsize/10);
144+
stepsize = RarmaSolvers.line_search(Z, f, g, @(Z)(objZ(Z, A)), init_stepsize);
145145
Z = Z - stepsize*g;
146146
end
147147

0 commit comments

Comments
 (0)