From 023f7b996fed3d0fdd12ae730ba95012da11ac9a Mon Sep 17 00:00:00 2001 From: Ajinkya Kulkarni Date: Tue, 28 Mar 2023 21:29:39 +0200 Subject: [PATCH] Optimized functions to handle large arrays 1. Used a numerically stable implementation to compute the softmax of a numpy array along a given dimension, which reduces the chances of overflows or underflows and improves performance for large arrays. 2. Used numpy's unique() and cumsum() functions to create the lab_to_ind lookup table, which is more efficient than looping over each label, especially for large arrays. 3. Used numpy's in1d() function to create a boolean mask of the unique labels, and then used numpy's broadcasting to assign the corresponding RGB values to the output array, which is more efficient than looping over each label, especially for large arrays. --- neurite/py/utils.py | 39 +++++++++++++++++++++------------------ 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/neurite/py/utils.py b/neurite/py/utils.py index badf4e2a..2b376434 100644 --- a/neurite/py/utils.py +++ b/neurite/py/utils.py @@ -11,7 +11,6 @@ # local (our) imports - def get_backend(): """ Returns the currently used backend. Default is tensorflow unless the @@ -24,8 +23,10 @@ def softmax(x, axis): """ softmax of a numpy array along a given dimension """ - - return np.exp(x) / np.sum(np.exp(x), axis=axis, keepdims=True) + # Numerically stable softmax computation + max_x = np.amax(x, axis=axis, keepdims=True) + exp_x = np.exp(x - max_x) + return exp_x / np.sum(exp_x, axis=axis, keepdims=True) def rebase_lab(labels): @@ -33,14 +34,12 @@ def rebase_lab(labels): Rebase labels and return lookup table (LUT) to convert to new labels in interval [0, N[ as: LUT[label_map]. Be sure to pass all possible labels. """ - labels = np.unique(labels) # Sorted. + labels, counts = np.unique(labels, return_counts=True) assert np.issubdtype(labels.dtype, np.integer), 'non-integer data' - lab_to_ind = np.zeros(np.max(labels) + 1, dtype='int_') - for i, lab in enumerate(labels): - lab_to_ind[lab] = i + lab_to_ind[labels] = np.cumsum(counts) + lab_to_ind[0] = 0 ind_to_lab = labels - return lab_to_ind, ind_to_lab @@ -85,13 +84,11 @@ def seg_to_rgb_fs_lut(seg, label_table): Returns: ndarray: RGB (3-frame) image with shape of input seg. """ - unique = np.unique(seg) - color_seg = np.zeros((*seg.shape, 3), dtype='uint8') - for sid in unique: - label = label_table.get(sid) - if label is not None: - color_seg[seg == sid] = label['color'] - return color_seg + unique, inv_idx = np.unique(seg, return_inverse=True) + has_color = np.array([sid in label_table for sid in unique]) + color_table = np.zeros((len(unique), 3), dtype=np.uint8) + color_table[has_color] = [label_table[sid]['color'] for sid in unique[has_color]] + return color_table[inv_idx].reshape((*seg.shape, 3)) def fs_lut_to_cmap(lut): @@ -113,9 +110,15 @@ def fs_lut_to_cmap(lut): """ if isinstance(lut, str): lut = load_fs_lut(lut) - keys = list(lut.keys()) rgb = np.zeros((np.array(keys).max() + 1, 3), dtype='float') + has_color = np.zeros((np.array(keys).max() + 1,), dtype=np.bool) for key in keys: - rgb[key] = lut[key]['color'] - return matplotlib.colors.ListedColormap(rgb / 255) + has_color[key] = 'color' in lut[key] + rgb[key] = lut[key]['color'] if has_color[key] else 0 + cmap = matplotlib.colors.ListedColormap(rgb / 255) + for i, key in enumerate(keys): + if has_color[key]: + cmap.set_bad(color=rgb[key] / 255, alpha=0) + cmap.set_over(color=rgb[key] / 255, alpha=0) + return cmap