Skip to content

Commit 24193e5

Browse files
committed
Improve the scatter plots
1 parent dc85476 commit 24193e5

3 files changed

Lines changed: 201 additions & 10 deletions

File tree

Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,5 @@ default = ["curve-sampling"]
2929
doc-comment = "0.3.4"
3030
polars-core = { version = "0.53.0", features = ["fmt"] }
3131
anyhow = "1.0.102"
32+
rand = "0.10.0"
33+
rand_distr = "=0.6.0"

examples/plot_types.rs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ static BASE: &'static str = "target/plot_types_";
88
fn main() -> anyhow::Result<()> {
99
// Pairwise data
1010
plot_xy()?;
11+
scatter_xy()?;
1112
Ok(())
1213
}
1314

@@ -34,3 +35,30 @@ fn plot_xy() -> anyhow::Result<()> {
3435
fig.save().to_file(format!("{BASE}plot_xy.pdf"))?;
3536
Ok(())
3637
}
38+
39+
fn scatter_xy() -> anyhow::Result<()> {
40+
use rand::RngExt;
41+
use rand_distr::{Normal, Distribution};
42+
mpl::style::using("_mpl-gallery")?;
43+
44+
fn vec<T>(mut f: impl FnMut() -> T) -> Vec<T> {
45+
(0..24).map(|_| f()).collect()
46+
}
47+
48+
let mut rng = rand::rng();
49+
let n = Normal::new(0., 2.)?;
50+
let x = vec(|| 4. + n.sample(&mut rng));
51+
let y = vec(|| 4. + n.sample(&mut rng));
52+
let sizes = vec(|| rng.random_range(15. .. 80.));
53+
let colors = vec(|| rng.random_range(15 .. 80));
54+
55+
let fig = Figure::new()?;
56+
let [[mut ax]] = fig.subplots()?;
57+
58+
ax.scatter(&x, &y).s(&sizes).cm(&colors).vmin(0.).vmax(100.).plot();
59+
60+
ax.set_xlim(0., 8.) .set_xticks((1..8).map(f64::from))
61+
.set_ylim(0., 8.) .set_yticks((1..8).map(f64::from));
62+
fig.save().to_file(format!("{BASE}scatter_xy.pdf"))?;
63+
Ok(())
64+
}

src/axes.rs

Lines changed: 171 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ use pyo3::{
1515
prelude::*,
1616
types::{PyDict, PyList, PyTuple},
1717
};
18+
use ndarray::Array2;
1819
use std::marker::PhantomData;
1920

2021
#[cfg(feature = "curve-sampling")]
@@ -208,19 +209,13 @@ impl Axes {
208209
}
209210
}
210211

211-
#[must_use]
212-
pub fn scatter<D>(&mut self, x: D, y: D) -> &mut Self
212+
/// Scatter plot of `y` vs. `x` with optional varying marker size
213+
/// and/or color.
214+
pub fn scatter<'a, D>(&'a mut self, x: D, y: D) -> Scatter<'a, D>
213215
where
214216
D: AsRef<[f64]>,
215217
{
216-
// FIXME: Do we want to check that `x` and `y` have the same
217-
// dimension? Better error message?
218-
meth!(self.ax, scatter, py -> {
219-
let xn = x.as_ref().to_pyarray(py);
220-
let yn = y.as_ref().to_pyarray(py);
221-
(xn, yn) })
222-
.unwrap();
223-
self
218+
Scatter::new(self, x, y)
224219
}
225220

226221
/// Set the title to `txt` for the Axes.
@@ -666,6 +661,172 @@ where
666661
}
667662
}
668663

664+
#[must_use]
665+
pub struct Scatter<'a, D> {
666+
axes: &'a Axes,
667+
x: D,
668+
y: D,
669+
// Optional arguments are different from other plot types.
670+
s: Option<&'a [f64]>,
671+
c: Option<ScatterColorMat<'a>>, // Slice of RGBA
672+
marker: Option<&'a str>, // TODO: generalize
673+
cmap: Option<()>, // TODO
674+
norm: Option<()>, // TODO
675+
// FIXME: It is an error to use vmin/vmax when a norm instance is
676+
// given (but using a str norm name together with vmin/vmax is
677+
// acceptable).
678+
vmin: Option<f64>,
679+
vmax: Option<f64>,
680+
alpha: Option<f64>, // ∈ [0, 1]
681+
linewidths: Option<f64>, // TODO: support array-like
682+
// edgecolors
683+
// colorizer
684+
// plotnonfinite
685+
}
686+
687+
impl<'a, D> Scatter<'a, D>
688+
where D: AsRef<[f64]> {
689+
fn new(axes: &'a Axes, x: D, y: D) -> Self {
690+
Self {
691+
axes, x, y,
692+
s: None, c: None, marker: None, cmap: None, norm: None,
693+
vmin: None, vmax: None, alpha: None, linewidths: None,
694+
}
695+
}
696+
697+
/// The marker size in points² (typographic points are 1/72 in).
698+
///
699+
/// Default is rcParams['lines.markersize'] ** 2.
700+
pub fn s(mut self, s: &'a [f64]) -> Self {
701+
self.s = Some(s);
702+
self
703+
}
704+
705+
/// Specify the marker colors.
706+
pub fn c<C>(mut self, colors: impl ScatterColors) -> Self
707+
where C: Color,
708+
{
709+
self.c = Some(ScatterColorMat::Colors(colors.as_mat()));
710+
self
711+
}
712+
713+
/// Specify the marker colors as a sequence of `n` numbers to be
714+
/// mapped to colors using `cmap` and `norm` where `n` is the
715+
/// length of the data (see [`Axes::scatter`]).
716+
pub fn cm(mut self, colors: &'a [usize]) -> Self {
717+
self.c = Some(ScatterColorMat::Cmap(colors));
718+
self
719+
}
720+
721+
/// Set the marker style.
722+
pub fn marker(mut self, m: &'a str) -> Self {
723+
self.marker = Some(m);
724+
self
725+
}
726+
727+
/// When using scalar data and no explicit `norm`, `vmin` and
728+
/// [`vmax`] define the data range that the colormap covers.
729+
pub fn vmin(mut self, v: f64) -> Self {
730+
self.vmin = Some(v);
731+
self
732+
}
733+
734+
/// When using scalar data and no explicit `norm`, [`vmin`] and
735+
/// `vmax` define the data range that the colormap covers.
736+
pub fn vmax(mut self, v: f64) -> Self {
737+
self.vmax = Some(v);
738+
self
739+
}
740+
741+
/// Set the alpha blending value, between 0 (transparent) and 1
742+
/// (opaque).
743+
pub fn alpha(mut self, alpha: f64) -> Self {
744+
self.alpha = Some(alpha.clamp(0., 1.));
745+
self
746+
}
747+
748+
/// The linewidth of the marker edges.
749+
///
750+
/// Note: The default `edgecolors` is "face". You may want to
751+
/// change this as well.
752+
pub fn linewidths(mut self, lw: f64) -> Self {
753+
self.linewidths = Some(lw);
754+
self
755+
}
756+
757+
pub fn plot(self) {
758+
// FIXME: Do we want to check that `x` and `y` have the same
759+
// dimension? Better error message?
760+
Python::attach(|py| {
761+
match self.c {
762+
Some(ScatterColorMat::Cmap(v)) => {
763+
self.plot_with_colors(py, v.to_pyarray(py));
764+
}
765+
Some(ScatterColorMat::Colors(ref m)) => {
766+
self.plot_with_colors(py, m.to_pyarray(py));
767+
}
768+
None => self.plot_with_colors(py, None::<&str>),
769+
}
770+
})
771+
}
772+
773+
fn plot_with_colors<'py>(
774+
&self,
775+
py: Python<'py>,
776+
c: impl IntoPyObject<'py>,
777+
) {
778+
let xn = self.x.as_ref().to_pyarray(py);
779+
let yn = self.y.as_ref().to_pyarray(py);
780+
self.axes.ax.call_method1(py, intern!(py, "scatter"),
781+
(xn, yn, self.s, c, self.marker, self.cmap,
782+
self.norm, self.vmin, self.vmax, self.alpha,
783+
self.linewidths))
784+
.unwrap();
785+
}
786+
}
787+
788+
enum ScatterColorMat<'a> {
789+
Cmap(&'a [usize]),
790+
Colors(ndarray::Array2<f64>),
791+
}
792+
793+
/// Possible color specifications for [`Axes::scatter`] plots.
794+
pub trait ScatterColors {
795+
#[doc(hidden)]
796+
fn as_mat(&self) -> ndarray::Array2<f64>;
797+
}
798+
799+
impl<C> ScatterColors for &[C]
800+
where
801+
C: Color,
802+
{
803+
fn as_mat(&self) -> ndarray::Array2<f64> {
804+
let colors = self.as_ref();
805+
let n = colors.len();
806+
let mut c: Array2<f64> = ndarray::Array2::zeros((n, 4));
807+
for i in 0 .. n {
808+
let ci = colors[i].rgba();
809+
for j in 0 .. 4 {
810+
c[(i,j)] = ci[j];
811+
}
812+
}
813+
c
814+
}
815+
}
816+
817+
impl<C: Color> ScatterColors for C {
818+
fn as_mat(&self) -> ndarray::Array2<f64> {
819+
// A single row array gives the same color for all markers.
820+
let mut c = ndarray::Array2::zeros((1, 4));
821+
let color = self.rgba();
822+
c[(0,0)] = color[0];
823+
c[(0,1)] = color[1];
824+
c[(0,2)] = color[2];
825+
c[(0,3)] = color[3];
826+
c
827+
}
828+
}
829+
669830
pub struct QuadContourSet {
670831
contours: Py<PyAny>,
671832
}

0 commit comments

Comments
 (0)