-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #198 from birdflow-science/plot-loss
- Loading branch information
Showing
21 changed files
with
932 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
Package: BirdFlowR | ||
Title: Predict and Visualize Bird Movement | ||
Version: 0.1.0.9065 | ||
Version: 0.1.0.9066 | ||
Authors@R: | ||
c(person("Ethan", "Plunkett", email = "[email protected]", role = c("aut", "cre"), | ||
comment = c(ORCID = "0000-0003-4405-2251")), | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
|
||
# Brownian bridge based weighting. This is a placeholder for now to allow | ||
# implementing weight_betweeness. | ||
|
||
#' calculate the weights of transitions for flux points | ||
#' | ||
#' `calc_dist_weights()` is an internal function that takes summary stats | ||
#' on the relationship between points and a transition line and returns | ||
#' the weight that should be used for that transition. | ||
#' | ||
#' The first three arguments can all be vectors in which case the calculations | ||
#' will be vectorized over the corresponding elements. | ||
#' | ||
#' This is a preliminary version of the function and will likely change. | ||
#' | ||
#' @param dist_to_line How far is the point from the line (m) | ||
#' @param dist_along_line How far along the line is the point, after | ||
#' projecting it onto the line (m) | ||
#' @param line_lengths How long is the line (m) | ||
#' @param radius_m The radius of the transect at the flux points - used to | ||
#' determine the band of brobability desnity that will be added to form | ||
#' the weight. | ||
#' @param res_m The resolution of the associated bird flow model, used to | ||
#' determine the nugget added to the variance to represent the uncertainty in | ||
#' the starting and ending location of the transition. | ||
#' | ||
#' @param method The method used for calculating the standard deviation | ||
#' in the probability distribution. Currently `"m3"`, Martern 3/2; and | ||
#' `"bb"`, brownian bridge are supported. | ||
#' | ||
#' @return A vector of weights of the same length as the first three arguments. | ||
#' @keywords internal | ||
calc_dist_weights <- function(dist_to_line, dist_along_line, line_lengths, | ||
radius_m, res_m, method = "m3") { | ||
|
||
|
||
|
||
# Spatial information | ||
len <- line_lengths / 1000 # length of great circle in km | ||
t <- dist_along_line / 1000 # Distance from start of great circle to where | ||
# the point projects onto the great circle in km | ||
d <- dist_to_line / 1000 # Distance from reference point to the great circle | ||
|
||
res_km <- res_m / 1000 | ||
|
||
r <- radius_m / 1000 # Radius of transact in KM. | ||
|
||
# s2 is the standard deviation of the nugget | ||
s2 <- res_km / 4 # This initial setting of 1/4 the cell width converted | ||
# to KM means the cell boundary (orthagonally) is at 2 | ||
# standard deviations from the cell center, the corners | ||
# will be at 2.8 SD. | ||
|
||
valid_methods <- c("bb", # Brownian Bridge | ||
"m3") # Martern 3/2 | ||
|
||
if (!method %in% valid_methods) | ||
stop("Method shold be one of ", paste(valid_methods, collapse = ", ")) | ||
|
||
if (method == "bb") { | ||
|
||
# Hyperparameters | ||
s1 <- 10 # Tune this visually looking at a few resulting distributions? | ||
|
||
|
||
# Standard deviation | ||
sd <- sqrt(s1^2 * t * (len - t) / len + s2^2) # SD in KM for this point. | ||
|
||
} | ||
|
||
if (method == "m3") { | ||
sd <- sqrt(calc_martern_variance(t, len, k_m3, gamma = 40000, kl = 2000) + | ||
s2^2) | ||
} | ||
|
||
|
||
# Calculate weight | ||
weight <- rep(0, length(sd)) | ||
in_range <- (d - r) < 1.96 * sd | ||
low_cum_prob <- stats::pnorm(d[in_range] - r, sd = sd[in_range]) | ||
high_cum_prob <- stats::pnorm(d[in_range] + r, sd = sd[in_range]) | ||
weight[in_range] <- high_cum_prob - low_cum_prob | ||
|
||
return(weight) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# Define kernel functions | ||
# Original version used "l" instead of "kl" but that led to a | ||
# warning about partial matching of "l" to "len" when invoked via ... | ||
# in calc_dist_weights(). kl makes it unambiguous. | ||
k_m1 <- function(d, gamma, kl) { | ||
gamma * (1 + sqrt(3) * d / kl) | ||
} | ||
|
||
k_m3 <- function(d, gamma = 40, kl = 40) { | ||
gamma * (1 + sqrt(3) * d / kl) * exp(-sqrt(3) * d / kl) | ||
} | ||
|
||
k_m5 <- function(d, gamma = 40, kl = 40) { | ||
gamma * (1 + sqrt(5) * d / kl + 5 * d^2 / (3 * kl^2)) * exp(-sqrt(5) * d / kl) | ||
} | ||
|
||
k_sq <- function(d, gamma, kl) { | ||
gamma * exp(-0.5 * (d / kl)^2) | ||
} | ||
|
||
# Note the argument k is a kernel function to use, presumably one of the above. | ||
# ... is used to flexibly pass additional parameters (gamma and l through to k) | ||
calc_martern_variance <- function(t, len, k, ...) { | ||
variance <- k(0, ...) - | ||
(k(t, ...)^2 + k(len - t, ...)^2 - 2 * | ||
k(t, ...) * k(len - t, ...) * k(len, ...) / k(0, ...)) / | ||
(k(0, ...) * (1 - k(len, ...)^2 / k(0, ...)^2)) | ||
return(pmax(variance, 0)) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
|
||
#' Get loss values for each step in the model fitting process | ||
#' | ||
#' `get_loss()` returns a data frame of loss values. Each row corresponds to | ||
#' a step in the fitting process. | ||
#' @param bf A BirdFlow object. | ||
#' | ||
#' @return A data frame with columns: | ||
#' \item{dist}{The distance loss} | ||
#' \item{ent}{The entropy loss} | ||
#' \item{obs}{The observation loss} | ||
#' \item{total}{The total weighted loss} | ||
#' @export | ||
#' | ||
#' @examples | ||
#' bf <- BirdFlowModels::amewoo | ||
#' get_loss(bf) | ||
#' @seealso [plot_loss()] | ||
get_loss <- function(bf) { | ||
if (!"loss_values" %in% names(bf$metadata)) | ||
stop("Missing loss values in the BirdFlow metadata. ", | ||
"Is this a fitted model?") | ||
bf$metadata$loss_values | ||
} | ||
|
||
|
||
|
||
|
||
#' Plot changes in component and total loss during model fitting | ||
#' | ||
#' Model fitting - in [BirdFlowPy]() - attempts to minimize the total weighted | ||
#' loss. This plot shows four lines: | ||
#' * **Total loss** is the weighted sum of the three loss components. The | ||
#' weighting may cause it to be lower than some of the components. | ||
#' * **Observation loss** captures how well the model predicts the Status and | ||
#' Trend distributions it was trained on. Its weight is always set to 1 and its | ||
#' relative weight is changed by adjusting the other to weights which are | ||
#' usually much less than 1. | ||
#' * **Distance loss** is lower when the routes encoded in the model are | ||
#' shorter. | ||
#' * **Entropy loss** is lower when the entropy in the model is higher. | ||
#' | ||
#' | ||
#' @param bf A fitted Bird Flow model | ||
#' @param transform Passed to [ggplot2::scale_y_continuous()] to set the y-axis | ||
#' transformation. Reasonable values for this function include | ||
#' "identity", "log", "log10", "log2", and "sqrt". | ||
#' @return a **ggplot2** plot object. | ||
#' @export | ||
#' @examples | ||
#' bf <- BirdFlowModels::amewoo | ||
#' plot_loss(bf) | ||
plot_loss <- function(bf, transform = "log10") { | ||
|
||
if (!all(c("loss_values", "hyperparameters") %in% names(bf$metadata))) | ||
stop("Missing loss or hyperparmers from metadata. Is this a fitted model?") | ||
loss <- get_loss(bf) |> | ||
dplyr::rename(Distance = "dist", | ||
Entropy = "ent", | ||
Observation = "obs", | ||
Total = "total") | ||
loss$Step <- seq_len(nrow(loss)) | ||
|
||
|
||
# Convert to long format | ||
d <- tidyr::pivot_longer(loss, cols = -ncol(loss), values_to = "Loss", | ||
names_to = "Type") | ||
|
||
hp <- bf$metadata$hyperparameters | ||
|
||
subtitle <- paste0("Weights: Observation:", hp$obs_weight, | ||
" Distance: ", hp$dist_weight, | ||
" Entropy: ", hp$ent_weight) | ||
|
||
ggplot2::ggplot(data = d, | ||
ggplot2::aes(x = .data$Step, | ||
y = .data$Loss, | ||
color = .data$Type)) + | ||
ggplot2::geom_line(linewidth = .8) + | ||
ggplot2::scale_y_continuous(transform = "log10") + | ||
ggplot2::ggtitle(paste0(species(bf), " Loss"), subtitle = subtitle) | ||
|
||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.