Skip to content

Commit 42639b0

Browse files
committed
Improve fpl self starting
1 parent b0bbaf4 commit 42639b0

File tree

2 files changed

+33
-35
lines changed

2 files changed

+33
-35
lines changed

DESCRIPTION

-1
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,4 @@ Encoding: UTF-8
1818
Roxygen: list(markdown = TRUE)
1919
RoxygenNote: 7.3.1
2020
Imports:
21-
drc (>= 3.0.1),
2221
DEoptimR (>= 1.1.3)

R/drob.R

+33-34
Original file line numberDiff line numberDiff line change
@@ -33,19 +33,26 @@ fpl <- list(
3333
a / (1 + a)
3434
)
3535
},
36-
init = function(x, y, extend = 3) {
37-
idx <- as.factor(x)
38-
s <- tapply(y, idx, sd)[idx]
39-
est <- drc::drm(y ~ x, fct = drc::LL.4(), weights = 1 / s)
40-
coef <- summary(est)$coefficients
41-
idx <- if (coef[1] >= 0) c(3, 1, 4, 2) else c(2, 1, 4, 3)
42-
t <- unname(coef[idx, 1])
43-
se <- unname(coef[idx, 2])
36+
init = function(x, y, extend = 15, eps = 1e-6) {
37+
ux <- unique(x)
38+
uy <- tapply(y, x, mean)
39+
r <- range(uy)
40+
p <- (uy - r[2]) / (r[1] - r[2])
41+
p <- (1 - 2 * eps) * p[ux != 0] + eps
42+
b <- coef(lm(log(p / (1 - p)) ~ log(ux[ux != 0])))
43+
t0 <- list(t1 = r[1], t2 = -b[2], t3 = exp(-b[1] / b[2]), t4 = r[2])
44+
w <- 1 / as.vector(tapply(y, x, sd)[as.factor(x)])
45+
b <- summary(nls(
46+
y ~ t4 + (t1 - t4) / (1 + (x / t3)^t2), start = t0, weights = w
47+
))$coefficients
48+
i <- if (b[2, 1] < 0) c(4, 2, 3, 1) else 1:4
49+
t <- unname(b[i, 1])
4450
t[2] <- abs(t[2])
45-
bound <- function(x, dir) x + dir * extend * abs(x)
46-
init <- list(t = t, se = se, lower = bound(t, -1), upper = bound(t, +1))
47-
init$lower[c(2, 3)] <- 1e-100
48-
init
51+
se <- unname(b[i, 2])
52+
lower <- t - extend * se
53+
upper <- t + extend * se
54+
lower[c(2, 3)] <- 0
55+
list(t = t, se = se, lower = lower, upper = upper)
4956
}
5057
)
5158

@@ -228,13 +235,11 @@ m_scale <- function(r, rho, extend = 5) {
228235
#' argument `"sl1"` to `step_2`). By default 1.48 in order to get 1 under
229236
#' the standard normal distribution.
230237
#'
231-
#' @param de_args A function that takes a list with the arguments to be passed
232-
#' to `JDEoptim` in step 1 and returns and arbitrarily modified list of
233-
#' arguments. By default it returns its argument unmodified.
238+
#' @param de_args A list that overrides arguments passed to `JDEoptim` in
239+
#' step 1 as in `utils::modifyList(args, de_args)`. By default it is empty.
234240
#'
235-
#' @param qn_args A function that takes a list with the arguments to be passed
236-
#' to `optim` in step 3 and returns and arbitrarily modified list of
237-
#' arguments. By default it returns its argument unmodified.
241+
#' @param qn_args A list that overrides arguments passed to `optim` in
242+
#' step 3 as in `utils::modifyList(args, qn_args)`. By default it is empty.
238243
#'
239244
#' @param qn_gr A flag that indicates if a gradient function is to be built and
240245
#' passed as the argument to the parameter `gr` of the `optim` routine in
@@ -247,11 +252,6 @@ m_scale <- function(r, rho, extend = 5) {
247252
#' be passed to `m_scale` in order to extend the root finding interval (for
248253
#' further details, refer to the documentation of `m_scale`). By default 5.
249254
#'
250-
#' @param init_extend Passed to `model$init` in order to extend the search
251-
#' region defined by the lower and bounds employed both for the differential
252-
#' evolution routine in step 1 and for the quasi-Newton L-BFGS-B routine in
253-
#' step 3. By default 3.
254-
#'
255255
#' @return A list containing the following elements:
256256
#' - `t`: a vector with the final location estimate, produced by step 3.
257257
#' - `t0`: a vector with the initial location estimate, produced by step 1.
@@ -293,11 +293,10 @@ drob <- function( # nolint
293293
mbi_k = 3.44,
294294
sbi_k = 1.548,
295295
sl1_k = 1.48,
296-
de_args = identity,
297-
qn_args = identity,
296+
de_args = list(),
297+
qn_args = list(),
298298
qn_gr = FALSE,
299-
ms_extend = 5,
300-
init_extend = 3
299+
ms_extend = 5
301300
) {
302301
select <- function(arg, ...) if (is.character(arg)) switch(arg, ...) else arg
303302
mbi <- bisquare(mbi_k)
@@ -307,7 +306,7 @@ drob <- function( # nolint
307306
"fpl" = fpl,
308307
stop("Invalid model '", model, "'")
309308
)
310-
init <- model$init(x, y, init_extend)
309+
init <- model$init(x, y)
311310
lower <- init$lower
312311
upper <- init$upper
313312

@@ -325,11 +324,11 @@ drob <- function( # nolint
325324
)
326325
grid <- matrix(runif(1000, lower, upper), nrow = length(lower))
327326
mloss <- median(abs(apply(grid, 2, loss)))
328-
args <- de_args(list(
329-
fn = function(t) loss(y - model$fun(x, t)),
330-
lower = lower, upper = upper, fnscale = mloss, tol = 1e-6
331-
))
332-
res <- do.call(DEoptimR::JDEoptim, args)
327+
args <- list(
328+
fn = function(t) loss(y - model$fun(x, t)), lower = lower, upper = upper,
329+
fnscale = mloss, tol = 1e-8, maxiter = 500 * length(lower)
330+
)
331+
res <- do.call(DEoptimR::JDEoptim, utils::modifyList(args, de_args))
333332
if (res$convergence == 1) stop("Step 1: optimizer failed")
334333
t0 <- res$par
335334

@@ -362,7 +361,7 @@ drob <- function( # nolint
362361
control = list(parscale = t0, trace = 0)
363362
)
364363
optimize <- function(...) {
365-
res <- do.call(optim, qn_args(append(args, list(...))))
364+
res <- do.call(optim, utils::modifyList(append(args, list(...)), qn_args))
366365
if (res$convergence %in% c(51, 52)) {
367366
stop("Step 3: optimizer failed with '", res$message, "'")
368367
}

0 commit comments

Comments
 (0)