-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathflux_conserve_proj.py
126 lines (112 loc) · 3.94 KB
/
flux_conserve_proj.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
# Approach used for the flux-conservation constraint from http://dx.doi.org/10.1051/0004-6361/201118681 and Bonettini et al. 2009.
import numpy as np
def projectDF(b, c, dia, lambda_=0, dlambda_=1, tol_lam=1e-11, max_projs=1000):
"""
Equation: min 0.5 * x' * diag(dia) * x - c' * x
subj to sum(x) = b
x >= 0
"""
c = c.astype(np.float64, copy=False)
dia = dia.astype(np.float64, copy=False)
b = b.astype(np.float64, copy=False)
biter, siter = 0, 0
tol_r = 1e-11 * b
# Bracketing phase.
# np.maximum takes element-wise maximum unlike np.max
x = np.maximum(0, np.divide(c+lambda_, dia))
r = np.sum(x) - b
if abs(r) < tol_r:
return x
if r < 0:
lambdal = lambda_
rl = r
lambda_ = lambda_ + dlambda_
x = np.maximum(0, np.divide(c+lambda_, dia))
r = np.sum(x) - b
while r < 0:
biter = biter + 1
lambdal = lambda_
## The below commented two lines can also be used as a replacement for the following two lines: from the paper.
# s = min(1, max(rl/r - 1, 0.1))
# dlambda_ /= s
s = np.max([rl/r-1, 0.1])
dlambda_ = dlambda_ + dlambda_/s
lambda_ = lambda_ + dlambda_
rl = r
x = np.maximum(0, np.divide(c+lambda_, dia))
r = np.sum(x) - b
lambdau = lambda_
ru = r
else:
lambdau = lambda_
ru = r
lambda_ = lambda_ - dlambda_
x = np.maximum(0, np.divide(c+lambda_, dia))
r = np.sum(x) - b
while r > 0:
biter = biter + 1
lambdau = lambda_
s = np.max([ru/r-1, 0.1])
# Program hangs for some image at the below line. Hence catch warning as an exception and continue.
try:
with np.errstate(all='raise'):
dlambda_ = dlambda_ + dlambda_/s
except Exception:
break
# raise ValueError("Something bad happened!")
lambda_ = lambda_ - dlambda_
ru = r
x = np.maximum(0, np.divide(c+lambda_, dia))
r = np.sum(x) - b
lambdal = lambda_
rl = r
# Check ru and rl
if abs(ru) < tol_r:
x = np.maximum(0, np.divide(c+lambdau, dia))
return x
if abs(rl) < tol_r:
x = np.maximum(0, np.divide(c+lambdal, dia))
return x
# Secant phase
s = 1 - rl/ru
dlambda_ = dlambda_/s
lambda_ = lambdau - dlambda_
x = np.maximum(0, np.divide(c+lambda_, dia))
r = np.sum(x) - b
maxit_s = max_projs - biter
# Main loop
while abs(r) > tol_r and dlambda_ > tol_lam * (1 + abs(lambda_)) and siter < maxit_s:
siter = siter + 1
if r > 0:
if s <= 2:
lambdau = lambda_
ru = r
s = 1-rl/ru
dlambda_ = (lambdau-lambdal)/s
lambda_ = lambdau - dlambda_
else:
s = np.max([ru/r-1, 0.1])
dlambda_ = (lambdau-lambda_) / s
lambda_new = np.max([lambda_-dlambda_, 0.75*lambdal+0.25*lambda_])
lambdau = lambda_
ru = r
lambda_ = lambda_new
x = (lambdau-lambdal) / (lambdau-lambda_)
else:
if s >= 2:
lambdal = lambda_
rl = r
s = 1-rl/ru
dlambda_ = (lambdau-lambdal)/s
lambda_ = lambdau - dlambda_
else:
s = np.max([rl/r-1, 0.1])
dlambda_ = (lambda_-lambdal) / s
lambda_new = np.min([lambda_+dlambda_, 0.75*lambdau+0.25*lambda_])
lambdal = lambda_
rl = r
lambda_ = lambda_new
s = (lambdau-lambdal) / (lambdau-lambda_)
x = np.maximum(0, np.divide(c+lambda_, dia))
r = np.sum(x) - b
return x