-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Incrementing the study case... Deform is not working well (signal pro…
…blem)
- Loading branch information
1 parent
ef8616e
commit 62614b6
Showing
10 changed files
with
361 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
# -*- coding: utf-8 -*- | ||
""" | ||
Created on Tue Jan 29 18:32:13 2019 | ||
@author: Maria | ||
""" | ||
|
||
import numpy as np | ||
from matplotlib import pyplot as plt | ||
import images | ||
#from math import ceil | ||
#from functools import reduce | ||
from scipy import ndimage | ||
|
||
class DistanceTransformType(): | ||
euclidean = 0 | ||
|
||
class SimpleInterpolation(): | ||
''' | ||
This class calculates alpha values that gives some 'weight' to the | ||
pixels | ||
''' | ||
def __init__(self, sim_image, transform_type: DistanceTransformType): | ||
|
||
if transform_type == DistanceTransformType.euclidean: | ||
None | ||
self.dist_transform, _ = ndimage.distance_transform_edt(sim_image,return_indices = True) | ||
self.alpha_dist = None | ||
self.n_pixels = None | ||
self.pts = None | ||
self.ref_pts = None | ||
self.denom = None | ||
self.inverse_d = None | ||
|
||
else: | ||
print('Error') | ||
|
||
def GetResult(self): | ||
self.alpha_dist = np.zeros((self.dist_transform.shape[0], self.dist_transform.shape[1])) | ||
self.n_pixels = self.dist_transform.shape[0]*self.dist_transform.shape[1] | ||
self.pts = sim_image_object.getOtherPointsCoords() | ||
self.ref_pts = sim_image_object.getPointRefCoords() | ||
self.denom = self.dist_transform[self.pts].reshape( self.n_pixels - 1, self.dist_transform.shape[2]) | ||
self.denom = 1/ self.denom[:] | ||
self.inverse_d = np.sum(self.denom[:], axis=0) | ||
|
||
return self.alpha_dist | ||
|
||
|
||
#dist_transform, indices = ndimage.distance_transform_edt(sim_image,return_indices = True) | ||
|
||
#alpha_dist = np.zeros((dist_transform.shape[0],dist_transform.shape[1])) | ||
#n_pixels = dist_transform.shape[0]*dist_transform.shape[1] | ||
# | ||
##pts = np.where(sim_image > 0) | ||
##ref_pts = np.where(sim_image < 1) | ||
# | ||
#pts = sim_image_object.getOtherPointsCoords() | ||
#ref_pts = sim_image_object.getPointRefCoords() | ||
|
||
#denom = dist_transform[pts].reshape(n_pixels - 1, dist_transform.shape[2]) | ||
#denom = 1/denom[:] | ||
#inverse_d = np.sum(denom[:], axis=0) | ||
|
||
w_dist = ((1/dist_transform[pts]).reshape(n_pixels - 1, dist_transform.shape[2]))/inverse_d | ||
|
||
alpha = np.zeros(dist_transform.shape) | ||
alpha[pts] = w_dist.ravel() | ||
alpha[ref_pts] = 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .affine_registration import affine_registration | ||
from .rigid_registration import rigid_registration | ||
from .deformable_registration import gaussian_kernel, deformable_registration |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
from builtins import super | ||
import numpy as np | ||
from .expectation_maximization_registration import expectation_maximization_registration | ||
|
||
class affine_registration(expectation_maximization_registration): | ||
def __init__(self, B=None, t=None, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.B = np.eye(self.D) if B is None else B | ||
self.t = np.atleast_2d(np.zeros((1, self.D))) if t is None else t | ||
|
||
def update_transform(self): | ||
muX = np.divide(np.sum(np.dot(self.P, self.X), axis=0), self.Np) | ||
muY = np.divide(np.sum(np.dot(np.transpose(self.P), self.Y), axis=0), self.Np) | ||
|
||
self.XX = self.X - np.tile(muX, (self.N, 1)) | ||
YY = self.Y - np.tile(muY, (self.M, 1)) | ||
|
||
self.A = np.dot(np.transpose(self.XX), np.transpose(self.P)) | ||
self.A = np.dot(self.A, YY) | ||
|
||
self.YPY = np.dot(np.transpose(YY), np.diag(self.P1)) | ||
self.YPY = np.dot(self.YPY, YY) | ||
|
||
self.B = np.linalg.solve(np.transpose(self.YPY), np.transpose(self.A)) | ||
self.t = np.transpose(muX) - np.dot(np.transpose(self.B), np.transpose(muY)) | ||
|
||
def transform_point_cloud(self, Y=None): | ||
if Y is None: | ||
self.TY = np.dot(self.Y, self.B) + np.tile(self.t, (self.M, 1)) | ||
return | ||
else: | ||
return np.dot(Y, self.B) + np.tile(self.t, (Y.shape[0], 1)) | ||
|
||
def update_variance(self): | ||
qprev = self.q | ||
|
||
trAB = np.trace(np.dot(self.A, self.B)) | ||
xPx = np.dot(np.transpose(self.Pt1), np.sum(np.multiply(self.XX, self.XX), axis =1)) | ||
trBYPYP = np.trace(np.dot(np.dot(self.B, self.YPY), self.B)) | ||
self.q = (xPx - 2 * trAB + trBYPYP) / (2 * self.sigma2) + self.D * self.Np/2 * np.log(self.sigma2) | ||
self.err = np.abs(self.q - qprev) | ||
|
||
self.sigma2 = (xPx - trAB) / (self.Np * self.D) | ||
|
||
if self.sigma2 <= 0: | ||
self.sigma2 = self.tolerance / 10 | ||
|
||
def get_registration_parameters(self): | ||
return self.B, self.t |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
from builtins import super | ||
import numpy as np | ||
from .expectation_maximization_registration import expectation_maximization_registration | ||
|
||
def gaussian_kernel(Y, beta): | ||
(M, D) = Y.shape | ||
XX = np.reshape(Y, (1, M, D)) | ||
YY = np.reshape(Y, (M, 1, D)) | ||
XX = np.tile(XX, (M, 1, 1)) | ||
YY = np.tile(YY, (1, M, 1)) | ||
diff = XX-YY | ||
diff = np.multiply(diff, diff) | ||
diff = np.sum(diff, 2) | ||
return np.exp(-diff / (2 * beta)) | ||
|
||
class deformable_registration(expectation_maximization_registration): | ||
def __init__(self, alpha=2, beta=2, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.alpha = 2 if alpha is None else alpha | ||
self.beta = 2 if alpha is None else beta | ||
self.W = np.zeros((self.M, self.D)) | ||
self.G = gaussian_kernel(self.Y, self.beta) | ||
|
||
def update_transform(self): | ||
A = np.dot(np.diag(self.P1), self.G) + self.alpha * self.sigma2 * np.eye(self.M) | ||
B = np.dot(self.P, self.X) - np.dot(np.diag(self.P1), self.Y) | ||
self.W = np.linalg.solve(A, B) | ||
|
||
def transform_point_cloud(self, Y=None): | ||
if Y is None: | ||
self.TY = self.Y + np.dot(self.G, self.W) | ||
return | ||
else: | ||
return Y + np.dot(self.G, self.W) | ||
|
||
def update_variance(self): | ||
qprev = self.sigma2 | ||
|
||
xPx = np.dot(np.transpose(self.Pt1), np.sum(np.multiply(self.X, self.X), axis=1)) | ||
yPy = np.dot(np.transpose(self.P1), np.sum(np.multiply(self.TY, self.TY), axis=1)) | ||
trPXY = np.sum(np.multiply(self.TY, np.dot(self.P, self.X))) | ||
|
||
self.sigma2 = (xPx - 2 * trPXY + yPy) / (self.Np * self.D) | ||
|
||
if self.sigma2 <= 0: | ||
self.sigma2 = self.tolerance / 10 | ||
self.err = np.abs(self.sigma2 - qprev) | ||
|
||
def get_registration_parameters(self): | ||
return self.G, self.W |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
import numpy as np | ||
|
||
def initialize_sigma2(X, Y): | ||
(N, D) = X.shape | ||
(M, _) = Y.shape | ||
XX = np.reshape(X, (1, N, D)) | ||
YY = np.reshape(Y, (M, 1, D)) | ||
XX = np.tile(XX, (M, 1, 1)) | ||
YY = np.tile(YY, (1, N, 1)) | ||
diff = XX - YY | ||
err = np.multiply(diff, diff) | ||
return np.sum(err) / (D * M * N) | ||
|
||
class expectation_maximization_registration(object): | ||
def __init__(self, X, Y, sigma2=None, max_iterations=100, tolerance=0.001, w=0, *args, **kwargs): | ||
if type(X) is not np.ndarray or X.ndim != 2: | ||
raise ValueError("The target point cloud (X) must be at a 2D numpy array.") | ||
if type(Y) is not np.ndarray or Y.ndim != 2: | ||
raise ValueError("The source point cloud (Y) must be a 2D numpy array.") | ||
if X.shape[1] != Y.shape[1]: | ||
raise ValueError("Both point clouds need to have the same number of dimensions.") | ||
|
||
self.X = X | ||
self.Y = Y | ||
self.sigma2 = sigma2 | ||
(self.N, self.D) = self.X.shape | ||
(self.M, _) = self.Y.shape | ||
self.tolerance = tolerance | ||
self.w = w | ||
self.max_iterations = max_iterations | ||
self.iteration = 0 | ||
self.err = self.tolerance + 1 | ||
self.P = np.zeros((self.M, self.N)) | ||
self.Pt1 = np.zeros((self.N, )) | ||
self.P1 = np.zeros((self.M, )) | ||
self.Np = 0 | ||
|
||
def register(self, callback=lambda **kwargs: None): | ||
self.transform_point_cloud() | ||
if self.sigma2 is None: | ||
self.sigma2 = initialize_sigma2(self.X, self.TY) | ||
self.q = -self.err - self.N * self.D/2 * np.log(self.sigma2) | ||
while self.iteration < self.max_iterations and self.err > self.tolerance: | ||
self.iterate() | ||
if callable(callback): | ||
kwargs = {'iteration': self.iteration, 'error': self.err, 'X': self.X, 'Y': self.TY} | ||
callback(**kwargs) | ||
|
||
return self.TY, self.get_registration_parameters() | ||
|
||
def get_registration_parameters(self): | ||
raise NotImplementedError("Registration parameters should be defined in child classes.") | ||
|
||
def iterate(self): | ||
self.expectation() | ||
self.maximization() | ||
self.iteration += 1 | ||
|
||
def expectation(self): | ||
P = np.zeros((self.M, self.N)) | ||
|
||
for i in range(0, self.M): | ||
diff = self.X - np.tile(self.TY[i, :], (self.N, 1)) | ||
diff = np.multiply(diff, diff) | ||
P[i, :] = P[i, :] + np.sum(diff, axis=1) | ||
|
||
c = (2 * np.pi * self.sigma2) ** (self.D / 2) | ||
c = c * self.w / (1 - self.w) | ||
c = c * self.M / self.N | ||
|
||
P = np.exp(-P / (2 * self.sigma2)) | ||
den = np.sum(P, axis=0) | ||
den = np.tile(den, (self.M, 1)) | ||
den[den==0] = np.finfo(float).eps | ||
den += c | ||
|
||
self.P = np.divide(P, den) | ||
self.Pt1 = np.sum(self.P, axis=0) | ||
self.P1 = np.sum(self.P, axis=1) | ||
self.Np = np.sum(self.P1) | ||
|
||
def maximization(self): | ||
self.update_transform() | ||
self.transform_point_cloud() | ||
self.update_variance() |
Oops, something went wrong.