Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 21 additions & 18 deletions neurite/py/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

# local (our) imports


def get_backend():
"""
Returns the currently used backend. Default is tensorflow unless the
Expand All @@ -24,23 +23,23 @@ def softmax(x, axis):
"""
softmax of a numpy array along a given dimension
"""

return np.exp(x) / np.sum(np.exp(x), axis=axis, keepdims=True)
# Numerically stable softmax computation
max_x = np.amax(x, axis=axis, keepdims=True)
exp_x = np.exp(x - max_x)
return exp_x / np.sum(exp_x, axis=axis, keepdims=True)


def rebase_lab(labels):
"""
Rebase labels and return lookup table (LUT) to convert to new labels in
interval [0, N[ as: LUT[label_map]. Be sure to pass all possible labels.
"""
labels = np.unique(labels) # Sorted.
labels, counts = np.unique(labels, return_counts=True)
assert np.issubdtype(labels.dtype, np.integer), 'non-integer data'

lab_to_ind = np.zeros(np.max(labels) + 1, dtype='int_')
for i, lab in enumerate(labels):
lab_to_ind[lab] = i
lab_to_ind[labels] = np.cumsum(counts)
lab_to_ind[0] = 0
ind_to_lab = labels

return lab_to_ind, ind_to_lab


Expand Down Expand Up @@ -85,13 +84,11 @@ def seg_to_rgb_fs_lut(seg, label_table):
Returns:
ndarray: RGB (3-frame) image with shape of input seg.
"""
unique = np.unique(seg)
color_seg = np.zeros((*seg.shape, 3), dtype='uint8')
for sid in unique:
label = label_table.get(sid)
if label is not None:
color_seg[seg == sid] = label['color']
return color_seg
unique, inv_idx = np.unique(seg, return_inverse=True)
has_color = np.array([sid in label_table for sid in unique])
color_table = np.zeros((len(unique), 3), dtype=np.uint8)
color_table[has_color] = [label_table[sid]['color'] for sid in unique[has_color]]
return color_table[inv_idx].reshape((*seg.shape, 3))


def fs_lut_to_cmap(lut):
Expand All @@ -113,9 +110,15 @@ def fs_lut_to_cmap(lut):
"""
if isinstance(lut, str):
lut = load_fs_lut(lut)

keys = list(lut.keys())
rgb = np.zeros((np.array(keys).max() + 1, 3), dtype='float')
has_color = np.zeros((np.array(keys).max() + 1,), dtype=np.bool)
for key in keys:
rgb[key] = lut[key]['color']
return matplotlib.colors.ListedColormap(rgb / 255)
has_color[key] = 'color' in lut[key]
rgb[key] = lut[key]['color'] if has_color[key] else 0
cmap = matplotlib.colors.ListedColormap(rgb / 255)
for i, key in enumerate(keys):
if has_color[key]:
cmap.set_bad(color=rgb[key] / 255, alpha=0)
cmap.set_over(color=rgb[key] / 255, alpha=0)
return cmap