-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathimportance.R
141 lines (115 loc) · 4.79 KB
/
importance.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
subgroup <- function(dataset, rule_col, compare, threshold) {
sbgrp <- switch(compare,
"<" = dataset[,rule_col] < threshold,
"<=" = dataset[,rule_col] <= threshold,
">" = dataset[,rule_col] > threshold,
">=" = dataset[,rule_col] >= threshold)
return(sbgrp)
}
avg_trt_effect <- function(dataset, subgroup) {
y <- dataset$y
treated <- dataset$trt == 1
mean(y[treated & subgroup]) - mean(y[!treated & subgroup])
}
# requires two packages to be installed, randomForest and FSelector
#install.packages("randomForest")
#install.packages("FSelector")
library("randomForest")
library("FSelector")
library(rpart)
#library(rattle)
library(rpart.utils)
# change to appropriate directory for your laptop
setwd("C:/Users/Anne/Downloads/subgroup_prediction")
data <- read.csv("Data/Data.csv")
max_datasets <- max(data$dataset)
# for now, make simple
subgroups_base <- data.frame(matrix("", 240, max_datasets+ 1))
names(subgroups_base) <- c("id", paste("dataset", 1:max_datasets, sep="_") )
subgroups_base$id <- 1:240
subgroups_fancier <- subgroups_base
subgroups_all <- subgroups_base
subgroups_fancy <- subgroups_base
subgroups_fancy2 <- subgroups_base
subgroups_fancier2 <- subgroups_base
subgroups_fancy3 <- subgroups_base
subgroups_fancier3 <- subgroups_base
no_diff <- 0
not_first <- 0
not_point5 <- 0
not_one <- 0
for (i in 1:max_datasets) {
dataset <- data[data$dataset == i,]
#dataset[,5:24] <- ifelse(dataset[,5:24] > 0, 1, 0) # if we want to collapse 2's and 1's
x <- dataset[,c(3,5:ncol(dataset))] # covariates plus trt column
y <- dataset$y # response column
# create datasets for virtual twin predictions
dataset_all_trt <- dataset
dataset_all_ctrl <- dataset
dataset_all_trt$trt <- 1
dataset_all_ctrl$trt <- 0
# random forest regression
blah <- randomForest(x,y)
trt.pred <- predict(blah, dataset_all_trt)
ctrl.pred <- predict(blah, dataset_all_ctrl)
# virtual twin treatment effect estimates
z <- trt.pred - ctrl.pred
# try one rule based on all data
rpart_test <- dataset[,5:ncol(dataset)]
rpart_test$z <- z
fit <- rpart(z ~., data = rpart_test, maxdepth = 1)
subgroups_base[,paste("dataset", i, sep="_")] <- ifelse(fit$where == 2, 1, 0)
subgroups_fancy[,paste("dataset", i, sep="_")] <- ifelse(fit$where == 2, 1, 0)
subgroups_all[,paste("dataset", i, sep="_")] <- ifelse(fit$where == 2, 1, 0)
subgroups_fancy2[,paste("dataset", i, sep="_")] <- ifelse(fit$where == 2, 1, 0)
diff <- mean(dataset$y[fit$where == 2 & dataset$trt == 1]) - mean(dataset$y[fit$where == 2 & dataset$trt == 0])
# subgroup doesn't look meaningful
if (diff > -.6) {
subgroups_base[,paste("dataset", i, sep="_")] <- 0
subgroups_fancier[,paste("dataset", i, sep="_")] <- 0
subgroups_fancier2[,paste("dataset", i, sep="_")] <- 0
subgroups_fancier3[,paste("dataset", i, sep="_")] <- 0
no_diff <- no_diff + 1
}
rule_col <- names(rpart.lists(fit)$L)[1]
# not in top important rules!
if (! (rule_col %in% names(fit$variable.importance))) {
subgroups_fancy[,paste("dataset", i, sep="_")] <- 0
subgroups_fancier[,paste("dataset", i, sep="_")] <- 0
subgroups_fancy2[,paste("dataset", i, sep="_")] <- 0
subgroups_fancier2[,paste("dataset", i, sep="_")] <- 0
subgroups_fancy3[,paste("dataset", i, sep="_")] <- 0
subgroups_fancier3[,paste("dataset", i, sep="_")] <- 0
not_first <- not_first + 1
not_point5 <- not_point5 + 1
not_one <- not_one + 1
}
else {
# not the most important rule!
if (rule_col != names(fit$variable.importance)[1]) {
subgroups_fancy[,paste("dataset", i, sep="_")] <- 0
subgroups_fancier[,paste("dataset", i, sep="_")] <- 0
not_first <- not_first + 1
}
# importance isn't over some threshold
if (fit$variable.importance[rule_col] < .5) {
subgroups_fancy2[,paste("dataset", i, sep="_")] <- 0
subgroups_fancier2[,paste("dataset", i, sep="_")] <- 0
not_point5 <- not_point5 + 1
}
# importance isn't over another threshold
if (fit$variable.importance[rule_col] < 1) {
subgroups_fancy3[,paste("dataset", i, sep="_")] <- 0
subgroups_fancier3[,paste("dataset", i, sep="_")] <- 0
not_one <- not_one + 1
}
}
}
write.csv(subgroups_all, "Data/subgroups_all.csv", row.names=F)
write.csv(subgroups_base, "Data/subgroups_base.csv", row.names=F)
write.csv(subgroups_fancy, "Data/subgroups_fancy.csv", row.names=F)
write.csv(subgroups_fancier, "Data/subgroups_fancier.csv", row.names=F)
write.csv(subgroups_fancy2, "Data/subgroups_fancy2.csv", row.names=F)
write.csv(subgroups_fancier2, "Data/subgroups_fancier2.csv", row.names=F)
write.csv(subgroups_fancy3, "Data/subgroups_fancy3.csv", row.names=F)
write.csv(subgroups_fancier3, "Data/subgroups_fancier3.csv", row.names=F)