Skip to content

Commit

Permalink
smartAlign2(): correctly calculate positions of boxes
Browse files Browse the repository at this point in the history
  • Loading branch information
jokergoo committed Mar 20, 2020
1 parent 5969936 commit 448f5de
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 71 deletions.
1 change: 1 addition & 0 deletions NEWS
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ CHANGES in VERSION 2.3.2
* move scripts in test_not_run/ to tests/ folder
* `Heatmap()`: `cluster_row_slice`/`cluster_column_slice` set to TRUE
by default for character matrix and when dendrogram is already provided.
* `smartAlign2()`: improved the code

========================

Expand Down
2 changes: 1 addition & 1 deletion R/HeatmapAnnotation-class.R
Original file line number Diff line number Diff line change
Expand Up @@ -961,6 +961,7 @@ anno_type = function(ha) {
# ha[, 1:2]
# ha[1:5, c("foo", "sth")]
"[.HeatmapAnnotation" = function(x, i, j) {

if(!missing(j)) {
if(is.character(j)) {
j = which(names(x@anno_list) %in% j)
Expand Down Expand Up @@ -1004,7 +1005,6 @@ anno_type = function(ha) {
x2@gap[length(x2@gap)] = unit(0, "mm")

size(x2) = sum(x2@anno_size) + sum(x2@gap) - x2@gap[length(x2@gap)]

}

extended = unit(c(0, 0, 0, 0), "mm")
Expand Down
3 changes: 3 additions & 0 deletions R/oncoPrint.R
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ oncoPrint = function(mat,
heatmap_legend_param = list(title = "Alterations"),
...) {

dev.null()
on.exit(dev.off2())

arg_list = as.list(match.call())[-1]
arg_names = names(arg_list)

Expand Down
115 changes: 66 additions & 49 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -965,52 +965,39 @@ resize_matrix = function(mat, nr, nc) {
# -start position which corresponds to the start (bottom or left) of the rectangle-shapes.
# -end position which corresponds to the end (top or right) of the rectanglar shapes.
# -range data ranges (the minimal and maximal values)
# -range_fixed Whether the range is fixed for ``range`` when adjust the positions?
# -plot Whether plot the correspondance between the original positions and the adjusted positions. Only for testing.
#
# == details
# This is an improved version of the `circlize::smartAlign`.
#
# It adjusts the positions of the rectangular shapes to make them do not overlap
#
# == example
# require(circlize)
# make_plot = function(pos1, pos2, range) {
# oxpd = par("xpd")
# par(xpd = NA)
# plot(NULL, xlim = c(0, 4), ylim = range, ann = FALSE)
# col = rand_color(nrow(pos1), transparency = 0.5)
# rect(0.5, pos1[, 1], 1.5, pos1[, 2], col = col)
# rect(2.5, pos2[, 1], 3.5, pos2[, 2], col = col)
# segments(1.5, rowMeans(pos1), 2.5, rowMeans(pos2))
# par(xpd = oxpd)
# }
#
# range = c(0, 10)
# pos1 = rbind(c(1, 2), c(5, 7))
# make_plot(pos1, smartAlign2(pos1, range = range), range)
# smartAlign2(pos1, range = range, plot = TRUE)
#
# range = c(0, 10)
# pos1 = rbind(c(-0.5, 2), c(5, 7))
# make_plot(pos1, smartAlign2(pos1, range = range), range)
# smartAlign2(pos1, range = range, plot = TRUE)
#
# pos1 = rbind(c(-1, 2), c(3, 4), c(5, 6), c(7, 11))
# pos1 = pos1 + runif(length(pos1), max = 0.3, min = -0.3)
# omfrow = par("mfrow")
# par(mfrow = c(3, 3))
# for(i in 1:9) {
# ind = sample(4, 4)
# make_plot(pos1[ind, ], smartAlign2(pos1[ind, ], range = range), range)
# smartAlign2(pos1[ind, ], range = range, plot = TRUE)
# }
# par(mfrow = omfrow)
#
# pos1 = rbind(c(3, 6), c(4, 7))
# make_plot(pos1, smartAlign2(pos1, range = range), range)
# smartAlign2(pos1, range = range, plot = TRUE)
#
# pos1 = rbind(c(1, 8), c(3, 10))
# make_plot(pos1, smartAlign2(pos1, range = range), range)
# make_plot(pos1, smartAlign2(pos1, range = range, range_fixed = FALSE), range)
# smartAlign2(pos1, range = range, plot = TRUE)
#
smartAlign2 = function(start, end, range, range_fixed = TRUE) {
smartAlign2 = function(start, end, range, plot = FALSE) {

if(missing(end)) {
x1 = start[, 1]
Expand All @@ -1024,16 +1011,54 @@ smartAlign2 = function(start, end, range, range_fixed = TRUE) {
range = range(c(x1, x2))
}

make_plot = function(pos1, pos2, main = "") {
oxpd = par("xpd")
par(xpd = NA)
plot(NULL, xlim = c(0, 4), ylim = range(c(pos1, pos2)), ann = FALSE, axes = FALSE)
col = rand_color(nrow(pos1), transparency = 0.5)
rect(0.5, pos1[, 1], 1.5, pos1[, 2], col = col)
rect(2.5, pos2[, 1], 3.5, pos2[, 2], col = col)
segments(1.5, rowMeans(pos1), 2.5, rowMeans(pos2))
text(1, -0.02, "original", adj = c(0.5, 1))
text(3, -0.02, "adjusted", adj = c(0.5, 1))
title(main)
par(xpd = oxpd)
}


od = order(x1)
rk = rank(x1, ties.method = "random")
x1 = x1[od]
x2 = x2[od]
mid = (x1 + x2)/2
h = x2 - x1
n = length(x1)

ox1 = x1
ox2 = x2

# sum of box heights exceeds range
if(sum(h) > range[2] - range[1]) {
a = ((range[2] - h[n]/2) - (range[1] + h[1]/2))/(n-1)
m = range[1] + 1:n*a
nx1 = m - h/2
nx2 = m + h/2

if(plot) {
make_plot(cbind(ox1, ox2), cbind(nx1, nx2), main = "sum of box heights exceeds range")
}

df = data.frame(start = x1, end = x2)
return(df[rk, , drop = FALSE])
}

ncluster.before = -1
ncluster = length(x1)
i_try = 0
i_try = 1
while(ncluster.before != ncluster) {

if(i_try > 100) break

ncluster.before = ncluster
cluster = rep(0, length(x1))
i_cluster = 1
Expand All @@ -1048,43 +1073,35 @@ smartAlign2 = function(start, end, range, range_fixed = TRUE) {
}
}
ncluster = length(unique(cluster))

if(ncluster.before == ncluster && i_try > 0) break

if(i_try > 100) break

# tile intervals in each cluster and re-assign x1 and x2
new_x1 = numeric(length(x1))
new_x2 = numeric(length(x2))
for(i_cluster in unique(cluster)) {
index = which(cluster == i_cluster)
total_len = sum(x2[index] - x1[index]) # sum of the height in the cluster
mid = (min(x1[index]) + max(x2[index]))/2
if(total_len > range[2] - range[1]) {
# tp = seq(range[1], range[2], length = length(index) + 1)
if(range_fixed) {
tp = cumsum(c(0, h[index]/sum(h[index])))*(range[2] - range[1]) + range[1]
} else {
tp = c(0, cumsum(h[index])) + mid - sum(h[index])/2
}
} else if(mid - total_len/2 < range[1]) { # if it exceed the bottom
# tp = seq(range[1], range[1] + total_len, length = length(index) + 1)
tp = c(0, cumsum(h[index])) + range[1]
} else if(mid + total_len/2 > range[2]) {
# tp = seq(range[2] - total_len, range[2], length = length(index) + 1)
tp = range[2] - rev(c(0, cumsum(h[index])))
box_height = sum(h[index]) # sum of the height in the cluster
box_mid = (min(x1[index]) + max(x2[index]))/2
box_x1 = box_mid - box_height/2
box_x2 = box_mid + box_height/2

if(box_x1 < range[1]) { # if it exceed the bottom
new_x2[index] = range[1] + cumsum(h[index])
new_x1[index] = new_x2[index] - h[index]
} else if(box_x2 > range[2]) {
new_x1[index] = range[2] - rev(cumsum(h[index]))
new_x2[index] = new_x1[index] + h[index]
} else {
# tp = seq(mid - total_len/2, mid + total_len/2, length = length(index)+1)
tp = c(0, cumsum(h[index])) + mid - sum(h[index])/2
new_x2[index] = box_x1 + cumsum(h[index])
new_x1[index] = new_x2[index] - h[index]
}
new_x1[index] = tp[-length(tp)]
new_x2[index] = tp[-1]
}
mid = (new_x1 + new_x2)/2
h = (x2 - x1)

x1 = mid - h/2
x2 = mid + h/2

x1 = new_x1
x2 = new_x2

if(plot) {
make_plot(cbind(ox1, ox2), cbind(x1, x2), main = qq("@{i_try}th try, @{ncluster} clusters"))
}

i_try = i_try + 1
}
Expand Down
27 changes: 7 additions & 20 deletions man/smartAlign2.Rd
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@ Adjust positions of rectanglar shapes
Adjust positions of rectanglar shapes
}
\usage{
smartAlign2(start, end, range, range_fixed = TRUE)
smartAlign2(start, end, range, plot = FALSE)
}
\arguments{

\item{start}{position which corresponds to the start (bottom or left) of the rectangle-shapes.}
\item{end}{position which corresponds to the end (top or right) of the rectanglar shapes.}
\item{range}{data ranges (the minimal and maximal values)}
\item{range_fixed}{Whether the range is fixed for \code{range} when adjust the positions?}
\item{plot}{Whether plot the correspondance between the original positions and the adjusted positions. Only for testing.}

}
\details{
Expand All @@ -23,40 +23,27 @@ This is an improved version of the \code{\link[circlize]{smartAlign}}.
It adjusts the positions of the rectangular shapes to make them do not overlap
}
\examples{
require(circlize)
make_plot = function(pos1, pos2, range) {
oxpd = par("xpd")
par(xpd = NA)
plot(NULL, xlim = c(0, 4), ylim = range, ann = FALSE)
col = rand_color(nrow(pos1), transparency = 0.5)
rect(0.5, pos1[, 1], 1.5, pos1[, 2], col = col)
rect(2.5, pos2[, 1], 3.5, pos2[, 2], col = col)
segments(1.5, rowMeans(pos1), 2.5, rowMeans(pos2))
par(xpd = oxpd)
}

range = c(0, 10)
pos1 = rbind(c(1, 2), c(5, 7))
make_plot(pos1, smartAlign2(pos1, range = range), range)
smartAlign2(pos1, range = range, plot = TRUE)

range = c(0, 10)
pos1 = rbind(c(-0.5, 2), c(5, 7))
make_plot(pos1, smartAlign2(pos1, range = range), range)
smartAlign2(pos1, range = range, plot = TRUE)

pos1 = rbind(c(-1, 2), c(3, 4), c(5, 6), c(7, 11))
pos1 = pos1 + runif(length(pos1), max = 0.3, min = -0.3)
omfrow = par("mfrow")
par(mfrow = c(3, 3))
for(i in 1:9) {
ind = sample(4, 4)
make_plot(pos1[ind, ], smartAlign2(pos1[ind, ], range = range), range)
smartAlign2(pos1[ind, ], range = range, plot = TRUE)
}
par(mfrow = omfrow)

pos1 = rbind(c(3, 6), c(4, 7))
make_plot(pos1, smartAlign2(pos1, range = range), range)
smartAlign2(pos1, range = range, plot = TRUE)

pos1 = rbind(c(1, 8), c(3, 10))
make_plot(pos1, smartAlign2(pos1, range = range), range)
make_plot(pos1, smartAlign2(pos1, range = range, range_fixed = FALSE), range)
smartAlign2(pos1, range = range, plot = TRUE)
}
24 changes: 23 additions & 1 deletion tests/test-AnnotationFunction.R
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,6 @@ panel_fun2 = function(index, nm) {
anno2 = anno_zoom(align_to = subgroup, which = "row", panel_fun = panel_fun2,
gap = unit(1, "cm"), width = unit(3, "cm"), side = "left")

# in infinite loop
draw(Heatmap(m, right_annotation = rowAnnotation(subgroup = subgroup, foo = anno,
show_annotation_name = FALSE),
left_annotation = rowAnnotation(bar = anno2, subgroup = subgroup, show_annotation_name = FALSE),
Expand All @@ -595,3 +594,26 @@ draw(Heatmap(m, right_annotation = rowAnnotation(foo = anno),
left_annotation = rowAnnotation(bar = anno2),
show_row_dend = FALSE,
row_split = subgroup))

set.seed(12345)
mat = matrix(rnorm(30*10), nr = 30)
row_split = c(rep("a", 10), rep("b", 5), rep("c", 2), rep("d", 3),
rep("e", 2), letters[10:17])
row_split = factor(row_split)

panel_fun = function(index, name) {
pushViewport(viewport())
grid.rect()
grid.text(name)
popViewport()
}

anno = anno_zoom(align_to = row_split, which = "row", panel_fun = panel_fun,
size = unit(0.5, "cm"), width = unit(4, "cm"))

# > dev.size()
# [1] 3.938326 4.502203
dev.new(width = 3.938326, height = 4.502203)
draw(Heatmap(mat, right_annotation = rowAnnotation(foo = anno),
row_split = row_split))

0 comments on commit 448f5de

Please sign in to comment.