-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathPhysRegTumor.py
2291 lines (1803 loc) · 84.5 KB
/
PhysRegTumor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#!/usr/bin/env python3
import pickle
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
plt.rcParams["font.family"] = "DejaVu Sans"
from scipy.ndimage import zoom, center_of_mass
from scipy.interpolate import griddata
import nibabel as nib
from FK import Solver
import argparse
import numpy as np
import os
from functools import partial
import odil
from odil.runtime import tf
from data_processing import correct_outside_skull_mask
printlog = odil.util.printlog
# Data tensors
global wm_data, gm_data, csf_data, segm_data, pet_data
global stored_affine
# Material properties
global dtype, gamma, gamma_ch
dtype = np.float32
gamma_ch = -2
gamma = tf.constant(1.0, dtype=dtype) * (-1.) * gamma_ch
# Weights for different components of the model
global BC_w, pde_w, balance_w, neg_w, core_w, edema_w, outside_w, params_w, symmetry_w
BC_w = 850
pde_w = 60000
balance_w = 215
neg_w = 80
core_w = 17.5
edema_w = 13.5
outside_w = 13.5
params_w = 98
symmetry_w = 3.2
# Diffusion and reaction parameters
global D_ch, R_ch, rho_ch, matter_th
D_ch = 0.13
R_ch = 10
rho_ch = 0.06
matter_th = 0.1
# Initial and threshold settings
global c_init, th_down_s, th_up_s
th_down_s = 0.22
th_up_s = 0.62
# Control points and tissue segmentation weights
global CM_pos
global pet_w
pet_w = 1.0
# Regularization factors for spatial and temporal components
global TS, kxreg, ktreg
kxreg = 11
ktreg = 80
def gauss_sol3d_tf(x, y, z, dx, dy, dz, init_scale):
# Experimentally chosen
Dt = 5.0
M = 250
# Apply scaling to the coordinates
x_scaled = x * dx / init_scale
y_scaled = y * dy / init_scale
z_scaled = z * dz / init_scale
# Gaussian function calculation
gauss = M / tf.pow(4 * tf.constant(np.pi, dtype=x.dtype) * Dt, 3/2) * tf.exp(- (tf.pow(x_scaled, 2) + tf.pow(y_scaled, 2) + tf.pow(z_scaled, 2)) / (4 * Dt))
# Apply thresholds
gauss = tf.where(gauss > 0.1, gauss, tf.zeros_like(gauss))
gauss = tf.where(gauss > 1, tf.ones_like(gauss, dtype=dtype), gauss)
return gauss
def unet_loss(unet_data, c_field, th_low):
# Convert unet_data numpy array to a TensorFlow tensor
unet_data_tensor = tf.convert_to_tensor(unet_data, dtype=tf.float32)
# Ensure c_field is a float tensor (in case it isn't)
c_field = tf.cast(c_field, tf.float32)
# Create a mask where unet_data is 1
mask = tf.equal(unet_data_tensor, 1.0)
# Calculate the part of c_field that is under the threshold where unet_data is 1
under_threshold = tf.less(c_field, th_low)
relevant_under_threshold = tf.logical_and(mask, under_threshold)
# Calculate losses where the condition is met
# Loss is linearly scaled with c_field, reaching maximum when c_field is 0
losses = tf.where(relevant_under_threshold, (th_low - c_field) / th_low, 0.0)
# Calculate the maximum possible loss
max_possible_loss = tf.where(mask, 1.0, 0.0)
# Compute the total loss and normalize by the maximum possible loss
total_loss = tf.reduce_sum(losses)
max_loss = tf.reduce_sum(max_possible_loss)
# Normalize the total loss to be between 0 and 1
normalized_loss = total_loss / max_loss
return normalized_loss
def transform_c(c, mod):
global outside_skull_mask
# Use the outside_skull_mask to set c to zero outside of the skull.
# This is achieved by multiplying c with the inverted mask (where outside of the skull is zero).
c_masked = c * (1 - outside_skull_mask)
# Zero the tumor cells for the initial time step.
if mod == np:
c_masked = np.concatenate([c_masked[:0], np.zeros_like(c_masked)[:1], c_masked[1:]], axis=0)
else:
c_masked = tf.concat([c_masked[:0], tf.zeros_like(c_masked)[:1], c_masked[1:]], axis=0)
return c_masked
def transform_txyz(tx, ty, tz, x, y, z, mod):
global outside_skull_mask
nth = 0 #first time slice
#nth = tx.shape[0] - 1 #last time slice
if mod == np:
# Fix the trajectories for the initial time step.
tx = np.concatenate([tx[:nth], x[nth:nth+1], tx[nth + 1:]], axis=0)
ty = np.concatenate([ty[:nth], y[nth:nth+1], ty[nth + 1:]], axis=0)
tz = np.concatenate([tz[:nth], z[nth:nth+1], tz[nth + 1:]], axis=0)
# Fix the spatial boundary particles in all three dimensions.
tx = np.concatenate([x[:, :1, :, :], tx[:, 1:-1, :, :], x[:, -1:, :, :]], axis=1)
ty = np.concatenate([y[:, :, :1, :], ty[:, :, 1:-1, :], y[:, :, -1:, :]], axis=2)
tz = np.concatenate([z[:, :, :, :1], tz[:, :, :, 1:-1], z[:, :, :, -1:]], axis=3)
# Create tensors of the initial positions for particles outside the skull in all three dimensions.
x_masked = x * outside_skull_mask
y_masked = y * outside_skull_mask
z_masked = z * outside_skull_mask
# Use the mask to combine the fixed and moving portions in all three dimensions.
tx_fixed = tx * (1-outside_skull_mask) + x_masked
ty_fixed = ty * (1-outside_skull_mask) + y_masked
tz_fixed = tz * (1-outside_skull_mask) + z_masked
else:
# Fix the trajectories for the initial time step.
tx = tf.concat([tx[:nth], x[nth:nth+1], tx[nth + 1:]], axis=0)
ty = tf.concat([ty[:nth], y[nth:nth+1], ty[nth + 1:]], axis=0)
tz = tf.concat([tz[:nth], z[nth:nth+1], tz[nth + 1:]], axis=0)
# Fix the spatial boundary particles in all three dimensions.
tx = tf.concat([x[:, :1, :, :], tx[:, 1:-1, :, :], x[:, -1:, :, :]], axis=1)
ty = tf.concat([y[:, :, :1, :], ty[:, :, 1:-1, :], y[:, :, -1:, :]], axis=2)
tz = tf.concat([z[:, :, :, :1], tz[:, :, :, 1:-1], z[:, :, :, -1:]], axis=3)
# Create tensors of the initial positions for particles outside the skull in all three dimensions.
x_masked = x * outside_skull_mask
y_masked = y * outside_skull_mask
z_masked = z * outside_skull_mask
# Use the mask to combine the fixed and moving portions in all three dimensions.
tx_fixed = tx * (1-outside_skull_mask) + x_masked
ty_fixed = ty * (1-outside_skull_mask) + y_masked
tz_fixed = tz * (1-outside_skull_mask) + z_masked
return tx_fixed, ty_fixed, tz_fixed
def transform_txyz2(tx, ty, tz, x, y, z, mod):
global outside_skull_mask
#nth = 0 #first time slice
nth = tx.shape[0] - 1 #last time slice
if mod == np:
# Fix the trajectories for the initial time step.
tx = np.concatenate([tx[:nth], x[nth:nth+1], tx[nth + 1:]], axis=0)
ty = np.concatenate([ty[:nth], y[nth:nth+1], ty[nth + 1:]], axis=0)
tz = np.concatenate([tz[:nth], z[nth:nth+1], tz[nth + 1:]], axis=0)
else:
# Fix the trajectories for the initial time step.
tx = tf.concat([tx[:nth], x[nth:nth+1], tx[nth + 1:]], axis=0)
ty = tf.concat([ty[:nth], y[nth:nth+1], ty[nth + 1:]], axis=0)
tz = tf.concat([tz[:nth], z[nth:nth+1], tz[nth + 1:]], axis=0)
return tx, ty, tz
def transform_txyz2(tx, ty, tz, x, y, z, mod):
return x, y, z
def compute_displacement_np(tx, ty, tz):
"""
Compute the displacement vector fields from particle trajectories in 3D using NumPy.
Parameters:
tx, ty, tz: np.ndarray
Particle trajectories, each with shape (nt, nx, ny, nz).
Returns:
ux, uy, uz: np.ndarray
Displacement fields in x, y, and z direction, each with shape (nt, nx, ny, nz).
"""
# Compute the initial positions in x, y, and z dimensions
initial_position_x = tx[0, :, :, :]
initial_position_y = ty[0, :, :, :]
initial_position_z = tz[0, :, :, :]
# Use broadcasting to compute displacements for all time steps in x, y, and z dimensions
ux = tx - initial_position_x[np.newaxis, ...]
uy = ty - initial_position_y[np.newaxis, ...]
uz = tz - initial_position_z[np.newaxis, ...]
return ux, uy, uz
def gradient_np(array, step, axis):
"""
Compute the gradient of a NumPy array using a central difference scheme.
"""
shifted_forward = np.roll(array, -1, axis=axis)
shifted_backward = np.roll(array, 1, axis=axis)
gradient = (shifted_forward - shifted_backward) / (2 * step)
return gradient
def compute_strain_tensor_lagrangian_full_np(u_x, u_y, u_z, dx, dy, dz):
"""
Compute the full 3D Green-Lagrange strain tensor E using NumPy arrays.
Parameters
----------
u_x, u_y, u_z : np.ndarray
The components of the displacement field. Each should have shape (nt, nx, ny, nz).
dx, dy, dz : float
The step sizes in the x, y, and z directions.
Returns
-------
np.ndarray
The Green-Lagrange strain tensor. Shape: (3, 3, nt, nx, ny, nz).
"""
# Compute the spatial gradients of the displacement fields
u_x_x = gradient_np(u_x, dx, axis=1)
u_x_y = gradient_np(u_x, dy, axis=2)
u_x_z = gradient_np(u_x, dz, axis=3)
u_y_x = gradient_np(u_y, dx, axis=1)
u_y_y = gradient_np(u_y, dy, axis=2)
u_y_z = gradient_np(u_y, dz, axis=3)
u_z_x = gradient_np(u_z, dx, axis=1)
u_z_y = gradient_np(u_z, dy, axis=2)
u_z_z = gradient_np(u_z, dz, axis=3)
# Compute the Green-Lagrange strain tensor components
E_xx = 0.5 * (u_x_x + u_x_x + u_x_x * u_x_x + u_y_x * u_y_x + u_z_x * u_z_x)
E_yy = 0.5 * (u_y_y + u_y_y + u_x_y * u_x_y + u_y_y * u_y_y + u_z_y * u_z_y)
E_zz = 0.5 * (u_z_z + u_z_z + u_x_z * u_x_z + u_y_z * u_y_z + u_z_z * u_z_z)
E_xy = 0.5 * (u_x_y + u_y_x + u_x_x * u_x_y + u_y_x * u_y_y)
E_xz = 0.5 * (u_x_z + u_z_x + u_x_x * u_x_z + u_z_x * u_z_z)
E_yz = 0.5 * (u_y_z + u_z_y + u_y_x * u_y_z + u_z_y * u_z_z)
# Combine strain tensor components into a single 6D array
E = np.array([[E_xx, E_xy, E_xz], [E_xy, E_yy, E_yz], [E_xz, E_yz, E_zz]])
return E
def compute_displacement(tx, ty, tz):
"""
Compute the displacement vector fields from particle trajectories in 3D.
Parameters:
tx, ty, tz: tf.Tensor
Particle trajectories, each with shape (nt, nx, ny, nz).
Returns:
ux, uy, uz: tf.Tensor
Displacement fields in x, y, and z direction, each with shape (nt, nx, ny, nz).
"""
# Compute the initial positions in x, y, and z dimensions
initial_position_x = tx[0, :, :, :]
initial_position_y = ty[0, :, :, :]
initial_position_z = tz[0, :, :, :]
# Expand dimensions of initial positions to match shape of tx, ty, tz
initial_position_x = tf.expand_dims(initial_position_x, axis=0)
initial_position_y = tf.expand_dims(initial_position_y, axis=0)
initial_position_z = tf.expand_dims(initial_position_z, axis=0)
# Use broadcasting to compute displacements for all time steps in x, y, and z dimensions
ux = tx - initial_position_x
uy = ty - initial_position_y
uz = tz - initial_position_z
return ux, uy, uz
def gradient(tensor, step, axis, final_op=False):
"""
Compute the gradient of a tensor using a central difference scheme in the interior and a first-order
scheme at the boundaries for 3D data.
Parameters
----------
tensor : tf.Tensor
The tensor to differentiate. Shape: (nt, nx, ny, nz).
step : float
The step size.
axis : int
The axis along which to compute the gradient.
final_op : bool
If this gradient operation is the final operation. Default is False.
Returns
-------
tf.Tensor
The gradient of the input tensor.
"""
tensor_before = tf.roll(tensor, shift=1, axis=axis)
tensor_after = tf.roll(tensor, shift=-1, axis=axis)
# Central difference in the interior of the domain
gradient = (tensor_after - tensor_before) / (2 * step)
if axis == 1:
# Forward difference at the left boundary and backward difference at the right boundary
gradient_left = (tensor[:, 1, :, :] - tensor[:, 0, :, :]) / step
gradient_right = (tensor[:, -1, :, :] - tensor[:, -2, :, :]) / step
gradient = tf.concat([gradient_left[:, None, :, :], gradient[:, 1:-1, :, :], gradient_right[:, None, :, :]], axis=axis)
elif axis == 2:
# Forward difference at the top boundary and backward difference at the bottom boundary
gradient_top = (tensor[:, :, 1, :] - tensor[:, :, 0, :]) / step
gradient_bottom = (tensor[:, :, -1, :] - tensor[:, :, -2, :]) / step
gradient = tf.concat([gradient_top[:, :, None, :], gradient[:, :, 1:-1, :], gradient_bottom[:, :, None, :]], axis=axis)
elif axis == 3:
# Forward difference at the front boundary and backward difference at the back boundary
gradient_front = (tensor[:, :, :, 1] - tensor[:, :, :, 0]) / step
gradient_back = (tensor[:, :, :, -1] - tensor[:, :, :, -2]) / step
gradient = tf.concat([gradient_front[:, :, :, None], gradient[:, :, :, 1:-1], gradient_back[:, :, :, None]], axis=axis)
return gradient
def dice_score_tf(mask1, mask2):
# Convert boolean masks to a dtype that TensorFlow operations can work with
mask1_float = tf.cast(mask1, dtype)
mask2_float = tf.cast(mask2, dtype)
# Compute the intersection
intersection = tf.reduce_sum(mask1_float * mask2_float)
# Compute Dice score
return 2. * intersection / (tf.reduce_sum(mask1_float) + tf.reduce_sum(mask2_float))
def calculate_dice_scores(segm, coeff, c_euler_slice):
# Create masks from c_euler_slice
edema_mask_pred = (c_euler_slice > coeff[5]) & (c_euler_slice <= coeff[6])
core_mask_pred = c_euler_slice >= coeff[6]
# Convert these masks to TensorFlow tensors if necessary
edema_mask_pred_tf = tf.convert_to_tensor(edema_mask_pred, dtype=tf.bool)
core_mask_pred_tf = tf.convert_to_tensor(core_mask_pred, dtype=tf.bool)
# Calculate masks from your segmentation
edema_mask_true = get_edema_mask(segm)
core_mask_true = get_core_mask(segm)
# Calculate Dice scores
dice_score_edema = dice_score_tf(edema_mask_true, edema_mask_pred_tf)
dice_score_core = dice_score_tf(core_mask_true, core_mask_pred_tf)
return dice_score_edema, dice_score_core
def get_core_mask(segm,mod=tf):
# Use TensorFlow operations for logical or and equality checks
return tf.logical_or(tf.equal(segm, 1), tf.equal(segm, 4))
def get_edema_mask(segm,mod=tf):
# Use TensorFlow operation for equality check
return tf.equal(segm, 3)
def get_core_loss_tf(c, th_up, segm):
core_mask = get_core_mask(segm)
# Compute the core loss where core_mask is True, else set to 0
core_loss = tf.where(core_mask,
tf.clip_by_value(th_up - c, clip_value_min=0, clip_value_max=tf.float32.max),
tf.zeros_like(core_mask, dtype=tf.float32))
return core_loss
def get_edema_loss_tf(c, th_down, th_up, segm):
edema_mask = get_edema_mask(segm)
# Define the condition for values within the desired range
within_range_condition = tf.logical_and(c >= th_down, c <= th_up)
# Combine the mask with the condition
final_mask = tf.logical_and(edema_mask, tf.logical_not(within_range_condition))
# Compute the edema loss
edema_loss = tf.where(final_mask,
tf.abs(c - th_down) + tf.abs(c - th_up),
tf.zeros_like(final_mask, dtype=tf.float32))
return edema_loss
def get_outside_segm_mask(segm, mod=tf):
# Use TensorFlow operation for equality check
return tf.equal(segm, 0)
def get_outside_segm_loss_tf(c, th_down, segm):
outside_segm_mask = get_outside_segm_mask(segm)
# Define the condition where c[-1] is not below th_down
not_below_condition = c >= th_down
# Combine the mask with the condition
final_mask = tf.logical_and(outside_segm_mask, not_below_condition)
# Compute the outside segment loss
# If c[-1] is not below th_down in the outside_segm, penalize by the difference
outside_segm_loss = tf.where(final_mask,
c - th_down,
tf.zeros_like(final_mask, dtype=tf.float32))
return outside_segm_loss
def pet_loss(pet_data, segm_data, c_euler):
# Create a mask where segm_data is 1 or 3
mask = tf.logical_or(segm_data == 1, segm_data == 3)
# Apply the mask to flatten only the selected voxels
pet_data_masked = tf.boolean_mask(pet_data, mask)
c_euler_masked = tf.boolean_mask(c_euler, mask)
# Ensure both tensors are of the same data type
pet_data_masked = tf.cast(pet_data_masked, dtype=tf.float32)
c_euler_masked = tf.cast(c_euler_masked, dtype=tf.float32)
# Compute Pearson correlation on the selected voxels
def pearson_correlation(x, y):
mean_x = tf.reduce_mean(x)
mean_y = tf.reduce_mean(y)
normalized_x = x - mean_x
normalized_y = y - mean_y
covariance = tf.reduce_sum(normalized_x * normalized_y)
std_dev_x = tf.sqrt(tf.reduce_sum(tf.square(normalized_x)))
std_dev_y = tf.sqrt(tf.reduce_sum(tf.square(normalized_y)))
return covariance / (std_dev_x * std_dev_y)
if tf.size(pet_data_masked) == 0 or tf.size(c_euler_masked) == 0:
return 1.0 # If no valid voxels, return maximum loss
correlation = pearson_correlation(pet_data_masked, c_euler_masked)
return 1 - correlation # loss is 1 minus the correlation coefficient
def compute_strain_tensor_lagrangian_full(u_x, u_y, u_z, domain):
"""
Compute the full 3D Green-Lagrange strain tensor E given the x, y, and z components of the displacement field.
Parameters
----------
u_x : tf.Tensor
The x-component of the displacement field. Shape: (nt, nx, ny, nz).
u_y : tf.Tensor
The y-component of the displacement field. Shape: (nt, nx, ny, nz).
u_z : tf.Tensor
The z-component of the displacement field. Shape: (nt, nx, ny, nz).
Returns
-------
E : tf.Tensor
The Green-Lagrange strain tensor. Shape: (3, 3, nt, nx, ny, nz).
"""
dx = domain.step('x')
dy = domain.step('y')
dz = domain.step('z')
# Compute the spatial gradients of the displacement fields
u_x_x = gradient(u_x, dx, axis=1) # partial derivative of u_x w.r.t. x
u_x_y = gradient(u_x, dy, axis=2) # partial derivative of u_x w.r.t. y
u_x_z = gradient(u_x, dz, axis=3) # partial derivative of u_x w.r.t. z
u_y_x = gradient(u_y, dx, axis=1) # partial derivative of u_y w.r.t. x
u_y_y = gradient(u_y, dy, axis=2) # partial derivative of u_y w.r.t. y
u_y_z = gradient(u_y, dz, axis=3) # partial derivative of u_y w.r.t. z
u_z_x = gradient(u_z, dx, axis=1) # partial derivative of u_z w.r.t. x
u_z_y = gradient(u_z, dy, axis=2) # partial derivative of u_z w.r.t. y
u_z_z = gradient(u_z, dz, axis=3) # partial derivative of u_z w.r.t. z
# Compute the Green-Lagrange strain tensor components
E_xx = 0.5 * (u_x_x + u_x_x + u_x_x*u_x_x + u_y_x*u_y_x + u_z_x*u_z_x)
E_yy = 0.5 * (u_y_y + u_y_y + u_x_y*u_x_y + u_y_y*u_y_y + u_z_y*u_z_y)
E_zz = 0.5 * (u_z_z + u_z_z + u_x_z*u_x_z + u_y_z*u_y_z + u_z_z*u_z_z)
E_xy = 0.5 * (u_x_y + u_y_x + u_x_x*u_x_y + u_y_x*u_y_y)
E_xz = 0.5 * (u_x_z + u_z_x + u_x_x*u_x_z + u_z_x*u_z_z)
E_yz = 0.5 * (u_y_z + u_z_y + u_y_x*u_y_z + u_z_y*u_z_z)
# Combine strain tensor components into a single 6D array
E = tf.stack([[E_xx, E_xy, E_xz], [E_xy, E_yy, E_yz], [E_xz, E_yz, E_zz]])
return E
def get_diffusion_coefficient(wm_intensity, gm_intensity, D_s, R):
return D_s * wm_intensity + (D_s / R) * gm_intensity
def m_Tildas(WM,GM,th):
WM_tilda_x = np.where(np.logical_and(np.roll(WM,-1,axis=0) + np.roll(GM,-1,axis=0) >= th,WM + GM >= th),(np.roll(WM,-1,axis=0) + WM)/2,0)
WM_tilda_y = np.where(np.logical_and(np.roll(WM,-1,axis=1) + np.roll(GM,-1,axis=1) >= th,WM + GM >= th),(np.roll(WM,-1,axis=1) + WM)/2,0)
WM_tilda_z = np.where(np.logical_and(np.roll(WM,-1,axis=2) + np.roll(GM,-1,axis=2) >= th,WM + GM >= th),(np.roll(WM,-1,axis=2) + WM)/2,0)
GM_tilda_x = np.where(np.logical_and(np.roll(WM,-1,axis=0) + np.roll(GM,-1,axis=0) >= th,WM + GM >= th),(np.roll(GM,-1,axis=0) + GM)/2,0)
GM_tilda_y = np.where(np.logical_and(np.roll(WM,-1,axis=1) + np.roll(GM,-1,axis=1) >= th,WM + GM >= th),(np.roll(GM,-1,axis=1) + GM)/2,0)
GM_tilda_z = np.where(np.logical_and(np.roll(WM,-1,axis=2) + np.roll(GM,-1,axis=2) >= th,WM + GM >= th),(np.roll(GM,-1,axis=2) + GM)/2,0)
return {"WM_t_x": WM_tilda_x,"WM_t_y": WM_tilda_y,"WM_t_z": WM_tilda_z,"GM_t_x": GM_tilda_x,"GM_t_y": GM_tilda_y,"GM_t_z": GM_tilda_z}
def get_D(WM, GM, th, Dw, Dw_ratio):
M = m_Tildas(WM,GM,th)
D_minus_x = Dw*(M["WM_t_x"] + M["GM_t_x"]/Dw_ratio)
D_minus_y = Dw*(M["WM_t_y"] + M["GM_t_y"]/Dw_ratio)
D_minus_z = Dw*(M["WM_t_z"] + M["GM_t_z"]/Dw_ratio)
D_plus_x = Dw*(np.roll(M["WM_t_x"],1,axis=0) + np.roll(M["GM_t_x"],1,axis=0)/Dw_ratio)
D_plus_y = Dw*(np.roll(M["WM_t_y"],1,axis=1) + np.roll(M["GM_t_y"],1,axis=1)/Dw_ratio)
D_plus_z = Dw*(np.roll(M["WM_t_z"],1,axis=2) + np.roll(M["GM_t_z"],1,axis=2)/Dw_ratio)
return {"D_minus_x": D_minus_x, "D_minus_y": D_minus_y, "D_minus_z": D_minus_z,"D_plus_x": D_plus_x, "D_plus_y": D_plus_y, "D_plus_z": D_plus_z}
def operator_adv(ctx):
global gamma, BC_w, pde_w, balance_w, neg_w, D_ch, R_ch, outside_skull_mask, neg_w,CM_pos, pet_w
dt = ctx.step('t')
dx = ctx.step('x')
dy = ctx.step('y')
dz = ctx.step('z')
x = ctx.points('x')
y = ctx.points('y')
z = ctx.points('z')
nt = ctx.size('t')
nx = ctx.size('x')
ny = ctx.size('y')
nz = ctx.size('z')
def single_var(key, st=0, sx=0, sy=0, sz=0):
u = ctx.field(key, st, sx, sy, sz)
return u
def field_to_particles_3d(q_src, it):
q_src = ctx.cast(q_src)
# Pad the field for 3 dimensions
q_src = pad_linear(q_src, [(1, 1), (1, 1), (1, 1)])
# Initialize the tensor for the particle field
qp = tf.zeros(tix[it].shape, dtype=q_src.dtype)
# Loop through all combinations of x, y, z coordinates and weights
for jx, jy, jz, jw in [
(tix, tiy, tiz, sx0 * sy0 * sz0),
(tixp, tiy, tiz, sx1 * sy0 * sz0),
(tix, tiyp, tiz, sx0 * sy1 * sz0),
(tixp, tiyp, tiz, sx1 * sy1 * sz0),
(tix, tiy, tizp, sx0 * sy0 * sz1),
(tixp, tiy, tizp, sx1 * sy0 * sz1),
(tix, tiyp, tizp, sx0 * sy1 * sz1),
(tixp, tiyp, tizp, sx1 * sy1 * sz1),
]:
idx = tf.stack([jx[it] + 1, jy[it] + 1, jz[it] + 1], axis=-1)
qp += jw[it] * tf.gather_nd(q_src, idx)
return qp
def laplace(st):
"""
Calculate the Laplacian of a 4D field (time, x, y, z).
:param st: A tuple of field values (q, qxm, qxp, qym, qyp, qzm, qzp)
:return: Laplacian of the field
"""
q, qxm, qxp, qym, qyp, qzm, qzp = st
q_xx = (qxp - 2 * q + qxm) / dx**2
q_yy = (qyp - 2 * q + qym) / dy**2
q_zz = (qzp - 2 * q + qzm) / dz**2
q_lap = q_xx + q_yy + q_zz
return q_lap
def pad_linear(q, paddings):
"""
Apply linear padding to a 4D field.
:param q: The field to be padded
:param paddings: Padding specifications
:return: Padded field
"""
qr = tf.pad(q, paddings, mode='reflect')
qs = tf.pad(q, paddings, mode='symmetric')
q_padded = 2 * qs - qr
return q_padded
def depad(q, paddings):
"""
Remove padding from a 4D field.
:param q: The padded field
:param paddings: Padding specifications used for padding the field
:return: Field with padding removed
"""
pt, px, py, pz = paddings
slices = tuple(slice(p[0], -p[1] if p[1] else None) for p in paddings)
return q[slices]
def laplace_roll(q):
"""
Calculate the Laplacian of a 4D field using rolling operations.
:param q: The field for which the Laplacian is to be calculated
:return: Laplacian of the field
"""
paddings = [[0, 0], [1, 1], [1, 1], [1, 1]]
q_padded = pad_linear(q, paddings)
qxm = tf.roll(q_padded, shift=1, axis=1)
qxp = tf.roll(q_padded, shift=-1, axis=1)
qym = tf.roll(q_padded, shift=1, axis=2)
qyp = tf.roll(q_padded, shift=-1, axis=2)
qzm = tf.roll(q_padded, shift=1, axis=3)
qzp = tf.roll(q_padded, shift=-1, axis=3)
laplacian = laplace((q_padded, qxm, qxp, qym, qyp, qzm, qzp))
return depad(laplacian, paddings)
# Unknown parameters
coeff = ctx.field('coeff')
res = []
# Trajectories.
tx = single_var('x')
ty = single_var('y')
tz = single_var('z')
tx, ty, tz = transform_txyz(tx, ty, tz, x, y, z, tf)
# Tumor
c = single_var('c')
c = transform_c(c, mod=tf)
# Cell indices.
dtx = tx / dx - 0.5
dty = ty / dy - 0.5
dtz = tz / dz - 0.5
tix = tf.clip_by_value(ctx.cast(tf.floor(dtx), tf.int32), -1, nx - 1)
tiy = tf.clip_by_value(ctx.cast(tf.floor(dty), tf.int32), -1, ny - 1)
tiz = tf.clip_by_value(ctx.cast(tf.floor(dtz), tf.int32), -1, nz - 1)
tixp = tix + 1
tiyp = tiy + 1
tizp = tiz + 1
# Weights.
sx1 = tf.clip_by_value(dtx - ctx.cast(tix), 0, 1)
sy1 = tf.clip_by_value(dty - ctx.cast(tiy), 0, 1)
sz1 = tf.clip_by_value(dtz - ctx.cast(tiz), 0, 1)
sx0 = 1 - sx1
sy0 = 1 - sy1
sz0 = 1 - sz1
# Get white matter intensities at particle locations
D_scalar = coeff[0]
rho = coeff[1]
x0 = coeff[2]
y0 = coeff[3]
z0 = coeff[4]
gamma = tf.constant(1.0,dtype=dtype)*coeff[7]
th_down = coeff[5]
th_up = coeff[6]
# Clip th_down values
# Ensuring th_down is no less than 0.20 and no more than 0.35
th_down = tf.clip_by_value(coeff[5], clip_value_min=0.20, clip_value_max=0.35)
# Clip th_up values
# Ensuring th_up is no less than 0.50 and no more than 0.85
th_up = tf.clip_by_value(coeff[6], clip_value_min=0.50, clip_value_max=0.85)
# Get the spatially-varying diffusion coefficient based on wm_intensities
D = get_D(wm_data, gm_data, matter_th, D_scalar, R_ch)
# Calculate the tumor pde loss
pde_loss = tumor_pde_loss(tx, ty, tz, dt, c, D,rho)
res += [pde_loss*pde_w]
#BC
c_init = gauss_sol3d_tf((x[0,:]-x0),(y[0,:]-y0),(z[0,:]-z0),dx,dy,dz,init_scale=0.8)
#c_init_lagrange = field_to_particles_3d(c_init,1)
c_init_lagrange = c_init #better for the init tissues
bc = c[1] - c_init_lagrange
res += [bc*BC_w]
bc0 = c[0]
res += [bc0*BC_w]
# Calculate the strain tensors and balace them
ux, uy, uz = compute_displacement(tx, ty, tz)
E = compute_strain_tensor_lagrangian_full(ux, uy, uz, ctx)
lambda_vals = compute_lambda(wm_data, gm_data, csf_data, c)
lambda_vals = lambda_vals / tf.reduce_max(lambda_vals)
mu_vals = compute_mu(wm_data, gm_data, csf_data, c)
mu_vals = mu_vals / tf.reduce_max(mu_vals)
#calculate the balance residual
res_balance = compute_strain_balance_tf(E,c,gamma,lambda_vals,mu_vals,ctx)
res += [res_balance*balance_w]
#static tissues:
#res += [[tx - x,ty - y,tz - z]]*w_static_tissues
#negative tumor
res += [(c - tf.abs(c))*neg_w]
#outside of matter mask
combined_matter = wm_data + gm_data
combined_matter_mask = (combined_matter < matter_th).astype(int)
# Repeat the combined_matter_mask across the time dimension
combined_matter_mask_4d = np.repeat(combined_matter_mask[np.newaxis, :, :], nt, axis=0)
res += [c * combined_matter_mask_4d * outside_w]
#Data fit
c_euler = particles_to_field_3d_average(c[-1],tx[-1],ty[-1],tz[-1],ctx.domain)
res += [get_outside_segm_loss_tf(c_euler,th_down,segm_data)*outside_w]
res += [get_edema_loss_tf(c_euler,th_down,th_up,segm_data)*edema_w]
res += [get_core_loss_tf(c_euler,th_up,segm_data)*core_w]
res += [tf.clip_by_value(0.11 - D_scalar ,0,100)*params_w]
res += [tf.clip_by_value(0.02 - rho ,0,100)*params_w]
res += [tf.clip_by_value(1.5 + gamma ,0,100)*params_w]
# Smoothness of particles in space.
ltx = laplace_roll(tx) * kxreg
lty = laplace_roll(ty) * kxreg
ltz = laplace_roll(tz) * kxreg
res += [ltx, lty, ltz]
# Smoothness of particles in time.
txm = tf.roll(tx, shift=[1], axis=[0])
tym = tf.roll(ty, shift=[1], axis=[0])
tzm = tf.roll(tz, shift=[1], axis=[0])
txp = tf.roll(tx, shift=[-1], axis=[0])
typ = tf.roll(ty, shift=[-1], axis=[0])
tzp = tf.roll(tz, shift=[-1], axis=[0])
# Calculate residuals for x, y, and z dimensions
res += [
((txp - 2 * tx + txm) / dt**2)[1:-1] * ktreg,
((typ - 2 * ty + tym) / dt**2)[1:-1] * ktreg,
((tzp - 2 * tz + tzm) / dt**2)[1:-1] * ktreg,
]
#brain looks simmetric at the beginning
healthy_wm = field_to_particles_3d(wm_data,nt-1)
for factor in [4, 8]:
res += [calculate_symmetry_loss(healthy_wm, scale_factor=factor)*symmetry_w]
# Process gray matter (gm)
healthy_gm = field_to_particles_3d(gm_data,nt-1)
for factor in [4, 8]:
res += [calculate_symmetry_loss(healthy_gm, scale_factor=factor)*symmetry_w]
# Process cerebrospinal fluid (csf)
healthy_csf = field_to_particles_3d(csf_data,nt-1)
for factor in [4, 8]:
res += [calculate_symmetry_loss(healthy_csf, scale_factor=factor)*symmetry_w]
# Parameter loss
# normalize the loss constructed so far
kappa = 3.0
res = [x / kappa for x in res]
# PET loss
res += [pet_loss(pet_data, segm_data, c_euler)*pet_w]
return res
def calculate_symmetry_loss(healthy, scale_factor=1):
# Assuming healthy shape is [depth, height, width]
depth, height, width = healthy.shape
# Downsample the tensor if scale_factor is greater than 1
if scale_factor > 1:
# Define the pooling size and strides for 3D data (batch, depth, height, width, channels)
pool_size = [1, 1, scale_factor, scale_factor, 1] # [batch, depth, height, width, channels]
strides = [1, 1, scale_factor, scale_factor, 1]
# Reshape the tensor to 5D for pooling (batch_size, depth, height, width, channels)
healthy_reshaped = tf.reshape(healthy, [1, depth, height, width, 1])
# Perform average pooling
healthy_downsampled = tf.nn.avg_pool3d(input=healthy_reshaped, ksize=pool_size, strides=strides, padding='VALID')
new_depth, new_height, new_width = healthy_downsampled.shape[1:4]
healthy_downsampled = tf.reshape(healthy_downsampled, [new_depth, new_height, new_width])
else:
healthy_downsampled = healthy
new_depth, new_height, new_width = healthy.shape
# Update dimensions after downsampling
_, height_downsampled, _ = healthy_downsampled.shape
# Split the tensor into upper and lower halves along the Y-axis (height)
mid = height_downsampled // 2
upper_half = healthy_downsampled[:, :mid, :]
lower_half = healthy_downsampled[:, mid:height_downsampled, :] if height_downsampled % 2 == 0 else healthy_downsampled[:, mid+1:, :]
# Mirror the lower half for comparison
lower_half_mirrored = tf.reverse(lower_half, axis=[1]) # Use axis=1 for the Y-axis (height)
# Calculate the absolute difference between the mirrored lower half and the upper half
difference = tf.abs(upper_half - lower_half_mirrored)
# Compute the loss as the mean of the differences
loss = tf.reduce_mean(difference)
return loss
def normalize_intensities(wm_intensity, gm_intensity, csf_intensity, c):
total = wm_intensity + gm_intensity + csf_intensity + c
# Check if the total is below the threshold
below_threshold = total < matter_th
# Avoid division by zero
total = tf.where(below_threshold, tf.constant(1, dtype=dtype), total)
equal_proportion = tf.constant(0.25, dtype=dtype)
# Normalize intensities
normalized_wm = tf.where(below_threshold, equal_proportion, wm_intensity / total)
normalized_gm = tf.where(below_threshold, equal_proportion, gm_intensity / total)
normalized_csf = tf.where(below_threshold, equal_proportion, csf_intensity / total)
normalized_c = tf.where(below_threshold, equal_proportion, c / total)
return normalized_wm, normalized_gm, normalized_csf, normalized_c
def compute_lambda(wm_intensity, gm_intensity, csf_intensity, c):
# Normalize the intensities
wm_intensity, gm_intensity, csf_intensity, c = normalize_intensities(wm_intensity, gm_intensity, csf_intensity, c)
# Constants
E = [2100, 2100, 100, 8000] # Young’s modulus for GM, WM, CSF, tumor
nu = [0.4, 0.4, 0.1, 0.45] # Poisson’s ratio for GM, WM, CSF, tumor
# Calculating lambda for each material
lambda_vals = [E[i] * nu[i] / ((1 + nu[i]) * (1 - 2 * nu[i])) for i in range(4)]
# Compute weighted lambda based on normalized intensities
weighted_lambda = (wm_intensity * lambda_vals[0] +
gm_intensity * lambda_vals[1] +
csf_intensity * lambda_vals[2] +
c * lambda_vals[3])
return weighted_lambda
def compute_mu(wm_intensity, gm_intensity, csf_intensity, c):
# Normalize the intensities
wm_intensity, gm_intensity, csf_intensity, c = normalize_intensities(wm_intensity, gm_intensity, csf_intensity, c)
# Constants
E = [2100, 2100, 100, 8000] # Young’s modulus for GM, WM, CSF, tumor
nu = [0.4, 0.4, 0.1, 0.45] # Poisson’s ratio for GM, WM, CSF, tumor
# Calculating mu for each material
mu_vals = [E[i] / (2 * (1 + nu[i])) for i in range(4)]
# Compute weighted mu based on normalized intensities
weighted_mu = (wm_intensity * mu_vals[0] +
gm_intensity * mu_vals[1] +
csf_intensity * mu_vals[2] +
c * mu_vals[3])
return weighted_mu
def compute_strain_balance_tf(E, c, gamma, lambda_vals, mu_vals, domain):
residuals = [] # A list to hold the residuals for each timestep
dx = domain.step('x')
dy = domain.step('y')
dz = domain.step('z')
for t in range(0, c.shape[0]):
lambda_ = lambda_vals[t]
mu = mu_vals[t]
grad_c_x, grad_c_y, grad_c_z = grad_fd(c[t], dx, dy, dz)
# Stack gradient to match the dimensionality of div
tumor_force = gamma * tf.stack([grad_c_x, grad_c_y, grad_c_z], axis=0)
# Calculate the total stress
sigma_tissue = compute_stress(E[...,t,:,:,:], lambda_, mu)
div = divergence_fd(sigma_tissue,dx,dy,dz)
# Calculate the residual
residual = div + tumor_force
residuals.append(residual)
return tf.stack(residuals) # Stack the residuals to form a tensor
def divergence_fd(tensor, dx, dy, dz):
"""
Compute the divergence of a vector field in 3D using the gradient function.
Parameters
----------
tensor : tf.Tensor
The vector field tensor with shape (3,3, nx, ny, nz).
dx : float, optional
The step size in the x-direction.
dy : float, optional
The step size in the y-direction.
dz : float, optional
The step size in the z-direction.
Returns
-------
tf.Tensor
The divergence of the input tensor.
"""
# Shift tensor in x, y, and z directions
tensor_right = tf.roll(tensor, shift=-1, axis=-3)
tensor_up = tf.roll(tensor, shift=-1, axis=-2)
tensor_forward = tf.roll(tensor, shift=-1, axis=-1)
# Shift tensor in x, y, and z directions
tensor_left = tf.roll(tensor, shift=1, axis=-3)
tensor_down = tf.roll(tensor, shift=1, axis=-2)
tensor_back = tf.roll(tensor, shift=1, axis=-1)
# Calculate the derivatives in the x, y, and z directions
div_x = (tensor_right - tensor_left) / (2*dx)
div_y = (tensor_up - tensor_down) / (2*dy)
div_z = (tensor_forward - tensor_back) / (2*dz)
# Sum the partial derivatives to get the divergence
div = div_x[...,0,:,:,:] + div_y[...,1,:,:,:] + div_z[...,2,:,:,:]
return div
return div
def gradient_3d(tensor, step, axis):
"""
Compute the gradient of a 3D tensor using a central difference scheme in the interior and a first-order
scheme at the boundaries.
Parameters
----------
tensor : tf.Tensor
The tensor to differentiate. Shape: (nx, ny, nz).
step : float
The step size.
axis : int
The axis along which to compute the gradient.
Returns
-------
tf.Tensor
The gradient of the input tensor.
"""
tensor_before = tf.roll(tensor, shift=1, axis=axis)
tensor_after = tf.roll(tensor, shift=-1, axis=axis)