Skip to content

Commit

Permalink
Update dev version
Browse files Browse the repository at this point in the history
  • Loading branch information
tripartio committed Nov 11, 2024
1 parent ba95e4b commit 1c5aef3
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 61 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: ale
Title: Interpretable Machine Learning and Statistical Inference with Accumulated Local Effects (ALE)
Version: 0.3.0.20241110
Version: 0.3.0.20241111
Authors@R: c(
person("Chitu", "Okoli", , "[email protected]", role = c("aut", "cre"),
comment = c(ORCID = "0000-0001-5574-7572"))
Expand Down
65 changes: 5 additions & 60 deletions R/stats.R
Original file line number Diff line number Diff line change
Expand Up @@ -515,8 +515,6 @@ summarize_conf_regions_1D <- function(
mid_bar = case_when(
.data$.y_hi < y_zeroed_summary['med_lo'] ~ 'below',
.data$.y_lo > y_zeroed_summary['med_hi'] ~ 'above',
# .data$.y_hi < y_summary[['med_lo', 1]] ~ 'below',
# .data$.y_lo > y_summary[['med_hi', 1]] ~ 'above',
.default = 'overlap'
) |>
factor(ordered = TRUE, levels = c('below', 'overlap', 'above')),
Expand Down Expand Up @@ -574,14 +572,13 @@ summarize_conf_regions_1D <- function(

cr <- cr |>
rename(
# x = '.x',
n = '.n',
y = '.y',
) |>
mutate(
pct = (.data$n / sum(it.ale_data$.n)) * 100,
# Convert x column from ordinal to character for consistency across terms
x = as.character(x),
x = as.character(.data$x),
) |>
select('x', 'y', 'mid_bar', 'n', 'pct')
}
Expand All @@ -597,33 +594,13 @@ summarize_conf_regions_1D <- function(
# https://bard.google.com/chat/ea68c7b9e8437179
select(
'term',
# any_of is used because categorical variables do not have 'start_x', 'end_x', 'x_span_pct'
# while numeric values do not have 'x'
# any_of is used because categorical variables do not have 'start_x', 'end_x', 'x_span_pct' while numeric values do not have 'x'
any_of(c('x', 'start_x', 'end_x', 'x_span_pct')),
'n', 'pct',
any_of(c('y', 'start_y', 'end_y', 'trend')),
'mid_bar'
)

# browser()


# Highlight which confidence regions are statistically significant
# sig_conf_regions <- map2(
# cr_by_term, names(cr_by_term),
# \(it.conf_tbl, it.term) {
# it.conf_tbl$term <- it.term
#
# if ('x' %in% names(it.conf_tbl)) {
# # Convert x column from ordinal to character for consistency across terms
# it.conf_tbl$x <- as.character(it.conf_tbl$x)
# }
#
# it.conf_tbl
# }
# ) |>
# bind_rows() |>

sig_conf_regions <- cr_by_term |>
filter(.data$mid_bar != 'overlap')

Expand Down Expand Up @@ -665,16 +642,14 @@ summarize_conf_regions_2D <- function(
mid_bar = case_when(
.data$.y_hi < y_zeroed_summary['med_lo'] ~ 'below',
.data$.y_lo > y_zeroed_summary['med_hi'] ~ 'above',
# .data$.y_hi < y_summary[['med_lo', 1]] ~ 'below',
# .data$.y_lo > y_summary[['med_hi', 1]] ~ 'above',
.default = 'overlap'
) |>
factor(ordered = TRUE, levels = c('below', 'overlap', 'above')),
) |>
select(-c('.y_lo':'.y_hi')) |>
rename(
n = .n,
y = .y
n = '.n',
y = '.y'
)

# Initialize cr_groups, used only if one or both x variables is non-numeric
Expand Down Expand Up @@ -716,20 +691,11 @@ summarize_conf_regions_2D <- function(
n = sum(.data$n),
pct = (n / total_n) * 100,
y = mean(.data$y),
# n_below = sum(mid_bar == 'below'),
# n_overlap = sum(mid_bar == 'overlap'),
# n_above = sum(mid_bar == 'above'),
# pct_below = (n_below / total_n) * 100,
# pct_overlap = (n_overlap / total_n) * 100,
# pct_above = (n_above / total_n) * 100,
)

# Rename the x variables with their original variable names
x1_x2_names <- x1_x2_names |>
stringr::str_remove("\\.bin$|\\.ceil$")
# names(cr)[1] <- x1_x2_names[1]
# names(cr)[2] <- x1_x2_names[2]


# Convert x data columns uniformly to character format
cr[[1]] <- as.character(cr[[1]])
Expand All @@ -738,39 +704,19 @@ summarize_conf_regions_2D <- function(
# Rename the x data columns consistently
names(cr)[1:2] <- c('x1', 'x2')

# Return value for map function
cr |>
mutate(
term1 = x1_x2_names[1],
term2 = x1_x2_names[2],
) |>
select('term1', 'x1', 'term2', 'x2', everything())


}) |>
set_names(names(ale_data_list)) |>
bind_rows()

# browser()

# Highlight which confidence regions are statistically significant
sig_conf_regions <- cr_by_term |>
# map(\(it.conf_tbl) {
# # browser()
# x1_x2_names <- names(it.conf_tbl)[1:2]
# # Convert x data columns uniformly to character format
# it.conf_tbl[[1]] <- as.character(it.conf_tbl[[1]])
# it.conf_tbl[[2]] <- as.character(it.conf_tbl[[2]])
#
# # Rename the x data columns consistently
# names(it.conf_tbl)[1:2] <- c('x1', 'x2')
#
# it.conf_tbl |>
# mutate(
# term1 = x1_x2_names[1],
# term2 = x1_x2_names[2],
# ) |>
# select('term1', 'x1', 'term2', 'x2', everything())
# }) |>
filter(.data$mid_bar != 'overlap')


Expand All @@ -791,7 +737,6 @@ summarize_conf_regions_1D_in_words <- function(
band_type = 'median'
) {
map_chr(1:nrow(conf_region_summary), \(.row_num) {
# browser()
with(
conf_region_summary[.row_num, ],
if (exists('start_x')) { # conf_region_summary is numeric
Expand Down

0 comments on commit 1c5aef3

Please sign in to comment.