Skip to content

Commit 7772f9b

Browse files
committed
bugfix and speedup
1 parent 417565c commit 7772f9b

File tree

2 files changed

+12
-3
lines changed

2 files changed

+12
-3
lines changed

src/spatialdata_plot/pl/basic.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -666,11 +666,16 @@ def show(
666666
# If any of the previous conditions are not met, generate random
667667
# colors for each cell id
668668

669+
N_DISTINCT_FOR_RANDOM = 30
670+
669671
if sdata.table is not None:
670672
# annoying case since number of cells in labels can be
671673
# different from number of cells in table. So we just use
672674
# the index and randomise colours for it
673675

676+
# add fake column for limiting the amount of different colors
677+
sdata.table.obs["fake"] = np.random.randint(0, N_DISTINCT_FOR_RANDOM, sdata.table.obs.shape[0])
678+
674679
# has a table, so it has a region key
675680
region_key = _get_region_key(sdata)
676681

@@ -681,7 +686,7 @@ def show(
681686
region_key = _get_region_key(sdata)
682687
instance_key = _get_instance_key(sdata)
683688
params["instance_key"] = instance_key
684-
params["color_key"] = instance_key
689+
params["color_key"] = "fake"
685690
params["add_legend"] = False
686691
# TODO(ttreis) log the decision not to display a legend
687692

@@ -693,7 +698,7 @@ def show(
693698
cell_ids_per_label = {}
694699
for key in list(sdata.labels.keys()):
695700
cell_ids_per_label[key] = sdata.labels[key].values.max()
696-
701+
print(cell_ids_per_label)
697702
region_key = "tmp_label_id"
698703
instance_key = "tmp_cell_id"
699704
params["instance_key"] = instance_key
@@ -708,10 +713,11 @@ def show(
708713
}
709714
)
710715

716+
tmp_table["fake"] = np.random.randint(0, N_DISTINCT_FOR_RANDOM, sdata.table.obs.shape[0])
711717
distinct_cells = max(list(cell_ids_per_label.values()))
712718

713719
if sdata.table is not None:
714-
print("Plotting a lot of cells with random colors, might take a while...")
720+
# print("Plotting a lot of cells with random colors, might take a while...")
715721
sdata.table.uns[f"{instance_key}_colors"] = _get_random_hex_colors(distinct_cells)
716722

717723
elif sdata.table is None:

src/spatialdata_plot/pl/render.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,10 @@ def _render_labels(
161161
ax.set_ylim(extent["y"][0], extent["y"][1])
162162

163163
for group in groups:
164+
165+
# Getting cell ids belonging to group and casting them to int for later numpy comparisons
164166
vaid_cell_ids = table[table[params["color_key"]] == group][params["instance_key"]].values
167+
vaid_cell_ids = [int(id) for id in vaid_cell_ids]
165168

166169
# define all out-of-group cells as background
167170
in_group_mask = segmentation.copy()

0 commit comments

Comments
 (0)