@@ -618,7 +618,7 @@ def _set_color_source_vec(
618
618
619
619
# numerical case, return early
620
620
if color_source_vector is not None and not isinstance (color_source_vector .dtype , pd .CategoricalDtype ):
621
- if palette [element_index ][ 0 ] is not None :
621
+ if palette [0 ] is not None :
622
622
logger .warning (
623
623
"Ignoring categorical palette which is given for a continuous variable. "
624
624
"Consider using `cmap` to pass a ColorMap."
@@ -632,7 +632,16 @@ def _set_color_source_vec(
632
632
color_source_vector = color_source_vector .remove_categories (categories .difference (groups ))
633
633
categories = groups
634
634
635
- color_map = dict (zip (categories , _get_colors_for_categorical_obs (categories , palette , cmap_params = cmap_params )))
635
+ if groups is not None :
636
+ palette_input = palette [0 ] if palette [0 ] is None else palette
637
+ elif palette is not None :
638
+ palette_input = palette [0 ]
639
+ else :
640
+ palette_input = palette
641
+
642
+ color_map = dict (
643
+ zip (categories , _get_colors_for_categorical_obs (categories , palette_input , cmap_params = cmap_params ))
644
+ )
636
645
637
646
if color_map is None :
638
647
raise ValueError ("Unable to create color palette." )
@@ -1797,6 +1806,8 @@ def _match_length_elements_groups_palette(
1797
1806
params .groups = [groups [0 ] for _ in range (len (render_elements ))]
1798
1807
if palette is not None :
1799
1808
params .palette = [palette [0 ] for _ in range (len (render_elements ))]
1809
+ else :
1810
+ params .palette = [[None ] for _ in range (len (render_elements ))]
1800
1811
else :
1801
1812
if len (groups ) != len (render_elements ):
1802
1813
raise ValueError (
@@ -1835,7 +1846,8 @@ def _update_params(sdata, params, wanted_elements_on_cs, element_type: Literal["
1835
1846
else :
1836
1847
params = _validate_colors_element_table_mapping_points_shapes (sdata , params , wanted_elements_on_cs )
1837
1848
1838
- if params .palette is None :
1839
- params .palette = [[None ] for _ in wanted_elements_on_cs ]
1849
+ # if params.palette is None:
1850
+ # params.palette = [[None] for _ in wanted_elements_on_cs]
1840
1851
image_flag = element_type == "images"
1841
- return _match_length_elements_groups_palette (params , wanted_elements_on_cs , image = image_flag )
1852
+ params = _match_length_elements_groups_palette (params , wanted_elements_on_cs , image = image_flag )
1853
+ return params
0 commit comments