@@ -1,54 +1,62 @@
from dataset_iterator import TrackingIterator
from dataset_iterator.tile_utils import extract_tile_random_zoom_function
import numpy as np
import numpy.ma as ma
-from scipy.ndimage import center_of_mass, find_objects, maximum_filter, map_coordinates
+from scipy.ndimage import center_of_mass, find_objects, maximum_filter, map_coordinates, gaussian_filter, convolve
from skimage.transform import rescale
from skimage.feature import peak_local_max
import skfmm
from math import copysign, isnan
import sys
import itertools
import edt
from random import random
from .medoid import get_medoid
+from ..utils import image_derivatives_np as der
+import time
+
+CENTER_MODE = ["GEOMETRICAL", "EDM_MAX", "EDM_MEAN", "SKELETON", "MEDOID"]
+
class DyDxIterator(TrackingIterator):
def __init__(self,
dataset,
extract_tile_function, # = extract_tile_random_zoom_function(tile_shape=(128, 128), n_tiles=8, zoom_range=[0.6, 1.6], aspect_ratio_range=[0.75, 1.5], random_channel_jitter_shape=[50, 50] ),
frame_window:int,
aug_frame_subsampling, # either int: frame number interval will be drawn uniformly in [frame_window,aug_frame_subsampling] or callable that generate an frame number interval (int)
erase_edge_cell_size:int,
next:bool = True,
allow_frame_subsampling_direct_neigh:bool = False,
- aug_remove_prob: float = 0.01,
+ aug_remove_prob: float = 0.005,
return_link_multiplicity:bool = True,
channel_keywords:list=['/raw', '/regionLabels'], # channel @1 must be label
array_keywords:list=['/linksPrev'],
elasticdeform_parameters:dict = None,
downscale_displacement_and_link_multiplicity=1,
- return_center = True,
- center_mode = "MEDOID", # GEOMETRICAL, "EDM_MAX", "EDM_MEAN", "SKELETON", "MEDOID"
+ return_edm_derivatives: bool = False,
+ return_center:bool = True,
+ center_mode:str = "MEDOID", # GEOMETRICAL, "EDM_MAX", "EDM_MEAN", "SKELETON", "MEDOID"
return_label_rank = False,
long_term:bool = True,
return_next_displacement:bool = True,
output_float16=False,
**kwargs):
assert len(channel_keywords)==2, 'keyword should contain 2 elements in this order: grayscale input images, object labels'
assert len(array_keywords) == 1, 'array keyword should contain 1 element: links to previous objects'
+ assert center_mode.upper() in CENTER_MODE, f"invalid center mode: {center_mode} should be in {CENTER_MODE}"
self.return_link_multiplicity=return_link_multiplicity
self.downscale=downscale_displacement_and_link_multiplicity
self.erase_edge_cell_size=erase_edge_cell_size
self.aug_frame_subsampling=aug_frame_subsampling
self.allow_frame_subsampling_direct_neigh=allow_frame_subsampling_direct_neigh
self.output_float16=output_float16
+ self.return_edm_derivatives=return_edm_derivatives
self.return_center=return_center
- self.center_mode=center_mode
+ self.center_mode=center_mode.upper()
self.return_label_rank=return_label_rank
assert frame_window>=1, "frame_window must be >=1"
self.frame_window = frame_window
self.return_next_displacement=return_next_displacement
self.n_label_max = kwargs.pop("n_label_max", 2000)
self.long_term=long_term if self.frame_window>1 else False
self.return_central_only = False
@@ -213,30 +221,34 @@
self._erase_small_objects_at_edges(labelIms[i,...,frame_window], i, mask_to_erase_cur, mask_to_erase_chan_cur, batch_by_channel)
# prev timepoints
for j in range(0, frame_window):
self._erase_small_objects_at_edges(labelIms[i,...,j], i, mask_to_erase_prev, [m+j for m in mask_to_erase_chan_prev], batch_by_channel)
if return_next:
for j in range(0, frame_window):
self._erase_small_objects_at_edges(labelIms[i,...,frame_window+1+j], i, mask_to_erase_next, [m+j for m in mask_to_erase_chan_next], batch_by_channel)
+ object_slices = {}
+ for b, c in itertools.product(range(labelIms.shape[0]), range(labelIms.shape[-1])):
+ object_slices[(b, c)] = find_objects(labelIms[b,...,c])
edm = np.zeros(shape=labelIms.shape, dtype=np.float32)
for b,c in itertools.product(range(edm.shape[0]), range(edm.shape[-1])):
- edm[b,...,c] = edt.edt(labelIms[b,...,c], black_border=False)
+ edm[b,...,c] = edt_smooth(labelIms[b,...,c], object_slices[(b, c)])
+ #edm[b,...,c] = edt.edt(labelIms[b,...,c], black_border=False)
n_motion = 2 * frame_window if return_next else frame_window
if long_term:
n_motion = n_motion + (2 * ( frame_window - 1 ) if return_next else frame_window -1)
dyIm = np.zeros(labelIms.shape[:-1]+(n_motion,), dtype=self.dtype)
dxIm = np.zeros(labelIms.shape[:-1]+(n_motion,), dtype=self.dtype)
if ndisp:
dyImNext = np.zeros(labelIms.shape[:-1]+(n_motion,), dtype=self.dtype)
dxImNext = np.zeros(labelIms.shape[:-1]+(n_motion,), dtype=self.dtype)
if self.return_link_multiplicity:
linkMultiplicityImNext = np.zeros(labelIms.shape[:-1]+(n_motion,), dtype=self.dtype)
centerIm = np.zeros(labelIms.shape, dtype=self.dtype) if self.return_center else None
if self.return_label_rank:
- labelIm = np.zeros(labelIms.shape, dtype=np.int32)
+ rankIm = np.zeros(labelIms.shape, dtype=np.int32)
prevLabelArr = np.zeros(labelIms.shape[:1]+(n_motion, self.n_label_max), dtype=np.int32)
nextLabelArr = np.zeros(labelIms.shape[:1] + (n_motion, self.n_label_max), dtype=np.int32)
centerArr = np.zeros(labelIms.shape[:1]+labelIms.shape[-1:]+(self.n_label_max,2), dtype=np.float32)
centerArr.fill(np.nan)
if self.return_link_multiplicity:
linkMultiplicityIm = np.zeros(labelIms.shape[:-1]+(n_motion,), dtype=self.dtype)
@@ -251,54 +263,67 @@
for b,c in itertools.product(range(labelIms.shape[0]), range(labelIms.shape[-1])):
labels_and_centers[(b, c)] = _get_labels_and_centers(labelIms[b][...,c], edm[b][...,c], self.center_mode)
for i in range(labelIms.shape[0]):
bidx = get_idx(i)
for c in range(0, frame_window):
sel = [c, c+1]
l_c = [labels_and_centers[(i,s)] for s in sel]
- _compute_displacement(l_c, labelIms[i][...,sel], labels_map_prev[bidx][c], dyIm[i,...,c], dxIm[i,...,c], dyImNext=dyImNext[i,...,c] if ndisp else None, dxImNext=dxImNext[i,...,c] if ndisp else None, gdcmIm=centerIm[i,...,frame_window] if self.return_center and sel[1] == frame_window else None, gdcmImPrev=centerIm[i,...,c] if self.return_center else None, linkMultiplicityIm=linkMultiplicityIm[i,...,c] if self.return_link_multiplicity else None, linkMultiplicityImNext=linkMultiplicityImNext[i,...,c] if self.return_link_multiplicity and ndisp else None, rankIm=labelIm[i,...,frame_window] if self.return_label_rank and sel[1] == frame_window else None, rankImPrev=labelIm[i,...,c] if self.return_label_rank else None, prevLabelArr=prevLabelArr[i,c] if self.return_label_rank else None, nextLabelArr=nextLabelArr[i,c] if self.return_label_rank and ndisp else None, centerArr=centerArr[i,frame_window] if self.return_label_rank and sel[1] == frame_window else None, centerArrPrev=centerArr[i,c] if self.return_label_rank else None, center_mode=self.center_mode)
+ o_s = [object_slices[(i, s)] for s in sel]
+ _compute_displacement(l_c, labelIms[i][...,sel], labels_map_prev[bidx][c], o_s, dyIm[i,...,c], dxIm[i,...,c], dyImNext=dyImNext[i,...,c] if ndisp else None, dxImNext=dxImNext[i,...,c] if ndisp else None, gdcmIm=centerIm[i,...,frame_window] if self.return_center and sel[1] == frame_window else None, gdcmImPrev=centerIm[i,...,c] if self.return_center else None, linkMultiplicityIm=linkMultiplicityIm[i,...,c] if self.return_link_multiplicity else None, linkMultiplicityImNext=linkMultiplicityImNext[i,...,c] if self.return_link_multiplicity and ndisp else None, rankIm=rankIm[i,...,frame_window] if self.return_label_rank and sel[1] == frame_window else None, rankImPrev=rankIm[i,...,c] if self.return_label_rank else None, prevLabelArr=prevLabelArr[i,c] if self.return_label_rank else None, nextLabelArr=nextLabelArr[i,c] if self.return_label_rank and ndisp else None, centerArr=centerArr[i,frame_window] if self.return_label_rank and sel[1] == frame_window else None, centerArrPrev=centerArr[i,c] if self.return_label_rank else None, center_mode=self.center_mode)
if return_next:
for c in range(frame_window, 2*frame_window):
sel = [c, c+1]
l_c = [labels_and_centers[(i, s)] for s in sel]
- _compute_displacement(l_c, labelIms[i][...,sel], labels_map_prev[bidx][c], dyIm[i,...,c], dxIm[i,...,c], dyImNext=dyImNext[i,...,c] if ndisp else None, dxImNext=dxImNext[i,...,c] if ndisp else None, gdcmIm=centerIm[i,..., c + 1] if self.return_center else None, gdcmImPrev=None, linkMultiplicityIm=linkMultiplicityIm[i,...,c] if self.return_link_multiplicity else None, linkMultiplicityImNext=linkMultiplicityImNext[i,...,c] if self.return_link_multiplicity and ndisp else None, rankIm=labelIm[i,..., c + 1] if self.return_label_rank else None, rankImPrev=None, prevLabelArr=prevLabelArr[i,c] if self.return_label_rank else None, nextLabelArr=nextLabelArr[i,c] if self.return_label_rank and ndisp else None, centerArr=centerArr[i, c + 1] if self.return_label_rank else None, center_mode=self.center_mode)
+ o_s = [object_slices[(i, s)] for s in sel]
+ _compute_displacement(l_c, labelIms[i][...,sel], labels_map_prev[bidx][c], o_s, dyIm[i,...,c], dxIm[i,...,c], dyImNext=dyImNext[i,...,c] if ndisp else None, dxImNext=dxImNext[i,...,c] if ndisp else None, gdcmIm=centerIm[i,..., c + 1] if self.return_center else None, gdcmImPrev=None, linkMultiplicityIm=linkMultiplicityIm[i,...,c] if self.return_link_multiplicity else None, linkMultiplicityImNext=linkMultiplicityImNext[i,...,c] if self.return_link_multiplicity and ndisp else None, rankIm=rankIm[i,..., c + 1] if self.return_label_rank else None, rankImPrev=None, prevLabelArr=prevLabelArr[i,c] if self.return_label_rank else None, nextLabelArr=nextLabelArr[i,c] if self.return_label_rank and ndisp else None, centerArr=centerArr[i, c + 1] if self.return_label_rank else None, center_mode=self.center_mode)
if long_term:
off = 2*frame_window if return_next else frame_window
for c in range(0, frame_window-1):
sel = [c, frame_window]
l_c = [labels_and_centers[(i, s)] for s in sel]
- _compute_displacement(l_c, labelIms[i][...,sel], labels_map_prev[bidx][c+off], dyIm[i,...,c+off], dxIm[i,...,c+off], dyImNext=dyImNext[i,...,c+off] if ndisp else None, dxImNext=dxImNext[i,...,c+off] if ndisp else None, gdcmIm=None, gdcmImPrev=None, linkMultiplicityIm=linkMultiplicityIm[i,..., c + off] if self.return_link_multiplicity else None, linkMultiplicityImNext=linkMultiplicityImNext[i,..., c + off] if self.return_link_multiplicity and ndisp else None, rankIm=None, rankImPrev=None, prevLabelArr=prevLabelArr[i, c + off] if self.return_label_rank else None, nextLabelArr=nextLabelArr[i, c + off] if self.return_label_rank and ndisp else None, center_mode=self.center_mode)
+ o_s = [object_slices[(i, s)] for s in sel]
+ _compute_displacement(l_c, labelIms[i][...,sel], labels_map_prev[bidx][c+off], o_s, dyIm[i,...,c+off], dxIm[i,...,c+off], dyImNext=dyImNext[i,...,c+off] if ndisp else None, dxImNext=dxImNext[i,...,c+off] if ndisp else None, gdcmIm=None, gdcmImPrev=None, linkMultiplicityIm=linkMultiplicityIm[i,..., c + off] if self.return_link_multiplicity else None, linkMultiplicityImNext=linkMultiplicityImNext[i,..., c + off] if self.return_link_multiplicity and ndisp else None, rankIm=None, rankImPrev=None, prevLabelArr=prevLabelArr[i, c + off] if self.return_label_rank else None, nextLabelArr=nextLabelArr[i, c + off] if self.return_label_rank and ndisp else None, center_mode=self.center_mode)
if return_next:
for c in range(frame_window-1, 2*(frame_window-1)):
sel = [frame_window, c+3]
l_c = [labels_and_centers[(i, s)] for s in sel]
- _compute_displacement(l_c, labelIms[i][...,sel], labels_map_prev[bidx][c+off], dyIm[i,...,c+off], dxIm[i,...,c+off], dyImNext=dyImNext[i,...,c+off] if ndisp else None, dxImNext=dxImNext[i,...,c+off] if ndisp else None, gdcmIm=None, gdcmImPrev=None, linkMultiplicityIm=linkMultiplicityIm[i,..., c + off] if self.return_link_multiplicity else None, linkMultiplicityImNext=linkMultiplicityImNext[i,..., c + off] if self.return_link_multiplicity and ndisp else None, rankIm=None, rankImPrev=None, prevLabelArr=prevLabelArr[i, c + off] if self.return_label_rank else None, nextLabelArr=nextLabelArr[i, c + off] if self.return_label_rank and ndisp else None, center_mode=self.center_mode)
- other_output_channels = [chan_idx for chan_idx in self.output_channels if chan_idx!=1 and chan_idx!=2]
- all_channels = [batch_by_channel[chan_idx] for chan_idx in other_output_channels]
- channel_inc = 0
+ o_s = [object_slices[(i, s)] for s in sel]
+ _compute_displacement(l_c, labelIms[i][...,sel], labels_map_prev[bidx][c+off], o_s, dyIm[i,...,c+off], dxIm[i,...,c+off], dyImNext=dyImNext[i,...,c+off] if ndisp else None, dxImNext=dxImNext[i,...,c+off] if ndisp else None, gdcmIm=None, gdcmImPrev=None, linkMultiplicityIm=linkMultiplicityIm[i,..., c + off] if self.return_link_multiplicity else None, linkMultiplicityImNext=linkMultiplicityImNext[i,..., c + off] if self.return_link_multiplicity and ndisp else None, rankIm=None, rankImPrev=None, prevLabelArr=prevLabelArr[i, c + off] if self.return_label_rank else None, nextLabelArr=nextLabelArr[i, c + off] if self.return_label_rank and ndisp else None, center_mode=self.center_mode)
+
+ edm[edm == 0] = -1
+ if self.return_edm_derivatives:
+ der_y, der_x = np.zeros_like(edm), np.zeros_like(edm)
+ for b, c in itertools.product(range(edm.shape[0]), range(edm.shape[-1])):
+ derivatives_labelwise(edm[b, ..., c], -1, der_y[b, ..., c], der_x[b, ..., c], labelIms[b, ..., c], object_slices[(b, c)])
+ if self.return_central_only:
+ der_y = der_y[..., 1:2]
+ der_x = der_x[..., 1:2]
if self.return_central_only: # select only central frame for edm / center and only displacement / link multiplicity related to central frame
edm = edm[..., 1:2]
centerIm = centerIm[..., 1:2]
dyIm = dyIm[..., :1]
dxIm = dxIm[..., :1]
if self.return_link_multiplicity:
linkMultiplicityIm = linkMultiplicityIm[..., :1]
if ndisp:
dyImNext = dyImNext[..., 1:]
dxImNext = dxImNext[..., 1:]
if self.return_link_multiplicity:
linkMultiplicityImNext = linkMultiplicityImNext[..., 1:]
if self.return_label_rank:
- labelIm = labelIm[..., 1:2]
+ rankIm = rankIm[..., 1:2]
centerArr = centerArr[: , 1:2]
prevLabelArr = prevLabelArr[:, :1]
if ndisp:
nextLabelArr = nextLabelArr[:, 1:]
-
- edm[edm==0] = -1
+ if self.return_edm_derivatives:
+ edm = np.concatenate([edm, der_y, der_x], -1)
+ other_output_channels = [chan_idx for chan_idx in self.output_channels if chan_idx != 1 and chan_idx != 2]
+ all_channels = [batch_by_channel[chan_idx] for chan_idx in other_output_channels]
+ channel_inc = 0
all_channels.insert(channel_inc, edm)
if self.return_center:
channel_inc+=1
all_channels.insert(channel_inc, centerIm)
downscale_factor = 1./self.downscale if self.downscale>1 else 0
scale = [1, downscale_factor, downscale_factor, 1]
if self.downscale>1:
@@ -325,15 +350,15 @@
if self.output_float16:
for i, c in enumerate(all_channels):
if not (self.return_link_multiplicity and i == 3 + channel_inc or self.return_label_rank and (i == channel_inc or i == channel_inc + 1)): # softmax / sigmoid activation -> float32
all_channels[i] = c.astype('float16', copy=False)
if self.return_label_rank:
if ndisp:
prevLabelArr = np.concatenate([prevLabelArr, nextLabelArr], 1)
- all_channels.append(labelIm)
+ all_channels.append(rankIm)
all_channels.append(prevLabelArr)
all_channels.append(centerArr)
return all_channels
def _erase_small_objects_at_edges(self, labelImage, batch_idx, channel_idxs, channel_idxs_chan, batch_by_channel):
objects_to_erase = _get_small_objects_at_edges_to_erase(labelImage, self.erase_edge_cell_size)
if len(objects_to_erase)>0:
@@ -367,39 +392,25 @@
return dict()
if center_mode == "GEOMETRICAL":
centers = center_of_mass(labelIm, labelIm, labels)
elif center_mode == "EDM_MAX":
assert edm is not None and edm.shape == labelIm.shape
centers = []
for label in labels:
- edm_label = ma.array(edm, mask = labelIm != label)
+ edm_label = ma.array(edm, mask=labelIm != label)
center = ma.argmax(edm_label, fill_value=0)
center = np.unravel_index(center, edm_label.shape)
centers.append(center)
elif center_mode == "EDM_MEAN":
assert edm is not None and edm.shape == labelIm.shape
centers = center_of_mass(edm, labelIm, labels)
elif center_mode == "SKELETON":
- assert edm is not None and edm.shape == labelIm.shape
- mass_centers = np.array(center_of_mass(labelIm, labelIm, labels))[np.newaxis] # 1, N_ob, 2
- lm_coords = peak_local_max(edm, labels = labelIm) # N_lm, 2
- lm_coords_l = labelIm[lm_coords[:,0], lm_coords[:,1]] # N_lm
- # labels in labelIm are not necessarily continuous -> replace by rank
- label_rank = np.zeros(shape=(max(labels)+1,), dtype=np.int32)
- for l in labels:
- label_rank[l] = labels.index(l)
- lm_coords_l = label_rank[lm_coords_l]
- lm_coords_l = np.eye(len(labels))[lm_coords_l] # N_lm, N_ob
- lm_coords_ob = lm_coords[:,np.newaxis] * lm_coords_l[...,np.newaxis] # N_lm, N_ob, 1 # TODO medoid of local maxima, weighted by edm?
- lm_coords_dist = np.sum(np.square(mass_centers - lm_coords_ob), 2, keepdims=True) # N_lm, N_ob, 1
- lm_coords_dinv= 1./(lm_coords_dist + 0.1) # N_lm, N_ob, 1
- lm_coords_dinv = lm_coords_dinv * lm_coords_ob # erase weights that are outside object
- wsum=np.sum(lm_coords_ob * lm_coords_dinv, 0, keepdims=False) # N_ob, 2
- sum=np.sum(lm_coords_dinv, 0, keepdims=False) # N_ob, 1
- centers = wsum / sum # N_ob, 2
+ edm_lap = der.laplacian_2d(gaussian_filter(edm, sigma=1.5))
+ skeleton = edm_lap<0.25
+ centers = [get_medoid(*np.asarray( (labelIm == l) & skeleton).nonzero()) for l in labels]
elif center_mode == "MEDOID":
centers = [get_medoid(*np.asarray(labelIm == l).nonzero()) for l in labels]
else:
raise ValueError(f"Invalid center mode: {center_mode}")
return dict(zip(labels, centers))
# channel dimension = frames
@@ -476,15 +487,15 @@
if n_neigh == 0:
return 3
elif n_neigh == 1:
return 1
else:
return 2
-def _compute_displacement(labels_map_centers, labelIm, labels_map_prev, dyIm, dxIm, dyImNext=None, dxImNext=None, gdcmIm=None, gdcmImPrev=None, linkMultiplicityIm=None, linkMultiplicityImNext=None, rankIm=None, rankImPrev=None, prevLabelArr=None, nextLabelArr=None, centerArr=None, centerArrPrev=None, center_mode:str= "MEDOID"):
+def _compute_displacement(labels_map_centers, labelIm, labels_map_prev, object_slices, dyIm, dxIm, dyImNext=None, dxImNext=None, gdcmIm=None, gdcmImPrev=None, linkMultiplicityIm=None, linkMultiplicityImNext=None, rankIm=None, rankImPrev=None, prevLabelArr=None, nextLabelArr=None, centerArr=None, centerArrPrev=None, center_mode:str= "MEDOID"):
assert labelIm.shape[-1] == 2, f"invalid labelIm : {labelIm.shape[-1]} channels instead of 2"
assert (dxImNext is None) == (dyImNext is None)
if len(labels_map_centers[-1])==0: # no cells
return
curLabelIm = labelIm[...,-1]
labels_prev = labels_map_centers[0].keys()
labels_prev_rank = {l:r for r, l in enumerate(labels_prev)}
@@ -522,46 +533,85 @@
nextLabelArr[rank] = labels_next_rank[label_next]+1
if linkMultiplicityImNext is not None:
linkMultiplicityImNext[mask] = _get_category(len(label_nexts))
if rankImPrev is not None:
rankImPrev[mask] = rank + 1
if gdcmIm is not None:
assert gdcmIm.shape == dyIm.shape, "invalid shape for center image"
- _draw_centers(gdcmIm, labels_map_centers[-1], labelIm[...,1], geometrical_distance=center_mode == "GEOMETRICAL")
+ _draw_centers(gdcmIm, labels_map_centers[-1], labelIm[...,1], object_slices[1], geometrical_distance=center_mode == "GEOMETRICAL")
if gdcmImPrev is not None:
assert gdcmImPrev.shape == dyIm.shape, "invalid shape for center image prev"
- _draw_centers(gdcmImPrev, labels_map_centers[0], labelIm[...,0], geometrical_distance=center_mode == "GEOMETRICAL")
+ _draw_centers(gdcmImPrev, labels_map_centers[0], labelIm[...,0], object_slices[0], geometrical_distance=center_mode == "GEOMETRICAL")
if centerArr is not None:
for rank, (label, center) in enumerate(labels_map_centers[-1].items()):
centerArr[rank,0] = center[0]
centerArr[rank,1] = center[1]
if centerArrPrev is not None:
for rank, (label, center) in enumerate(labels_map_centers[0].items()):
centerArrPrev[rank,0] = center[0]
centerArrPrev[rank,1] = center[1]
-def _draw_centers(centerIm, labels_map_centers, labelIm, geometrical_distance:bool=False):
+def _draw_centers(centerIm, labels_map_centers, labelIm, object_slices, geometrical_distance:bool=False):
if len(labels_map_centers)==0:
return
# geodesic distance to center
if not geometrical_distance:
- count = 0
- m = np.ones_like(labelIm)
- for center in labels_map_centers.values():
- if not (isnan(center[0]) or isnan(center[1])):
- m[int(round(center[0])), int(round(center[1]))] = 0
- count+=1
- if count>0:
- m = ma.masked_array(m, ~labelIm.astype(bool))
- centerIm[:] = skfmm.distance(m)
+ shape = centerIm.shape
+ labelIm_dil = maximum_filter(labelIm, size=5)
+ non_zero = labelIm>0
+ labelIm_dil[non_zero] = labelIm[non_zero]
+ for (i, sl) in enumerate(object_slices):
+ if sl is not None:
+ center = labels_map_centers.get(i+1)
+ if not (isnan(center[0]) or isnan(center[1])):
+ sl = tuple( [slice(max(0, s.start - 2), min(s.stop + 2, ax), s.step) for s, ax in zip(sl, shape)])
+ mask = labelIm_dil[sl] == i + 1
+ m = np.ones_like(mask)
+ #print(f"label: {i+1} slice: {sl}, center: {center}, sub_m {sub_m.shape}, coord: {(int(round(center[0]))-sl[0].start, int(round(center[1]))-sl[1].start)}", flush=True)
+ m[int(round(center[0]))-sl[0].start, int(round(center[1]))-sl[1].start] = 0
+ m = ma.masked_array(m, ~mask)
+ centerIm[sl][mask] = skfmm.distance(m)[mask]
else:
Y, X = centerIm.shape
Y, X = np.meshgrid(np.arange(Y, dtype = np.float32), np.arange(X, dtype = np.float32), indexing = 'ij')
for i, (label, center) in enumerate(labels_map_centers.items()): # in case center prediction is a classification
if isnan(center[0]) or isnan(center[1]):
pass
else:
# distance to center
mask = labelIm==label
if mask.sum()>0:
d = np.sqrt(np.square(Y-center[0])+np.square(X-center[1]))
centerIm[mask] = d[mask]
+
+
+def edt_smooth(labelIm, object_slices):
+ shape = labelIm.shape
+ upsampled = np.kron(labelIm, np.ones((2, 2))) # upsample by factor 2
+ w=np.ones(shape=(3, 3), dtype=np.int8)
+ for (i, sl) in enumerate(object_slices):
+ if sl is not None:
+ sl = tuple([slice(max(s.start*2 - 1, 0), min(s.stop*2 + 1, ax*2), s.step) for s, ax in zip(sl, shape)])
+ sub_labelIm = upsampled[sl]
+ mask = sub_labelIm == i + 1
+ new_mask = convolve(mask.astype(np.int8), weights=w, mode="nearest") > 4 # smooth borders
+ sub_labelIm[mask] = 0 # replace mask by smoothed
+ sub_labelIm[new_mask] = i + 1
+ edm = edt.edt(upsampled)
+ edm = edm.reshape((shape[0], 2, shape[1], 2)).mean(-1).mean(1) # downsample (bin) by factor 2
+ edm = np.divide(edm + 0.5, 2) #convert to pixel unit
+ edm[edm <= 0.5] = 0
+ return edm
+
+
+def derivatives_labelwise(image, bck_value, der_y, der_x, labelIm, object_slices):
+ shape = labelIm.shape
+ for (i, sl) in enumerate(object_slices):
+ if sl is not None:
+ sl = tuple([slice(max(s.start - 1, 0), min(s.stop + 1, ax), s.step) for s, ax in zip(sl, shape)])
+ mask = labelIm[sl] == i + 1
+ sub_im = np.copy(image[sl])
+ sub_im[np.logical_not(mask)] = bck_value # erase neighboring cells
+ sub_der_y, sub_der_x = der.der_2d(sub_im, 0, 1)
+ der_y[sl][mask] = sub_der_y[mask]
+ der_x[sl][mask] = sub_der_x[mask]
+ return der_y, der_x
@@ -1,49 +1,53 @@
import tensorflow as tf
from .layers import ker_size_to_string, Combine, ResConv2D, Conv2DBNDrop, Conv2DTransposeBNDrop, WSConv2D, BatchToChannel, SplitBatch, ChannelToBatch, NConvToBatch2D, SelectFeature
import numpy as np
from .spatial_attention import SpatialAttention2D
from ..utils.helpers import ensure_multiplicity, flatten_list
-from ..utils.losses import weighted_loss_by_category, balanced_category_loss, PseudoHuber
+from ..utils.losses import weighted_loss_by_category, balanced_category_loss, PseudoHuber, compute_loss_derivatives
from ..utils.agc import adaptive_clip_grad
from .gradient_accumulator import GradientAccumulator
+
import time
class DiSTNetModel(tf.keras.Model):
def __init__(self, *args, spatial_dims,
edm_loss_weight:float=1,
center_loss_weight:float=1,
displacement_loss_weight:float=1,
category_loss_weight:float=1,
- edm_loss=PseudoHuber(1),
- gcdm_loss=PseudoHuber(1), gcdm_gradients:bool=False,
+ edm_loss=PseudoHuber(1), edm_derivative_loss:bool=False,
+ gcdm_loss=PseudoHuber(1), gcdm_derivative_loss:bool=False,
displacement_loss=PseudoHuber(1),
category_weights=None, # array of weights: [normal, division, no previous cell] or None = auto
category_class_frequency_range=[1/50, 50],
next=True,
frame_window=3,
long_term:bool=True,
predict_next_displacement:bool=True,
+ predict_gcdm_derivatives:bool=False, predict_edm_derivatives:bool=False,
print_gradients:bool=False, # for optimization, available in eager mode only
accum_steps=1, use_agc=False, agc_clip_factor=0.1, agc_eps=1e-3, agc_exclude_output=False, # lower clip factor clips more
**kwargs):
super().__init__(*args, **kwargs)
self.edm_weight = edm_loss_weight
self.center_weight = center_loss_weight
self.displacement_weight = displacement_loss_weight
self.category_weight = category_loss_weight
self.spatial_dims = spatial_dims
self.next = next
self.predict_next_displacement=predict_next_displacement
self.frame_window = frame_window
self.edm_loss = edm_loss
- self.gdcm_loss=gcdm_loss
- self.gcdm_gradients=gcdm_gradients
+ self.edm_derivative_loss = edm_derivative_loss
+ self.gcdm_loss = gcdm_loss
+ self.gcdm_derivative_loss = gcdm_derivative_loss
self.displacement_loss = displacement_loss
-
+ self.predict_gcdm_derivatives = predict_gcdm_derivatives
+ self.predict_edm_derivatives = predict_edm_derivatives
min_class_frequency=category_class_frequency_range[0]
max_class_frequency=category_class_frequency_range[1]
if category_weights is not None:
assert len(category_weights)==3, "3 category weights should be provided: normal cell, dividing cell, cell with no previous cell"
self.category_loss=weighted_loss_by_category(tf.keras.losses.CategoricalCrossentropy(), category_weights, remove_background=True)
else:
self.category_loss = balanced_category_loss(tf.keras.losses.CategoricalCrossentropy(),3, min_class_frequency=min_class_frequency, max_class_frequency=max_class_frequency, remove_background=True)
@@ -52,15 +56,15 @@
self.use_grad_acc = accum_steps>1
self.accum_steps = float(accum_steps)
if self.use_grad_acc or use_agc:
self.gradient_accumulator = GradientAccumulator(accum_steps, self)
self.use_agc = use_agc
self.agc_clip_factor = agc_clip_factor
self.agc_eps = agc_eps
- self.agc_exclude_keywords=["DecoderTrackY0/, DecoderTrackX0/", "DecoderCat0/", "DecoderCenter0/", "DecoderSegEDM0"] if agc_exclude_output else None
+ self.agc_exclude_keywords=["DecoderTrackY0/, DecoderTrackX0/", "DecoderCat0/", "DecoderCenterGCDM0/", "DecoderCenterGCDMdY0/", "DecoderCenterGCDMdX0/", "DecoderSegEDM0/", "DecoderSegEDMdY0/", "DecoderSegEDMdX0/"] if agc_exclude_output else None
self.print_gradients=print_gradients
def train_step(self, data):
if self.use_grad_acc:
self.gradient_accumulator.init_train_step()
fw = self.frame_window
@@ -74,39 +78,51 @@
displacement_weight = self.displacement_weight / (2. * n_fp_mul) # y & x
category_weight = self.category_weight / (n_fp_mul * float(n_frame_pairs))
edm_weight = self.edm_weight / float(n_frames) # divide by channel number ?
center_weight = self.center_weight / float(n_frames)# divide by channel number ?
if category_weight>0:
assert len(y) == 5 , f"invalid number of output. Expected: >=4 actual {len(y)}" # 0 = edm, 1 = center, 2 = dY, 3 = dX, 4 = cat
else:
- assert len(y) >= 4, f"invalid number of output. Expected: >=4 actual {len(y)}" # 0 = edm, 1 = center, 2 = dY, 3 = dX, 4 = cat
+ assert len(y) >= 4, f"invalid number of output. Expected: >=4 actual {len(y)}" # 0 = edm, 1 = center, 2 = dY, 3 = dX
with tf.GradientTape(persistent=self.print_gradients) as tape:
y_pred = self(x, training=True) # Forward pass
+ if self.predict_edm_derivatives:
+ edm, edm_dy, edm_dx = tf.split(y_pred[0], num_or_size_splits=3, axis=-1)
+ else:
+ edm, edm_dy, edm_dx = y_pred[0], None, None
+ if self.predict_gcdm_derivatives:
+ gcdm, gcdm_dy, gcdm_dx = tf.split(y_pred[1], num_or_size_splits=3, axis=-1)
+ else:
+ gcdm, gcdm_dy, gcdm_dx = y_pred[1], None, None
+ if self.predict_edm_derivatives or self.edm_derivative_loss:
+ true_edm, true_edm_dy, true_edm_dx = tf.split(y[0], num_or_size_splits=3, axis=-1)
+ else:
+ true_edm, true_edm_dy, true_edm_dx = y[0], None, None
# compute loss
losses = dict()
loss_weights = dict()
- cell_mask = tf.math.greater(y[0], 0)
+ cell_mask = tf.math.greater(true_edm, 0.5)
+ cell_mask_interior = tf.math.greater(true_edm, 1.5) if self.gcdm_derivative_loss or self.predict_gcdm_derivatives else None
# edm
if edm_weight>0:
- edm_loss = self.edm_loss(y[0], y_pred[0]) #, sample_weight = weight_map)
+ edm_loss = compute_loss_derivatives(true_edm, edm, self.edm_loss, true_dy=true_edm_dy, true_dx=true_edm_dx, pred_dy=edm_dy, pred_dx=edm_dx, derivative_loss=self.edm_derivative_loss, laplacian_loss=self.edm_derivative_loss)
losses["edm"] = edm_loss
loss_weights["edm"] = edm_weight
# center
if center_weight>0:
- center_pred_inside=tf.where(cell_mask, y_pred[1], 0) # do not predict anything outside
- center_loss = self.gdcm_loss(y[1], center_pred_inside)
+ center_loss = compute_loss_derivatives(y[1], gcdm, self.gcdm_loss, pred_dy=gcdm_dy, pred_dx=gcdm_dx, mask=cell_mask, mask_interior=cell_mask_interior, derivative_loss=self.gcdm_derivative_loss)
losses["center"] = center_loss
loss_weights["center"] = center_weight
#regression displacement loss
if displacement_weight>0:
- loss_dY, loss_dX = self._compute_displacement_loss(y, y_pred, cell_mask, spatial_gradients=self.gcdm_gradients)
+ loss_dY, loss_dX = self._compute_displacement_loss(y, y_pred, cell_mask)
losses["dY"] = loss_dY
loss_weights["dY"] = displacement_weight
losses["dX"] = loss_dX
loss_weights["dX"] = displacement_weight
# category loss
if category_weight>0:
@@ -121,15 +137,15 @@
# print(f"reg loss: {len(self.losses)} values: {self.losses}")
if len(self.losses)>0:
loss += tf.add_n(self.losses) # regularizers
losses["loss"] = loss
if self.print_gradients:
- trainable_vars_tape = [t for t in self.trainable_variables if (t.name.startswith("DecoderSegEDM") or t.name.startswith("DecoderCenter0") or t.name.startswith("DecoderTrackY0") or t.name.startswith("DecoderTrackX0") or t.name.startswith("DecoderCat0") or t.name.startswith("FeatureSequence/Op4") or t.name.startswith("Attention")) and ("/kernel" in t.name or "/wv" in t.name) ]
+ trainable_vars_tape = [t for t in self.trainable_variables if (t.name.startswith("DecoderSegEDM") or t.name.startswith("DecoderCenterGCDM") or t.name.startswith("DecoderTrackY0") or t.name.startswith("DecoderTrackX0") or t.name.startswith("DecoderCat0") or t.name.startswith("FeatureSequence/Op4") or t.name.startswith("Attention")) and ("/kernel" in t.name or "/wv" in t.name) ]
for loss_name, loss_value in losses.items():
if loss_name != "loss" :
w = loss_weights[loss_name] # outside tape: cannot modify loss_value -> need to apply w to gradient itself
if mixed_precision:
loss_value = self.optimizer.get_scaled_loss(loss_value)
gradients = tape.gradient(loss_value, trainable_vars_tape)
if mixed_precision:
@@ -153,45 +169,42 @@
if not self.use_grad_acc:
self.optimizer.apply_gradients(zip(gradients, self.trainable_variables)) #Update weights
else:
self.gradient_accumulator.accumulate_gradients(gradients)
self.gradient_accumulator.apply_gradients()
return losses
- def _compute_displacement_loss(self, y, y_pred, cell_mask, spatial_gradients:bool=False):
+ def _compute_displacement_loss(self, y, y_pred, cell_mask):
+ mask = self._to_pair_mask(cell_mask)
+ dy = tf.where(mask, y_pred[2], 0) # do not predict anything outside
+ dx = tf.where(mask, y_pred[3], 0) # do not predict anything outside
+ return self.displacement_loss(y[2], dy), self.displacement_loss(y[3], dx)
+
+ def _to_pair_mask(self, cell_mask):
fw = self.frame_window
mask = cell_mask[..., 1:]
if self.predict_next_displacement:
mask_next = cell_mask[..., :-1]
if self.long_term and fw > 1:
mask_center = tf.tile(mask[..., fw - 1:fw], [1, 1, 1, fw - 1])
if self.predict_next_displacement:
if self.next:
- mask = tf.concat( [mask, mask_center, cell_mask[..., -fw + 1:], mask_next, cell_mask[..., :fw - 1], mask_center], -1)
+ mask = tf.concat(
+ [mask, mask_center, cell_mask[..., -fw + 1:], mask_next, cell_mask[..., :fw - 1], mask_center],
+ -1)
else:
mask = tf.concat([mask, mask_center, mask_next, cell_mask[..., :fw - 1]], -1)
else:
if self.next:
mask = tf.concat([mask, mask_center, cell_mask[..., -fw + 1:]], -1)
else:
mask = tf.concat([mask, mask_center], -1)
elif self.predict_next_displacement:
mask = tf.concat([mask, mask_next], -1)
-
- dy_inside = tf.where(mask, y_pred[2], 0) # do not predict anything outside
- dx_inside = tf.where(mask, y_pred[3], 0) # do not predict anything outside
- if spatial_gradients:
- dy_dy, dx_dy = tf.image.image_gradients(y[2])
- dy_dx, dx_dx = tf.image.image_gradients(y[3])
- dy_dy_pred, dx_dy_pred = tf.image.image_gradients(y[2])
- dy_dy_pred, dx_dy_pred = tf.where(mask, dy_dy_pred, 0), tf.where(mask, dx_dy_pred, 0)
- dy_dx_pred, dx_dx_pred = tf.image.image_gradients(y[3])
- dy_dx_pred, dx_dx_pred = tf.where(mask, dy_dx_pred, 0), tf.where(mask, dx_dx_pred, 0)
- return 1./3 * (self.displacement_loss(y[2], dy_inside) + self.displacement_loss(dy_dy, dy_dy_pred) + self.displacement_loss(dx_dy, dx_dy_pred)), 1./3 * (self.displacement_loss(y[3], dx_inside) + self.displacement_loss(dy_dx, dy_dx_pred) + self.displacement_loss(dx_dx, dx_dx_pred))
- return self.displacement_loss(y[2], dy_inside), self.displacement_loss(y[3], dx_inside)
+ return mask
def _compute_category_loss(self, y, y_pred, n_frame_pairs, n_fp_mul):
cat_loss = 0.
for i in range(n_frame_pairs * n_fp_mul):
inside_mask = tf.math.greater(y[4][..., i:i + 1], 0)
cat_pred_inside = tf.where(inside_mask, y_pred[4][..., 3 * i:3 * i + 3], 1)
cat_loss = cat_loss + self.category_loss(y[4][..., i:i + 1], cat_pred_inside)
@@ -209,22 +222,23 @@
self.compile()
super().save(*args, **kwargs)
if inference:
self.set_inference(False)
self.trainable=True
self.compile()
+
def get_distnet_2d(input_shape,
frame_window:int,
next:bool,
config,
name: str="DiSTNet2D",
**kwargs):
- return get_distnet_2d_model(input_shape, upsampling_mode = config.upsampling_mode, downsampling_mode=config.downsampling_mode, skip_stop_gradient=False, skip_connections=config.skip_connections, encoder_settings=config.encoder_settings, feature_settings=config.feature_settings, feature_blending_settings=config.feature_blending_settings, decoder_settings=config.decoder_settings, feature_decoder_settings=config.feature_decoder_settings, attention=config.attention, attention_dropout=config.dropout, self_attention=config.self_attention, combine_kernel_size=config.combine_kernel_size, pair_combine_kernel_size=config.pair_combine_kernel_size, blending_filter_factor=config.blending_filter_factor, frame_window=frame_window, next=next, name=name, **kwargs)
+ return get_distnet_2d_model(input_shape, upsampling_mode=config.upsampling_mode, downsampling_mode=config.downsampling_mode, skip_stop_gradient=False, skip_connections=config.skip_connections, encoder_settings=config.encoder_settings, feature_settings=config.feature_settings, feature_blending_settings=config.feature_blending_settings, decoder_settings=config.decoder_settings, feature_decoder_settings=config.feature_decoder_settings, attention=config.attention, attention_dropout=config.dropout, self_attention=config.self_attention, combine_kernel_size=config.combine_kernel_size, pair_combine_kernel_size=config.pair_combine_kernel_size, blending_filter_factor=config.blending_filter_factor, frame_window=frame_window, next=next, name=name, **kwargs)
def get_distnet_2d_model(input_shape, # Y, X
encoder_settings:list,
feature_settings: list,
feature_blending_settings: list,
feature_decoder_settings:list,
decoder_settings: list,
@@ -239,15 +253,16 @@
attention : int = 0,
attention_dropout:float = 0.1,
self_attention: int = 0,
frame_window:int = 1,
next:bool=True,
long_term:bool = True,
predict_next_displacement:bool = True,
- gcdm_gradients:bool=False,
+ predict_edm_derivatives:bool = False,
+ predict_gcdm_derivatives:bool = False,
l2_reg:float = 0,
name: str="DiSTNet2D",
**kwargs,
):
total_contraction = np.prod([np.prod([params.get("downscale", 1) for params in param_list]) for param_list in encoder_settings])
assert len(encoder_settings)==len(decoder_settings), "decoder should have same length as encoder"
if attention>0 or self_attention>0:
@@ -298,39 +313,47 @@
feature_pair_skip_op = Combine(filters=feature_filters, l2_reg=l2_reg, name="FeaturePairSkip")
# define decoder operations
decoder_layers={"Seg":[], "Center":[], "Track":[], "Cat":[]}
get_seq_and_filters = lambda l : [l[i] for i in [0, 3]]
decoder_feature_op={n: get_seq_and_filters(parse_param_list(feature_decoder_settings, f"Features{n}", l2_reg=l2_reg, last_input_filters=feature_filters)) for n in decoder_layers.keys()}
decoder_out={"Seg":{}, "Center":{}, "Track":{}, "Cat":{}}
- output_per_decoder = {"Seg": ["EDM"], "Center": ["Center"], "Track": ["dY", "dX"] if not predict_next_displacement else ["dY", "dX", "dYNext", "dXNext"], "Cat": ["Cat"] if not predict_next_displacement else ["Cat", "CatNext"]}
+ output_per_decoder = {"Seg": ["EDM"], "Center": ["GCDM"], "Track": ["dY", "dX"] if not predict_next_displacement else ["dY", "dX", "dYNext", "dXNext"], "Cat": ["Cat"] if not predict_next_displacement else ["Cat", "CatNext"]}
+ if predict_edm_derivatives:
+ output_per_decoder["Seg"].append("EDMdY")
+ output_per_decoder["Seg"].append("EDMdX")
+ if predict_gcdm_derivatives:
+ output_per_decoder["Center"].append("GCDMdY")
+ output_per_decoder["Center"].append("GCDMdX")
n_frame_pairs = n_chan -1
if long_term:
n_frame_pairs = n_frame_pairs + (frame_window-1) * (2 if next else 1)
decoder_is_segmentation = {"Seg": True, "Center": True, "Track": False, "Cat": False}
skip_per_decoder = {"Seg": skip_connections, "Center": [], "Track": [], "Cat": []}
for l_idx, param_list in enumerate(decoder_settings):
if l_idx==0:
- decoder_out["Seg"]["EDM"] = decoder_op(**param_list, size_factor=contraction_per_layer[l_idx], mode=upsampling_mode, skip_combine_mode=skip_combine_mode, combine_kernel_size=combine_kernel_size, activation_out="linear", filters_out=1, l2_reg=l2_reg, layer_idx=l_idx, name=f"DecoderSegEDM")
- decoder_out["Center"]["Center"] = decoder_op(**param_list, size_factor=contraction_per_layer[l_idx], mode=upsampling_mode, skip_combine_mode=skip_combine_mode, combine_kernel_size=combine_kernel_size, activation_out="linear", filters_out=1, l2_reg=l2_reg, layer_idx=l_idx, name=f"DecoderCenter")
+ for dSegName in output_per_decoder["Seg"]:
+ decoder_out["Seg"][dSegName] = decoder_op(**param_list, size_factor=contraction_per_layer[l_idx], mode=upsampling_mode, skip_combine_mode=skip_combine_mode, combine_kernel_size=combine_kernel_size, activation_out="linear", filters_out=1, l2_reg=l2_reg, layer_idx=l_idx, name=f"DecoderSeg{dSegName}")
+ for dCenterName in output_per_decoder["Center"]:
+ decoder_out["Center"][dCenterName] = decoder_op(**param_list, size_factor=contraction_per_layer[l_idx], mode=upsampling_mode, skip_combine_mode=skip_combine_mode, combine_kernel_size=combine_kernel_size, activation_out="linear", filters_out=1, l2_reg=l2_reg, layer_idx=l_idx, name=f"DecoderCenter{dCenterName}")
for dTrackName in output_per_decoder["Track"]:
decoder_out["Track"][dTrackName] = decoder_op(**param_list, size_factor=contraction_per_layer[l_idx], mode=upsampling_mode, skip_combine_mode=skip_combine_mode, combine_kernel_size=combine_kernel_size, activation_out="linear", filters_out=1, l2_reg=l2_reg, layer_idx=l_idx, name=f"DecoderTrack{dTrackName}")
for dCatName in output_per_decoder["Cat"]:
decoder_out["Cat"][dCatName] = decoder_op(**param_list, size_factor=contraction_per_layer[l_idx], mode=upsampling_mode, skip_combine_mode=skip_combine_mode, combine_kernel_size=combine_kernel_size, activation_out="softmax", filters_out=3, l2_reg=l2_reg, layer_idx=l_idx, name=f"Decoder{dCatName}")
else:
for decoder_name, d_layers in decoder_layers.items():
d_layers.append( decoder_op(**param_list, size_factor=contraction_per_layer[l_idx], mode=upsampling_mode, skip_combine_mode=skip_combine_mode, combine_kernel_size=combine_kernel_size, activation="relu", l2_reg=l2_reg, layer_idx=l_idx, name=f"Decoder{decoder_name}") )
decoder_output_names = dict()
oi = 0
for n, o_ns in output_per_decoder.items():
decoder_output_names[n] = dict()
for o_n in o_ns:
- decoder_output_names[n][o_n] = f"Output{oi}_{o_n}"
+ decoder_output_names[n][o_n] = f"Output{oi:02}_{o_n}"
oi += 1
# Create GRAPH
input = tf.keras.layers.Input(shape=spatial_dims+[n_chan], name="Input")
print(f"input dims: {input.shape}")
input_merged = ChannelToBatch(compensate_gradient = False, name = "MergeInputs")(input)
downsampled = [input_merged]
residuals = []
@@ -397,26 +420,33 @@
for i, (l, res) in enumerate(zip(d_layers[::-1], residuals[:-1])):
up = l([up, res if len(d_layers)-1-i in skip else None])
output_per_dec = dict()
for output_name in output_per_decoder[decoder_name]:
if output_name in decoder_out[decoder_name]:
d_out = decoder_out[decoder_name][output_name]
layer_output_name = decoder_output_names[decoder_name][output_name]
- if not is_segmentation and predict_next_displacement:
+ if not is_segmentation and predict_next_displacement or decoder_name=="Seg" and predict_edm_derivatives or decoder_name== "Center" and predict_gcdm_derivatives:
layer_output_name += "_" # will be concatenated -> output name is used @ concat
up_out = d_out([up, residuals[-1] if 0 in skip else None]) # (N_OUT x B, Y, X, F)
up_out = BatchToChannel(n_splits = n_chan if is_segmentation else n_frame_pairs, compensate_gradient = False, name = layer_output_name)(up_out)
output_per_dec[output_name] = up_out
if predict_next_displacement: # merge outputs ending by Next
for k in list(output_per_dec.keys()):
if k.endswith("Next"):
output_name = k.replace("Next", "")
output_per_dec[output_name] = tf.keras.layers.Concatenate(axis = -1, name = decoder_output_names[decoder_name][output_name])([output_per_dec[output_name], output_per_dec.pop(k)])
+ if decoder_name=="Seg" and predict_edm_derivatives:
+ output_name = "EDM"
+ output_per_dec[output_name] = tf.keras.layers.Concatenate(axis=-1, name=decoder_output_names[decoder_name][output_name])([output_per_dec[output_name], output_per_dec.pop("EDMdY"), output_per_dec.pop("EDMdX")])
+ if decoder_name=="Center" and predict_gcdm_derivatives:
+ output_name = "GCDM"
+ output_per_dec[output_name] = tf.keras.layers.Concatenate(axis=-1, name=decoder_output_names[decoder_name][output_name])([output_per_dec[output_name], output_per_dec.pop("GCDMdY"), output_per_dec.pop("GCDMdX")])
outputs.extend(output_per_dec.values())
- return DiSTNetModel([input], outputs, name=name, frame_window=frame_window, next = next, spatial_dims=spatial_dims if attention > 0 or self_attention > 0 else None, long_term=long_term, predict_next_displacement=predict_next_displacement, gcdm_gradients=gcdm_gradients, **kwargs)
+ return DiSTNetModel([input], outputs, name=name, frame_window=frame_window, next=next, spatial_dims=spatial_dims if attention > 0 or self_attention > 0 else None, long_term=long_term, predict_next_displacement=predict_next_displacement, predict_gcdm_derivatives=predict_gcdm_derivatives, predict_edm_derivatives=predict_edm_derivatives, **kwargs)
+
def encoder_op(param_list, downsampling_mode, skip_stop_gradient:bool = False, l2_reg:float=0, last_input_filters:int=0, name: str="EncoderLayer", layer_idx:int=1):
name=f"{name}{layer_idx}"
maxpool = downsampling_mode=="maxpool"
maxpool_and_stride = downsampling_mode == "maxpool_and_stride"
sequence, down_sequence, total_contraction, residual_filters, out_filters = parse_param_list(param_list, name, ignore_stride=maxpool, l2_reg=l2_reg, last_input_filters = last_input_filters if maxpool_and_stride else 0)
assert total_contraction>1, "invalid parameters: no contraction specified"