diff --git a/neurite/py/utils.py b/neurite/py/utils.py index badf4e2a..2b376434 100644 --- a/neurite/py/utils.py +++ b/neurite/py/utils.py @@ -11,7 +11,6 @@ # local (our) imports - def get_backend(): """ Returns the currently used backend. Default is tensorflow unless the @@ -24,8 +23,10 @@ def softmax(x, axis): """ softmax of a numpy array along a given dimension """ - - return np.exp(x) / np.sum(np.exp(x), axis=axis, keepdims=True) + # Numerically stable softmax computation + max_x = np.amax(x, axis=axis, keepdims=True) + exp_x = np.exp(x - max_x) + return exp_x / np.sum(exp_x, axis=axis, keepdims=True) def rebase_lab(labels): @@ -33,14 +34,12 @@ def rebase_lab(labels): Rebase labels and return lookup table (LUT) to convert to new labels in interval [0, N[ as: LUT[label_map]. Be sure to pass all possible labels. """ - labels = np.unique(labels) # Sorted. + labels, counts = np.unique(labels, return_counts=True) assert np.issubdtype(labels.dtype, np.integer), 'non-integer data' - lab_to_ind = np.zeros(np.max(labels) + 1, dtype='int_') - for i, lab in enumerate(labels): - lab_to_ind[lab] = i + lab_to_ind[labels] = np.cumsum(counts) + lab_to_ind[0] = 0 ind_to_lab = labels - return lab_to_ind, ind_to_lab @@ -85,13 +84,11 @@ def seg_to_rgb_fs_lut(seg, label_table): Returns: ndarray: RGB (3-frame) image with shape of input seg. """ - unique = np.unique(seg) - color_seg = np.zeros((*seg.shape, 3), dtype='uint8') - for sid in unique: - label = label_table.get(sid) - if label is not None: - color_seg[seg == sid] = label['color'] - return color_seg + unique, inv_idx = np.unique(seg, return_inverse=True) + has_color = np.array([sid in label_table for sid in unique]) + color_table = np.zeros((len(unique), 3), dtype=np.uint8) + color_table[has_color] = [label_table[sid]['color'] for sid in unique[has_color]] + return color_table[inv_idx].reshape((*seg.shape, 3)) def fs_lut_to_cmap(lut): @@ -113,9 +110,15 @@ def fs_lut_to_cmap(lut): """ if isinstance(lut, str): lut = load_fs_lut(lut) - keys = list(lut.keys()) rgb = np.zeros((np.array(keys).max() + 1, 3), dtype='float') + has_color = np.zeros((np.array(keys).max() + 1,), dtype=np.bool) for key in keys: - rgb[key] = lut[key]['color'] - return matplotlib.colors.ListedColormap(rgb / 255) + has_color[key] = 'color' in lut[key] + rgb[key] = lut[key]['color'] if has_color[key] else 0 + cmap = matplotlib.colors.ListedColormap(rgb / 255) + for i, key in enumerate(keys): + if has_color[key]: + cmap.set_bad(color=rgb[key] / 255, alpha=0) + cmap.set_over(color=rgb[key] / 255, alpha=0) + return cmap