@@ -15,6 +15,7 @@ use pyo3::{
1515 prelude:: * ,
1616 types:: { PyDict , PyList , PyTuple } ,
1717} ;
18+ use ndarray:: Array2 ;
1819use std:: marker:: PhantomData ;
1920
2021#[ cfg( feature = "curve-sampling" ) ]
@@ -208,19 +209,13 @@ impl Axes {
208209 }
209210 }
210211
211- #[ must_use]
212- pub fn scatter < D > ( & mut self , x : D , y : D ) -> & mut Self
212+ /// Scatter plot of `y` vs. `x` with optional varying marker size
213+ /// and/or color.
214+ pub fn scatter < ' a , D > ( & ' a mut self , x : D , y : D ) -> Scatter < ' a , D >
213215 where
214216 D : AsRef < [ f64 ] > ,
215217 {
216- // FIXME: Do we want to check that `x` and `y` have the same
217- // dimension? Better error message?
218- meth ! ( self . ax, scatter, py -> {
219- let xn = x. as_ref( ) . to_pyarray( py) ;
220- let yn = y. as_ref( ) . to_pyarray( py) ;
221- ( xn, yn) } )
222- . unwrap ( ) ;
223- self
218+ Scatter :: new ( self , x, y)
224219 }
225220
226221 /// Set the title to `txt` for the Axes.
@@ -666,6 +661,172 @@ where
666661 }
667662}
668663
664+ #[ must_use]
665+ pub struct Scatter < ' a , D > {
666+ axes : & ' a Axes ,
667+ x : D ,
668+ y : D ,
669+ // Optional arguments are different from other plot types.
670+ s : Option < & ' a [ f64 ] > ,
671+ c : Option < ScatterColorMat < ' a > > , // Slice of RGBA
672+ marker : Option < & ' a str > , // TODO: generalize
673+ cmap : Option < ( ) > , // TODO
674+ norm : Option < ( ) > , // TODO
675+ // FIXME: It is an error to use vmin/vmax when a norm instance is
676+ // given (but using a str norm name together with vmin/vmax is
677+ // acceptable).
678+ vmin : Option < f64 > ,
679+ vmax : Option < f64 > ,
680+ alpha : Option < f64 > , // ∈ [0, 1]
681+ linewidths : Option < f64 > , // TODO: support array-like
682+ // edgecolors
683+ // colorizer
684+ // plotnonfinite
685+ }
686+
687+ impl < ' a , D > Scatter < ' a , D >
688+ where D : AsRef < [ f64 ] > {
689+ fn new ( axes : & ' a Axes , x : D , y : D ) -> Self {
690+ Self {
691+ axes, x, y,
692+ s : None , c : None , marker : None , cmap : None , norm : None ,
693+ vmin : None , vmax : None , alpha : None , linewidths : None ,
694+ }
695+ }
696+
697+ /// The marker size in points² (typographic points are 1/72 in).
698+ ///
699+ /// Default is rcParams['lines.markersize'] ** 2.
700+ pub fn s ( mut self , s : & ' a [ f64 ] ) -> Self {
701+ self . s = Some ( s) ;
702+ self
703+ }
704+
705+ /// Specify the marker colors.
706+ pub fn c < C > ( mut self , colors : impl ScatterColors ) -> Self
707+ where C : Color ,
708+ {
709+ self . c = Some ( ScatterColorMat :: Colors ( colors. as_mat ( ) ) ) ;
710+ self
711+ }
712+
713+ /// Specify the marker colors as a sequence of `n` numbers to be
714+ /// mapped to colors using `cmap` and `norm` where `n` is the
715+ /// length of the data (see [`Axes::scatter`]).
716+ pub fn cm ( mut self , colors : & ' a [ usize ] ) -> Self {
717+ self . c = Some ( ScatterColorMat :: Cmap ( colors) ) ;
718+ self
719+ }
720+
721+ /// Set the marker style.
722+ pub fn marker ( mut self , m : & ' a str ) -> Self {
723+ self . marker = Some ( m) ;
724+ self
725+ }
726+
727+ /// When using scalar data and no explicit `norm`, `vmin` and
728+ /// [`vmax`] define the data range that the colormap covers.
729+ pub fn vmin ( mut self , v : f64 ) -> Self {
730+ self . vmin = Some ( v) ;
731+ self
732+ }
733+
734+ /// When using scalar data and no explicit `norm`, [`vmin`] and
735+ /// `vmax` define the data range that the colormap covers.
736+ pub fn vmax ( mut self , v : f64 ) -> Self {
737+ self . vmax = Some ( v) ;
738+ self
739+ }
740+
741+ /// Set the alpha blending value, between 0 (transparent) and 1
742+ /// (opaque).
743+ pub fn alpha ( mut self , alpha : f64 ) -> Self {
744+ self . alpha = Some ( alpha. clamp ( 0. , 1. ) ) ;
745+ self
746+ }
747+
748+ /// The linewidth of the marker edges.
749+ ///
750+ /// Note: The default `edgecolors` is "face". You may want to
751+ /// change this as well.
752+ pub fn linewidths ( mut self , lw : f64 ) -> Self {
753+ self . linewidths = Some ( lw) ;
754+ self
755+ }
756+
757+ pub fn plot ( self ) {
758+ // FIXME: Do we want to check that `x` and `y` have the same
759+ // dimension? Better error message?
760+ Python :: attach ( |py| {
761+ match self . c {
762+ Some ( ScatterColorMat :: Cmap ( v) ) => {
763+ self . plot_with_colors ( py, v. to_pyarray ( py) ) ;
764+ }
765+ Some ( ScatterColorMat :: Colors ( ref m) ) => {
766+ self . plot_with_colors ( py, m. to_pyarray ( py) ) ;
767+ }
768+ None => self . plot_with_colors ( py, None :: < & str > ) ,
769+ }
770+ } )
771+ }
772+
773+ fn plot_with_colors < ' py > (
774+ & self ,
775+ py : Python < ' py > ,
776+ c : impl IntoPyObject < ' py > ,
777+ ) {
778+ let xn = self . x . as_ref ( ) . to_pyarray ( py) ;
779+ let yn = self . y . as_ref ( ) . to_pyarray ( py) ;
780+ self . axes . ax . call_method1 ( py, intern ! ( py, "scatter" ) ,
781+ ( xn, yn, self . s , c, self . marker , self . cmap ,
782+ self . norm , self . vmin , self . vmax , self . alpha ,
783+ self . linewidths ) )
784+ . unwrap ( ) ;
785+ }
786+ }
787+
788+ enum ScatterColorMat < ' a > {
789+ Cmap ( & ' a [ usize ] ) ,
790+ Colors ( ndarray:: Array2 < f64 > ) ,
791+ }
792+
793+ /// Possible color specifications for [`Axes::scatter`] plots.
794+ pub trait ScatterColors {
795+ #[ doc( hidden) ]
796+ fn as_mat ( & self ) -> ndarray:: Array2 < f64 > ;
797+ }
798+
799+ impl < C > ScatterColors for & [ C ]
800+ where
801+ C : Color ,
802+ {
803+ fn as_mat ( & self ) -> ndarray:: Array2 < f64 > {
804+ let colors = self . as_ref ( ) ;
805+ let n = colors. len ( ) ;
806+ let mut c: Array2 < f64 > = ndarray:: Array2 :: zeros ( ( n, 4 ) ) ;
807+ for i in 0 .. n {
808+ let ci = colors[ i] . rgba ( ) ;
809+ for j in 0 .. 4 {
810+ c[ ( i, j) ] = ci[ j] ;
811+ }
812+ }
813+ c
814+ }
815+ }
816+
817+ impl < C : Color > ScatterColors for C {
818+ fn as_mat ( & self ) -> ndarray:: Array2 < f64 > {
819+ // A single row array gives the same color for all markers.
820+ let mut c = ndarray:: Array2 :: zeros ( ( 1 , 4 ) ) ;
821+ let color = self . rgba ( ) ;
822+ c[ ( 0 , 0 ) ] = color[ 0 ] ;
823+ c[ ( 0 , 1 ) ] = color[ 1 ] ;
824+ c[ ( 0 , 2 ) ] = color[ 2 ] ;
825+ c[ ( 0 , 3 ) ] = color[ 3 ] ;
826+ c
827+ }
828+ }
829+
669830pub struct QuadContourSet {
670831 contours : Py < PyAny > ,
671832}
0 commit comments