@@ -711,8 +711,10 @@ def _set_color_source_vec(
711
711
groups : list [str ] | str | None = None ,
712
712
palette : list [str ] | str | None = None ,
713
713
cmap_params : CmapParams | None = None ,
714
+ alpha : float = 1.0 ,
714
715
table_name : str | None = None ,
715
716
table_layer : str | None = None ,
717
+ render_type : Literal ["points" ] | None = None ,
716
718
) -> tuple [ArrayLike | pd .Series | None , ArrayLike , bool ]:
717
719
if value_to_plot is None and element is not None :
718
720
color = np .full (len (element ), na_color )
@@ -757,9 +759,12 @@ def _set_color_source_vec(
757
759
adata = sdata .table ,
758
760
cluster_key = value_to_plot ,
759
761
color_source_vector = color_source_vector ,
762
+ cmap_params = cmap_params ,
763
+ alpha = alpha ,
760
764
groups = groups ,
761
765
palette = palette ,
762
766
na_color = na_color ,
767
+ render_type = render_type ,
763
768
)
764
769
765
770
color_source_vector = color_source_vector .set_categories (color_mapping .keys ())
@@ -912,15 +917,28 @@ def _get_categorical_color_mapping(
912
917
na_color : ColorLike ,
913
918
cluster_key : str | None = None ,
914
919
color_source_vector : ArrayLike | pd .Series [CategoricalDtype ] | None = None ,
920
+ cmap_params : CmapParams | None = None ,
921
+ alpha : float = 1 ,
915
922
groups : list [str ] | str | None = None ,
916
923
palette : list [str ] | str | None = None ,
924
+ render_type : Literal ["points" ] | None = None ,
917
925
) -> Mapping [str , str ]:
918
926
if not isinstance (color_source_vector , Categorical ):
919
927
raise TypeError (f"Expected `categories` to be a `Categorical`, but got { type (color_source_vector ).__name__ } " )
920
928
921
929
if isinstance (groups , str ):
922
930
groups = [groups ]
923
931
932
+ if not palette and render_type == "points" and cmap_params is not None and not cmap_params .cmap_is_default :
933
+ palette = cmap_params .cmap
934
+
935
+ color_idx = color_idx = np .linspace (0 , 1 , len (color_source_vector .categories ))
936
+ if isinstance (palette , ListedColormap ):
937
+ palette = [to_hex (x ) for x in palette (color_idx , alpha = alpha )]
938
+ elif isinstance (palette , LinearSegmentedColormap ):
939
+ palette = [to_hex (palette (x , alpha = alpha )) for x in color_idx ] # type: ignore[attr-defined]
940
+ return dict (zip (color_source_vector .categories , palette , strict = True ))
941
+
924
942
if isinstance (palette , str ):
925
943
palette = [palette ]
926
944
@@ -2011,7 +2029,7 @@ def _is_coercable_to_float(series: pd.Series) -> bool:
2011
2029
2012
2030
2013
2031
def _ax_show_and_transform (
2014
- array : MaskedArray [tuple [int , ...], Any ],
2032
+ array : MaskedArray [tuple [int , ...], Any ] | npt . NDArray [ Any ] ,
2015
2033
trans_data : CompositeGenericTransform ,
2016
2034
ax : Axes ,
2017
2035
alpha : float | None = None ,
0 commit comments