Reference
Plotly-based plotting utilities and an extensible plotting interface for pyspect.
This module provides
- update_theme: Configures a Plotly figure with light/dark 2D/3D themes, common font, axis styles, and background colors. The function mutates the provided figure (template, paper/plot background) and supplies a template_layout for consistent styling.
- with_figure: A decorator that ensures a Plotly figure is available, normalizes
keyword arguments, and applies theming and layout updates. Depending on dim:
- dim=None: no axis remapping
- dim=2: xaxis_ and yaxis_ kwds map to layout_xaxis_ / layout_yaxis_
- dim=3: xaxis_, yaxis_, zaxis_ map to layout_scene_; camera_ maps to layout_scene_camera_ It also collects theme_ and layout_ prefixes to feed update_theme and figure layout.
The core interface is PlotlyImpl[R] (subclass of AxesImpl), which offers:
- A general dispatcher PlotlyImpl.plot(..., method="name") that calls either
self.plot_name or self.name on provided inputs.
- Three transformation hooks to implement in subclasses:
- transform_to_bitmap(inp, axes)
: 2D boolean array
- transform_to_surface(inp, axes)
: 2D float array
- transform_to_isosurface(inp, axes)
: 3D float array
Ready-to-use plotting methods (decorated with with_figure) build meshes over axis
bounds (min/max) via numpy.meshgrid (indexing='ij'), label axes using axis_name/unit,
and add corresponding Plotly traces:
- plot_bitmap
: 2D Heatmap of a boolean mask (True mapped to value
in [zmin, zmax]).
- plot_contour
: 2D contour lines from a 2D scalar field.
- plot_surface
: 3D surface (z from 2D scalar field) with scene aspectmode='cube'.
- plot_isosurface
: 3D isosurface at a given level from a 3D volume; caps hidden and
surface_count=1 by default; scene aspectmode='cube'.
The nested PlotlyImpl.PLOT namespace provides spherical-to-cartesian helpers and a set of precomputed camera.eye presets (e.g., EYE_HI_NE, EYE_MH_S, EYE_ZENITH, etc.) to quickly position 3D views.
PlotlyImpl is designed for easy extension, adding plotting functionality to implementations that inherit from it.
PlotlyImpl
Bases: AxesImpl
Example plotting interface.
Integrate these methods where pyspect emits data/sets/meshes
and call with either an existing fig=
or let it create one.
Source code in src/pyspect/impls/plotly.py
class PlotlyImpl[R](AxesImpl):
"""
Example plotting interface.
Integrate these methods where pyspect emits data/sets/meshes
and call with either an existing `fig=` or let it create one.
"""
class PLOT:
@staticmethod
def sph_to_cart(r, theta, phi):
"""Spherical (deg) → cartesian dict compatible with Plotly camera.eye."""
th = np.deg2rad(theta)
ph = np.deg2rad(phi)
s = np.sin(th)
return dict(
x=r * s * np.cos(ph),
y=r * s * np.sin(ph),
z=r * np.cos(th),
)
# Layer 1: Higher elevation (closer to the zenith)
EYE_HI_W = sph_to_cart(2.2, 20, -180) # West, high up
EYE_HI_SW = sph_to_cart(2.5, 30, -135) # Southwest, high up
EYE_HI_S = sph_to_cart(2.5, 20, -90) # South, high up
EYE_HI_SE = sph_to_cart(2.5, 30, -45) # Southeast, high up
EYE_HI_E = sph_to_cart(2.2, 20, 0) # East, high up
EYE_HI_NE = sph_to_cart(2.5, 30, 45) # Northeast, high up
EYE_HI_N = sph_to_cart(2.5, 20, 90) # North, high up
EYE_HI_NW = sph_to_cart(2.5, 30, 135) # Northwest, high up
# Layer 2: Medium-high elevation (closer to the horizon, around 45°)
EYE_MH_W = sph_to_cart(2.2, 45, -180) # West, medium height
EYE_MH_SW = sph_to_cart(2.5, 45, -135) # Southwest, medium height
EYE_MH_S = sph_to_cart(2.5, 45, -90) # South, medium height
EYE_MH_SE = sph_to_cart(2.5, 45, -45) # Southeast, medium height
EYE_MH_E = sph_to_cart(2.2, 45, 0) # East, medium height
EYE_MH_NE = sph_to_cart(2.5, 45, 45) # Northeast, medium height
EYE_MH_N = sph_to_cart(2.5, 45, 90) # North, medium height
EYE_MH_NW = sph_to_cart(2.5, 45, 135) # Northwest, medium height
# Layer 3: Medium-low elevation (closer to the nadir)
EYE_ML_W = sph_to_cart(2.2, 70, -180) # West, low elevation
EYE_ML_SW = sph_to_cart(2.5, 60, -135) # Southwest, low elevation
EYE_ML_S = sph_to_cart(2.5, 70, -90) # South, low elevation
EYE_ML_SE = sph_to_cart(2.5, 60, -45) # Southeast, low elevation
EYE_ML_E = sph_to_cart(2.2, 70, 0) # East, low elevation
EYE_ML_NE = sph_to_cart(2.5, 60, 45) # Northeast, low elevation
EYE_ML_N = sph_to_cart(2.5, 70, 90) # North, low elevation
EYE_ML_NW = sph_to_cart(2.5, 60, 135) # Northwest, low elevation
# Layer 4: Low elevation (closer to the nadir)
EYE_LO_W = sph_to_cart(2.2, 90, -180) # West, low elevation
EYE_LO_SW = sph_to_cart(2.5, 80, -135) # Southwest, low elevation
EYE_LO_S = sph_to_cart(2.5, 90, -90) # South, low elevation
EYE_LO_SE = sph_to_cart(2.5, 80, -45) # Southeast, low elevation
EYE_LO_E = sph_to_cart(2.2, 90, 0) # East, low elevation
EYE_LO_NE = sph_to_cart(2.5, 80, 45) # Northeast, low elevation
EYE_LO_N = sph_to_cart(2.5, 90, 90) # North, low elevation
EYE_LO_NW = sph_to_cart(2.5, 80, 135) # Northwest, low elevation
# Example of viewing from directly above and below
EYE_ZENITH = sph_to_cart(2.5, 0, 0) # Directly above (zenith)
EYE_NADIR = sph_to_cart(2.5, 180, 0) # Directly below (nadir)
@with_figure
def plot(self, *args: R | tuple[R, dict], method: str, fig: BaseFigure, **kwds) -> BaseFigure:
"""General plotting interface.
Parameters:
*args: TODO.
method: Plotting method to use. Implementation must provide `{method}` or `plot_{method}`.
fig: Existing figure to plot into. If not provided, a new figure is created.
**kwds: Additional keyword arguments passed to the plotting method.
Returns:
fig: The figure containing the plots.
"""
func = (getattr(self, 'plot_' + method, None) or getattr(self, method, None))
if not callable(func):
raise ValueError(f"Unknown plotting method '{method}'")
normalize = lambda x: x if isinstance(x, tuple) else (x, {})
for arg, kw in map(normalize, args):
setdefaults(kw, kwds)
func(arg, fig=fig, **kw)
return fig
def transform_to_bitmap(self, inp: R, axes: Axes2D, **kwds) -> np.ndarray:
"""Transform input data to a bitmap (2D boolean array).
This is a stub implementation. Actual implementation depends on the data type R.
Parameters:
inp: Input data to transform.
axes: Two axes to project onto.
**kwds: Additional keyword arguments for the transformation.
Returns:
A 2D boolean numpy array representing the bitmap.
"""
raise NotImplementedError("transform_to_bitmap not implemented")
@with_figure(dim=2)
def plot_bitmap(self,
inp: R, *,
value: float = 0.5,
axes: Axes2D = (0, 1),
fig: BaseFigure,
**kwds) -> BaseFigure:
"""Plot a 2D bitmap.
This method visualizes the input data as a 2D bitmap using a heatmap. To select the color
for the "True" values in the bitmap, use the `value` argument. This must be within the
range defined by `zmin` and `zmax` (arguments to go.Heatmap). `zmin` and `zmax` default to
0 and 1, respectively.
*Note:* Requires the `transform_to_bitmap` method to be implemented.
Parameters:
inp: Input data to plot.
value: Value to represent "True" in the bitmap.
axes: Two axes to project onto.
fig: Figure to plot into. If not provided, a new figure is created.
**kwds: Additional keyword arguments for the heatmap.
Returns:
fig: The figure containing the bitmap plot.
"""
setdefaults(kwds,
zmin=0, zmax=1,
colorscale="Greens",
showscale=False)
if not len(axes) == 2:
raise ValueError("plot_bitmap expects exactly 2 axes")
axes = tuple(self.axis(ax) for ax in axes)
if not kwds['zmin'] <= value <= kwds['zmax']:
raise ValueError("plot_bitmap expects value within [zmin, zmax]")
transf_kw = collect_prefix(kwds, 'transform_', remove=True)
Z = self.transform_to_bitmap(inp, axes=axes, **transf_kw)
if Z.ndim != 2: raise ValueError("transform_to_bitmap must return a 2D array")
min_bounds = kwds.pop('min_bounds', [self._min_bounds[i] for i in axes])
max_bounds = kwds.pop('max_bounds', [self._max_bounds[i] for i in axes])
if len(min_bounds) != 2: raise ValueError("plot_bitmap expects exactly 2 min_bounds")
if len(max_bounds) != 2: raise ValueError("plot_bitmap expects exactly 2 max_bounds")
X, Y = np.meshgrid(np.linspace(min_bounds[0], max_bounds[0], Z.shape[0]),
np.linspace(min_bounds[1], max_bounds[1], Z.shape[1]),
indexing='ij')
Z = np.where(Z, value, np.nan)
xaxis = dict(zeroline=False, showline=False, showticklabels=True)
yaxis = xaxis.copy()
for i, axis in enumerate((xaxis, yaxis)):
title = self.axis_name(axes[i])
if unit := self.axis_unit(axes[i]):
title += f' [{unit}]'
axis.update(range=[min_bounds[i], max_bounds[i]], title=title)
fig.update_layout(xaxis=xaxis, yaxis=yaxis)
return fig.add_trace(go.Heatmap(
x=X.flatten(),
y=Y.flatten(),
z=Z.flatten(),
**kwds
))
def transform_to_scatter(self, inp: R, axes: Axes2D, **kwds) -> np.ndarray:
"""Transform input data to scatter points (N x 2 float array).
This is a stub implementation. Actual implementation depends on the data type R.
Parameters:
inp: Input data to transform.
axes: Two axes to project onto.
**kwds: Additional keyword arguments for the transformation.
Returns:
An (N, 2) float numpy array representing the scatter points.
"""
raise NotImplementedError("transform_to_scatter not implemented")
@with_figure(dim=2)
def plot_fill(self,
inp: R, *,
axes: Axes2D = (0, 1),
fig: BaseFigure,
**kwds) -> BaseFigure:
"""Plot a filled 2D area.
This method visualizes the input data as a filled area using a scatter plot with
`fill='toself'`. The area is defined by the points returned by `transform_to_scatter`.
*Note:* Requires the `transform_to_scatter` method to be implemented.
Parameters:
inp: Input data to plot.
axes: Two axes to project onto.
fig: Figure to plot into. If not provided, a new figure is created.
**kwds: Additional keyword arguments for the scatter plot.
Returns:
fig: The figure containing the filled area plot.
"""
setdefaults(kwds,
fill='toself',
mode='lines',
line=dict(width=0),
fillcolor='#74c476', # Roughly 50% of Greens
showlegend=False)
if not len(axes) == 2:
raise ValueError("plot_fill expects exactly 2 axes")
axes = tuple(self.axis(ax) for ax in axes)
transf_kw = collect_prefix(kwds, 'transform_', remove=True)
P = self.transform_to_scatter(inp, axes=axes, **transf_kw)
if P.ndim != 2 or P.shape[1] != 2:
raise ValueError("transform_to_scatter must return an (N, 2) array")
min_bounds = kwds.pop('min_bounds', [self._min_bounds[i] for i in axes])
max_bounds = kwds.pop('max_bounds', [self._max_bounds[i] for i in axes])
if len(min_bounds) != 2: raise ValueError("plot_fill expects exactly 2 min_bounds")
if len(max_bounds) != 2: raise ValueError("plot_fill expects exactly 2 max_bounds")
xaxis = dict(zeroline=False, showline=False, showticklabels=True)
yaxis = xaxis.copy()
for i, axis in enumerate((xaxis, yaxis)):
title = self.axis_name(axes[i])
if unit := self.axis_unit(axes[i]):
title += f' [{unit}]'
axis.update(range=[min_bounds[i], max_bounds[i]], title=title)
fig.update_layout(xaxis=xaxis, yaxis=yaxis)
return fig.add_trace(go.Scatter(
x=P[:, 0],
y=P[:, 1],
**kwds
))
def transform_to_surface(self, inp: R, axes: Axes2D, **kwds) -> np.ndarray:
"""Transform input data to a 3D surface (2D float array).
This is a stub implementation. Actual implementation depends on the data type R.
Parameters:
inp: Input data to transform.
axes: Two axes to project onto.
**kwds: Additional keyword arguments for the transformation.
Returns:
A 2D float numpy array whose value represents the surface.
"""
raise NotImplementedError("transform_to_surface not implemented")
@with_figure(dim=2)
def plot_contour(self,
inp: R, *,
axes: Axes2D = (0, 1),
fig: BaseFigure,
**kwds) -> BaseFigure:
"""Plot a 2D contour.
This method visualizes the input data as a 2D contour using a contour plot. The contour
levels are determined by the values in the 2D array returned by `transform_to_surface`.
*Note:* Requires the `transform_to_surface` method to be implemented.
Parameters:
inp: Input data to plot.
axes: Two axes to project onto.
fig: Figure to plot into. If not provided, a new figure is created.
**kwds: Additional keyword arguments for the contour plot.
Returns:
fig: The figure containing the contour plot.
"""
setdefaults(kwds,
showscale=False,
contours_coloring='lines',
line=dict(width=2))
if not len(axes) == 2:
raise ValueError("plot_contour expects exactly 2 axes")
axes = tuple(self.axis(ax) for ax in axes)
transf_kw = collect_prefix(kwds, 'transform_', remove=True)
Z = self.transform_to_surface(inp, axes=axes, **transf_kw)
if Z.ndim != 2: raise ValueError("transform_to_surface must return a 2D array")
min_bounds = kwds.pop('min_bounds', [self._min_bounds[i] for i in axes])
max_bounds = kwds.pop('max_bounds', [self._max_bounds[i] for i in axes])
if len(min_bounds) != 2: raise ValueError("plot_contour expects exactly 2 min_bounds")
if len(max_bounds) != 2: raise ValueError("plot_contour expects exactly 2 max_bounds")
X, Y = np.meshgrid(np.linspace(min_bounds[0], max_bounds[0], Z.shape[0]),
np.linspace(min_bounds[1], max_bounds[1], Z.shape[1]),
indexing='ij')
xaxis, yaxis = {}, {}
for i, axis in enumerate((xaxis, yaxis)):
title = self.axis_name(axes[i])
if unit := self.axis_unit(axes[i]):
title += f' [{unit}]'
axis.update(range=[min_bounds[i], max_bounds[i]], title=title)
fig.update_layout(xaxis=xaxis, yaxis=yaxis)
return fig.add_trace(go.Contour(
x=X.flatten(),
y=Y.flatten(),
z=Z.flatten(),
**kwds
))
@with_figure(dim=3)
def plot_surface(self,
inp: R, *,
axes: Axes2D = (0, 1),
fig: BaseFigure,
**kwds) -> BaseFigure:
"""Plot a 3D surface.
This method visualizes the input data as a 3D surface using a surface plot. The height of
the surface is determined by the values in the 2D array returned by `transform_to_surface`.
*Note:* Requires the `transform_to_surface` method to be implemented.
Parameters:
inp: Input data to plot.
axes: Two axes to project onto.
fig: Figure to plot into. If not provided, a new figure is created.
**kwds: Additional keyword arguments for the surface plot.
Returns:
fig: The figure containing the surface plot.
"""
setdefaults(kwds,
showscale=False)
if not len(axes) == 2:
raise ValueError("plot_surface expects exactly 2 axes")
axes = tuple(self.axis(ax) for ax in axes)
transf_kw = collect_prefix(kwds, 'transform_', remove=True)
Z = self.transform_to_surface(inp, axes=axes, **transf_kw)
if Z.ndim != 2: raise ValueError("transform_to_surface must return a 2D array")
min_bounds = kwds.pop('min_bounds', [self._min_bounds[i] for i in axes])
max_bounds = kwds.pop('max_bounds', [self._max_bounds[i] for i in axes])
if len(min_bounds) != 2: raise ValueError("plot_bitmap expects exactly 2 min_bounds")
if len(max_bounds) != 2: raise ValueError("plot_bitmap expects exactly 2 max_bounds")
X, Y = np.meshgrid(np.linspace(min_bounds[0], max_bounds[0], Z.shape[0]),
np.linspace(min_bounds[1], max_bounds[1], Z.shape[1]),
indexing='ij')
xaxis, yaxis = {}, {}
for i, axis in enumerate((xaxis, yaxis)):
title = self.axis_name(axes[i])
if unit := self.axis_unit(axes[i]):
title += f' [{unit}]'
axis.update(range=[min_bounds[i], max_bounds[i]], title=title)
fig.update_layout(scene=dict(aspectmode='cube',
xaxis=xaxis,
yaxis=yaxis,
zaxis_title='Value'))
return fig.add_trace(go.Surface(
x=X.flatten(),
y=Y.flatten(),
z=Z.flatten(),
**kwds
))
def transform_to_isosurface(self, inp: R, axes: Axes3D, **kwds) -> np.ndarray:
"""Transform input data to a 3D volume (3D float array).
This is a stub implementation. Actual implementation depends on the data type R.
Parameters:
inp: Input data to transform.
axes: Three axes to project onto.
**kwds: Additional keyword arguments for the transformation.
Returns:
A 3D float numpy array whose value represents the volume.
"""
raise NotImplementedError("transform_to_isosurface not implemented")
@with_figure(dim=3)
def plot_isosurface(self,
inp: R, *,
level: float = 0.0,
axes: Axes3D = (0, 1, 2),
fig: BaseFigure,
**kwds) -> BaseFigure:
"""Plot a 3D isosurface.
This method visualizes the input data as a 3D isosurface using an isosurface plot. The
isosurface is extracted at the specified `level` from the 3D volume returned by
`transform_to_isosurface`.
*Note:* Requires the `transform_to_isosurface` method to be implemented.
Parameters:
inp: Input data to plot.
level: Level at which to extract the isosurface.
axes: Three axes to project onto.
fig: Figure to plot into. If not provided, a new figure is created.
**kwds: Additional keyword arguments for the isosurface plot.
Returns:
fig: The figure containing the isosurface plot.
"""
setdefaults(kwds,
colorscale='Greens',
showscale=False,
isomin=level,
isomax=level,
surface_count=1,
caps=dict(x_show=False, y_show=False, z_show=False))
if not len(axes) == 3:
raise ValueError("plot_isosurface expects exactly 3 axes")
axes = tuple(self.axis(ax) for ax in axes)
transf_kw = collect_prefix(kwds, 'transform_', remove=True)
V = self.transform_to_isosurface(inp, axes=axes, **transf_kw)
if V.ndim != 3: raise ValueError("transform_to_isosurface must return a 3D array")
min_bounds = kwds.pop('min_bounds', [self._min_bounds[i] for i in axes])
max_bounds = kwds.pop('max_bounds', [self._max_bounds[i] for i in axes])
if len(min_bounds) != 3: raise ValueError("plot_isosurface expects exactly 3 min_bounds")
if len(max_bounds) != 3: raise ValueError("plot_isosurface expects exactly 3 max_bounds")
X, Y, Z = np.meshgrid(np.linspace(min_bounds[0], max_bounds[0], V.shape[0]),
np.linspace(min_bounds[1], max_bounds[1], V.shape[1]),
np.linspace(min_bounds[2], max_bounds[2], V.shape[2]),
indexing='ij')
xaxis, yaxis, zaxis = {}, {}, {}
for i, axis in enumerate((xaxis, yaxis, zaxis)):
title = self.axis_name(axes[i])
if unit := self.axis_unit(axes[i]):
title += f' [{unit}]'
axis.update(range=[min_bounds[i], max_bounds[i]], title=title)
fig.update_layout(scene=dict(aspectmode='cube',
xaxis=xaxis,
yaxis=yaxis,
zaxis=zaxis))
return fig.add_trace(go.Isosurface(
x=X.flatten(),
y=Y.flatten(),
z=Z.flatten(),
value=V.flatten(),
**kwds
))
PLOT
Source code in src/pyspect/impls/plotly.py
class PLOT:
@staticmethod
def sph_to_cart(r, theta, phi):
"""Spherical (deg) → cartesian dict compatible with Plotly camera.eye."""
th = np.deg2rad(theta)
ph = np.deg2rad(phi)
s = np.sin(th)
return dict(
x=r * s * np.cos(ph),
y=r * s * np.sin(ph),
z=r * np.cos(th),
)
# Layer 1: Higher elevation (closer to the zenith)
EYE_HI_W = sph_to_cart(2.2, 20, -180) # West, high up
EYE_HI_SW = sph_to_cart(2.5, 30, -135) # Southwest, high up
EYE_HI_S = sph_to_cart(2.5, 20, -90) # South, high up
EYE_HI_SE = sph_to_cart(2.5, 30, -45) # Southeast, high up
EYE_HI_E = sph_to_cart(2.2, 20, 0) # East, high up
EYE_HI_NE = sph_to_cart(2.5, 30, 45) # Northeast, high up
EYE_HI_N = sph_to_cart(2.5, 20, 90) # North, high up
EYE_HI_NW = sph_to_cart(2.5, 30, 135) # Northwest, high up
# Layer 2: Medium-high elevation (closer to the horizon, around 45°)
EYE_MH_W = sph_to_cart(2.2, 45, -180) # West, medium height
EYE_MH_SW = sph_to_cart(2.5, 45, -135) # Southwest, medium height
EYE_MH_S = sph_to_cart(2.5, 45, -90) # South, medium height
EYE_MH_SE = sph_to_cart(2.5, 45, -45) # Southeast, medium height
EYE_MH_E = sph_to_cart(2.2, 45, 0) # East, medium height
EYE_MH_NE = sph_to_cart(2.5, 45, 45) # Northeast, medium height
EYE_MH_N = sph_to_cart(2.5, 45, 90) # North, medium height
EYE_MH_NW = sph_to_cart(2.5, 45, 135) # Northwest, medium height
# Layer 3: Medium-low elevation (closer to the nadir)
EYE_ML_W = sph_to_cart(2.2, 70, -180) # West, low elevation
EYE_ML_SW = sph_to_cart(2.5, 60, -135) # Southwest, low elevation
EYE_ML_S = sph_to_cart(2.5, 70, -90) # South, low elevation
EYE_ML_SE = sph_to_cart(2.5, 60, -45) # Southeast, low elevation
EYE_ML_E = sph_to_cart(2.2, 70, 0) # East, low elevation
EYE_ML_NE = sph_to_cart(2.5, 60, 45) # Northeast, low elevation
EYE_ML_N = sph_to_cart(2.5, 70, 90) # North, low elevation
EYE_ML_NW = sph_to_cart(2.5, 60, 135) # Northwest, low elevation
# Layer 4: Low elevation (closer to the nadir)
EYE_LO_W = sph_to_cart(2.2, 90, -180) # West, low elevation
EYE_LO_SW = sph_to_cart(2.5, 80, -135) # Southwest, low elevation
EYE_LO_S = sph_to_cart(2.5, 90, -90) # South, low elevation
EYE_LO_SE = sph_to_cart(2.5, 80, -45) # Southeast, low elevation
EYE_LO_E = sph_to_cart(2.2, 90, 0) # East, low elevation
EYE_LO_NE = sph_to_cart(2.5, 80, 45) # Northeast, low elevation
EYE_LO_N = sph_to_cart(2.5, 90, 90) # North, low elevation
EYE_LO_NW = sph_to_cart(2.5, 80, 135) # Northwest, low elevation
# Example of viewing from directly above and below
EYE_ZENITH = sph_to_cart(2.5, 0, 0) # Directly above (zenith)
EYE_NADIR = sph_to_cart(2.5, 180, 0) # Directly below (nadir)
sph_to_cart(r, theta, phi)
staticmethod
Spherical (deg) → cartesian dict compatible with Plotly camera.eye.
Source code in src/pyspect/impls/plotly.py
@staticmethod
def sph_to_cart(r, theta, phi):
"""Spherical (deg) → cartesian dict compatible with Plotly camera.eye."""
th = np.deg2rad(theta)
ph = np.deg2rad(phi)
s = np.sin(th)
return dict(
x=r * s * np.cos(ph),
y=r * s * np.sin(ph),
z=r * np.cos(th),
)
plot(*args, method, fig, **kwds)
General plotting interface.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*args
|
R | tuple[R, dict]
|
TODO. |
()
|
method
|
str
|
Plotting method to use. Implementation must provide |
required |
fig
|
BaseFigure
|
Existing figure to plot into. If not provided, a new figure is created. |
required |
**kwds
|
Additional keyword arguments passed to the plotting method. |
{}
|
Returns:
Name | Type | Description |
---|---|---|
fig |
BaseFigure
|
The figure containing the plots. |
Source code in src/pyspect/impls/plotly.py
@with_figure
def plot(self, *args: R | tuple[R, dict], method: str, fig: BaseFigure, **kwds) -> BaseFigure:
"""General plotting interface.
Parameters:
*args: TODO.
method: Plotting method to use. Implementation must provide `{method}` or `plot_{method}`.
fig: Existing figure to plot into. If not provided, a new figure is created.
**kwds: Additional keyword arguments passed to the plotting method.
Returns:
fig: The figure containing the plots.
"""
func = (getattr(self, 'plot_' + method, None) or getattr(self, method, None))
if not callable(func):
raise ValueError(f"Unknown plotting method '{method}'")
normalize = lambda x: x if isinstance(x, tuple) else (x, {})
for arg, kw in map(normalize, args):
setdefaults(kw, kwds)
func(arg, fig=fig, **kw)
return fig
plot_bitmap(inp, *, value=0.5, axes=(0, 1), fig, **kwds)
Plot a 2D bitmap.
This method visualizes the input data as a 2D bitmap using a heatmap. To select the color
for the "True" values in the bitmap, use the value
argument. This must be within the
range defined by zmin
and zmax
(arguments to go.Heatmap). zmin
and zmax
default to
0 and 1, respectively.
Note: Requires the transform_to_bitmap
method to be implemented.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
inp
|
R
|
Input data to plot. |
required |
value
|
float
|
Value to represent "True" in the bitmap. |
0.5
|
axes
|
Axes2D
|
Two axes to project onto. |
(0, 1)
|
fig
|
BaseFigure
|
Figure to plot into. If not provided, a new figure is created. |
required |
**kwds
|
Additional keyword arguments for the heatmap. |
{}
|
Returns:
Name | Type | Description |
---|---|---|
fig |
BaseFigure
|
The figure containing the bitmap plot. |
Source code in src/pyspect/impls/plotly.py
@with_figure(dim=2)
def plot_bitmap(self,
inp: R, *,
value: float = 0.5,
axes: Axes2D = (0, 1),
fig: BaseFigure,
**kwds) -> BaseFigure:
"""Plot a 2D bitmap.
This method visualizes the input data as a 2D bitmap using a heatmap. To select the color
for the "True" values in the bitmap, use the `value` argument. This must be within the
range defined by `zmin` and `zmax` (arguments to go.Heatmap). `zmin` and `zmax` default to
0 and 1, respectively.
*Note:* Requires the `transform_to_bitmap` method to be implemented.
Parameters:
inp: Input data to plot.
value: Value to represent "True" in the bitmap.
axes: Two axes to project onto.
fig: Figure to plot into. If not provided, a new figure is created.
**kwds: Additional keyword arguments for the heatmap.
Returns:
fig: The figure containing the bitmap plot.
"""
setdefaults(kwds,
zmin=0, zmax=1,
colorscale="Greens",
showscale=False)
if not len(axes) == 2:
raise ValueError("plot_bitmap expects exactly 2 axes")
axes = tuple(self.axis(ax) for ax in axes)
if not kwds['zmin'] <= value <= kwds['zmax']:
raise ValueError("plot_bitmap expects value within [zmin, zmax]")
transf_kw = collect_prefix(kwds, 'transform_', remove=True)
Z = self.transform_to_bitmap(inp, axes=axes, **transf_kw)
if Z.ndim != 2: raise ValueError("transform_to_bitmap must return a 2D array")
min_bounds = kwds.pop('min_bounds', [self._min_bounds[i] for i in axes])
max_bounds = kwds.pop('max_bounds', [self._max_bounds[i] for i in axes])
if len(min_bounds) != 2: raise ValueError("plot_bitmap expects exactly 2 min_bounds")
if len(max_bounds) != 2: raise ValueError("plot_bitmap expects exactly 2 max_bounds")
X, Y = np.meshgrid(np.linspace(min_bounds[0], max_bounds[0], Z.shape[0]),
np.linspace(min_bounds[1], max_bounds[1], Z.shape[1]),
indexing='ij')
Z = np.where(Z, value, np.nan)
xaxis = dict(zeroline=False, showline=False, showticklabels=True)
yaxis = xaxis.copy()
for i, axis in enumerate((xaxis, yaxis)):
title = self.axis_name(axes[i])
if unit := self.axis_unit(axes[i]):
title += f' [{unit}]'
axis.update(range=[min_bounds[i], max_bounds[i]], title=title)
fig.update_layout(xaxis=xaxis, yaxis=yaxis)
return fig.add_trace(go.Heatmap(
x=X.flatten(),
y=Y.flatten(),
z=Z.flatten(),
**kwds
))
plot_contour(inp, *, axes=(0, 1), fig, **kwds)
Plot a 2D contour.
This method visualizes the input data as a 2D contour using a contour plot. The contour
levels are determined by the values in the 2D array returned by transform_to_surface
.
Note: Requires the transform_to_surface
method to be implemented.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
inp
|
R
|
Input data to plot. |
required |
axes
|
Axes2D
|
Two axes to project onto. |
(0, 1)
|
fig
|
BaseFigure
|
Figure to plot into. If not provided, a new figure is created. |
required |
**kwds
|
Additional keyword arguments for the contour plot. |
{}
|
Returns:
Name | Type | Description |
---|---|---|
fig |
BaseFigure
|
The figure containing the contour plot. |
Source code in src/pyspect/impls/plotly.py
@with_figure(dim=2)
def plot_contour(self,
inp: R, *,
axes: Axes2D = (0, 1),
fig: BaseFigure,
**kwds) -> BaseFigure:
"""Plot a 2D contour.
This method visualizes the input data as a 2D contour using a contour plot. The contour
levels are determined by the values in the 2D array returned by `transform_to_surface`.
*Note:* Requires the `transform_to_surface` method to be implemented.
Parameters:
inp: Input data to plot.
axes: Two axes to project onto.
fig: Figure to plot into. If not provided, a new figure is created.
**kwds: Additional keyword arguments for the contour plot.
Returns:
fig: The figure containing the contour plot.
"""
setdefaults(kwds,
showscale=False,
contours_coloring='lines',
line=dict(width=2))
if not len(axes) == 2:
raise ValueError("plot_contour expects exactly 2 axes")
axes = tuple(self.axis(ax) for ax in axes)
transf_kw = collect_prefix(kwds, 'transform_', remove=True)
Z = self.transform_to_surface(inp, axes=axes, **transf_kw)
if Z.ndim != 2: raise ValueError("transform_to_surface must return a 2D array")
min_bounds = kwds.pop('min_bounds', [self._min_bounds[i] for i in axes])
max_bounds = kwds.pop('max_bounds', [self._max_bounds[i] for i in axes])
if len(min_bounds) != 2: raise ValueError("plot_contour expects exactly 2 min_bounds")
if len(max_bounds) != 2: raise ValueError("plot_contour expects exactly 2 max_bounds")
X, Y = np.meshgrid(np.linspace(min_bounds[0], max_bounds[0], Z.shape[0]),
np.linspace(min_bounds[1], max_bounds[1], Z.shape[1]),
indexing='ij')
xaxis, yaxis = {}, {}
for i, axis in enumerate((xaxis, yaxis)):
title = self.axis_name(axes[i])
if unit := self.axis_unit(axes[i]):
title += f' [{unit}]'
axis.update(range=[min_bounds[i], max_bounds[i]], title=title)
fig.update_layout(xaxis=xaxis, yaxis=yaxis)
return fig.add_trace(go.Contour(
x=X.flatten(),
y=Y.flatten(),
z=Z.flatten(),
**kwds
))
plot_fill(inp, *, axes=(0, 1), fig, **kwds)
Plot a filled 2D area.
This method visualizes the input data as a filled area using a scatter plot with
fill='toself'
. The area is defined by the points returned by transform_to_scatter
.
Note: Requires the transform_to_scatter
method to be implemented.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
inp
|
R
|
Input data to plot. |
required |
axes
|
Axes2D
|
Two axes to project onto. |
(0, 1)
|
fig
|
BaseFigure
|
Figure to plot into. If not provided, a new figure is created. |
required |
**kwds
|
Additional keyword arguments for the scatter plot. |
{}
|
Returns:
Name | Type | Description |
---|---|---|
fig |
BaseFigure
|
The figure containing the filled area plot. |
Source code in src/pyspect/impls/plotly.py
@with_figure(dim=2)
def plot_fill(self,
inp: R, *,
axes: Axes2D = (0, 1),
fig: BaseFigure,
**kwds) -> BaseFigure:
"""Plot a filled 2D area.
This method visualizes the input data as a filled area using a scatter plot with
`fill='toself'`. The area is defined by the points returned by `transform_to_scatter`.
*Note:* Requires the `transform_to_scatter` method to be implemented.
Parameters:
inp: Input data to plot.
axes: Two axes to project onto.
fig: Figure to plot into. If not provided, a new figure is created.
**kwds: Additional keyword arguments for the scatter plot.
Returns:
fig: The figure containing the filled area plot.
"""
setdefaults(kwds,
fill='toself',
mode='lines',
line=dict(width=0),
fillcolor='#74c476', # Roughly 50% of Greens
showlegend=False)
if not len(axes) == 2:
raise ValueError("plot_fill expects exactly 2 axes")
axes = tuple(self.axis(ax) for ax in axes)
transf_kw = collect_prefix(kwds, 'transform_', remove=True)
P = self.transform_to_scatter(inp, axes=axes, **transf_kw)
if P.ndim != 2 or P.shape[1] != 2:
raise ValueError("transform_to_scatter must return an (N, 2) array")
min_bounds = kwds.pop('min_bounds', [self._min_bounds[i] for i in axes])
max_bounds = kwds.pop('max_bounds', [self._max_bounds[i] for i in axes])
if len(min_bounds) != 2: raise ValueError("plot_fill expects exactly 2 min_bounds")
if len(max_bounds) != 2: raise ValueError("plot_fill expects exactly 2 max_bounds")
xaxis = dict(zeroline=False, showline=False, showticklabels=True)
yaxis = xaxis.copy()
for i, axis in enumerate((xaxis, yaxis)):
title = self.axis_name(axes[i])
if unit := self.axis_unit(axes[i]):
title += f' [{unit}]'
axis.update(range=[min_bounds[i], max_bounds[i]], title=title)
fig.update_layout(xaxis=xaxis, yaxis=yaxis)
return fig.add_trace(go.Scatter(
x=P[:, 0],
y=P[:, 1],
**kwds
))
plot_isosurface(inp, *, level=0.0, axes=(0, 1, 2), fig, **kwds)
Plot a 3D isosurface.
This method visualizes the input data as a 3D isosurface using an isosurface plot. The
isosurface is extracted at the specified level
from the 3D volume returned by
transform_to_isosurface
.
Note: Requires the transform_to_isosurface
method to be implemented.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
inp
|
R
|
Input data to plot. |
required |
level
|
float
|
Level at which to extract the isosurface. |
0.0
|
axes
|
Axes3D
|
Three axes to project onto. |
(0, 1, 2)
|
fig
|
BaseFigure
|
Figure to plot into. If not provided, a new figure is created. |
required |
**kwds
|
Additional keyword arguments for the isosurface plot. |
{}
|
Returns:
Name | Type | Description |
---|---|---|
fig |
BaseFigure
|
The figure containing the isosurface plot. |
Source code in src/pyspect/impls/plotly.py
@with_figure(dim=3)
def plot_isosurface(self,
inp: R, *,
level: float = 0.0,
axes: Axes3D = (0, 1, 2),
fig: BaseFigure,
**kwds) -> BaseFigure:
"""Plot a 3D isosurface.
This method visualizes the input data as a 3D isosurface using an isosurface plot. The
isosurface is extracted at the specified `level` from the 3D volume returned by
`transform_to_isosurface`.
*Note:* Requires the `transform_to_isosurface` method to be implemented.
Parameters:
inp: Input data to plot.
level: Level at which to extract the isosurface.
axes: Three axes to project onto.
fig: Figure to plot into. If not provided, a new figure is created.
**kwds: Additional keyword arguments for the isosurface plot.
Returns:
fig: The figure containing the isosurface plot.
"""
setdefaults(kwds,
colorscale='Greens',
showscale=False,
isomin=level,
isomax=level,
surface_count=1,
caps=dict(x_show=False, y_show=False, z_show=False))
if not len(axes) == 3:
raise ValueError("plot_isosurface expects exactly 3 axes")
axes = tuple(self.axis(ax) for ax in axes)
transf_kw = collect_prefix(kwds, 'transform_', remove=True)
V = self.transform_to_isosurface(inp, axes=axes, **transf_kw)
if V.ndim != 3: raise ValueError("transform_to_isosurface must return a 3D array")
min_bounds = kwds.pop('min_bounds', [self._min_bounds[i] for i in axes])
max_bounds = kwds.pop('max_bounds', [self._max_bounds[i] for i in axes])
if len(min_bounds) != 3: raise ValueError("plot_isosurface expects exactly 3 min_bounds")
if len(max_bounds) != 3: raise ValueError("plot_isosurface expects exactly 3 max_bounds")
X, Y, Z = np.meshgrid(np.linspace(min_bounds[0], max_bounds[0], V.shape[0]),
np.linspace(min_bounds[1], max_bounds[1], V.shape[1]),
np.linspace(min_bounds[2], max_bounds[2], V.shape[2]),
indexing='ij')
xaxis, yaxis, zaxis = {}, {}, {}
for i, axis in enumerate((xaxis, yaxis, zaxis)):
title = self.axis_name(axes[i])
if unit := self.axis_unit(axes[i]):
title += f' [{unit}]'
axis.update(range=[min_bounds[i], max_bounds[i]], title=title)
fig.update_layout(scene=dict(aspectmode='cube',
xaxis=xaxis,
yaxis=yaxis,
zaxis=zaxis))
return fig.add_trace(go.Isosurface(
x=X.flatten(),
y=Y.flatten(),
z=Z.flatten(),
value=V.flatten(),
**kwds
))
plot_surface(inp, *, axes=(0, 1), fig, **kwds)
Plot a 3D surface.
This method visualizes the input data as a 3D surface using a surface plot. The height of
the surface is determined by the values in the 2D array returned by transform_to_surface
.
Note: Requires the transform_to_surface
method to be implemented.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
inp
|
R
|
Input data to plot. |
required |
axes
|
Axes2D
|
Two axes to project onto. |
(0, 1)
|
fig
|
BaseFigure
|
Figure to plot into. If not provided, a new figure is created. |
required |
**kwds
|
Additional keyword arguments for the surface plot. |
{}
|
Returns:
Name | Type | Description |
---|---|---|
fig |
BaseFigure
|
The figure containing the surface plot. |
Source code in src/pyspect/impls/plotly.py
@with_figure(dim=3)
def plot_surface(self,
inp: R, *,
axes: Axes2D = (0, 1),
fig: BaseFigure,
**kwds) -> BaseFigure:
"""Plot a 3D surface.
This method visualizes the input data as a 3D surface using a surface plot. The height of
the surface is determined by the values in the 2D array returned by `transform_to_surface`.
*Note:* Requires the `transform_to_surface` method to be implemented.
Parameters:
inp: Input data to plot.
axes: Two axes to project onto.
fig: Figure to plot into. If not provided, a new figure is created.
**kwds: Additional keyword arguments for the surface plot.
Returns:
fig: The figure containing the surface plot.
"""
setdefaults(kwds,
showscale=False)
if not len(axes) == 2:
raise ValueError("plot_surface expects exactly 2 axes")
axes = tuple(self.axis(ax) for ax in axes)
transf_kw = collect_prefix(kwds, 'transform_', remove=True)
Z = self.transform_to_surface(inp, axes=axes, **transf_kw)
if Z.ndim != 2: raise ValueError("transform_to_surface must return a 2D array")
min_bounds = kwds.pop('min_bounds', [self._min_bounds[i] for i in axes])
max_bounds = kwds.pop('max_bounds', [self._max_bounds[i] for i in axes])
if len(min_bounds) != 2: raise ValueError("plot_bitmap expects exactly 2 min_bounds")
if len(max_bounds) != 2: raise ValueError("plot_bitmap expects exactly 2 max_bounds")
X, Y = np.meshgrid(np.linspace(min_bounds[0], max_bounds[0], Z.shape[0]),
np.linspace(min_bounds[1], max_bounds[1], Z.shape[1]),
indexing='ij')
xaxis, yaxis = {}, {}
for i, axis in enumerate((xaxis, yaxis)):
title = self.axis_name(axes[i])
if unit := self.axis_unit(axes[i]):
title += f' [{unit}]'
axis.update(range=[min_bounds[i], max_bounds[i]], title=title)
fig.update_layout(scene=dict(aspectmode='cube',
xaxis=xaxis,
yaxis=yaxis,
zaxis_title='Value'))
return fig.add_trace(go.Surface(
x=X.flatten(),
y=Y.flatten(),
z=Z.flatten(),
**kwds
))
transform_to_bitmap(inp, axes, **kwds)
Transform input data to a bitmap (2D boolean array).
This is a stub implementation. Actual implementation depends on the data type R.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
inp
|
R
|
Input data to transform. |
required |
axes
|
Axes2D
|
Two axes to project onto. |
required |
**kwds
|
Additional keyword arguments for the transformation. |
{}
|
Returns:
Type | Description |
---|---|
ndarray
|
A 2D boolean numpy array representing the bitmap. |
Source code in src/pyspect/impls/plotly.py
def transform_to_bitmap(self, inp: R, axes: Axes2D, **kwds) -> np.ndarray:
"""Transform input data to a bitmap (2D boolean array).
This is a stub implementation. Actual implementation depends on the data type R.
Parameters:
inp: Input data to transform.
axes: Two axes to project onto.
**kwds: Additional keyword arguments for the transformation.
Returns:
A 2D boolean numpy array representing the bitmap.
"""
raise NotImplementedError("transform_to_bitmap not implemented")
transform_to_isosurface(inp, axes, **kwds)
Transform input data to a 3D volume (3D float array).
This is a stub implementation. Actual implementation depends on the data type R.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
inp
|
R
|
Input data to transform. |
required |
axes
|
Axes3D
|
Three axes to project onto. |
required |
**kwds
|
Additional keyword arguments for the transformation. |
{}
|
Returns:
Type | Description |
---|---|
ndarray
|
A 3D float numpy array whose value represents the volume. |
Source code in src/pyspect/impls/plotly.py
def transform_to_isosurface(self, inp: R, axes: Axes3D, **kwds) -> np.ndarray:
"""Transform input data to a 3D volume (3D float array).
This is a stub implementation. Actual implementation depends on the data type R.
Parameters:
inp: Input data to transform.
axes: Three axes to project onto.
**kwds: Additional keyword arguments for the transformation.
Returns:
A 3D float numpy array whose value represents the volume.
"""
raise NotImplementedError("transform_to_isosurface not implemented")
transform_to_scatter(inp, axes, **kwds)
Transform input data to scatter points (N x 2 float array).
This is a stub implementation. Actual implementation depends on the data type R.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
inp
|
R
|
Input data to transform. |
required |
axes
|
Axes2D
|
Two axes to project onto. |
required |
**kwds
|
Additional keyword arguments for the transformation. |
{}
|
Returns:
Type | Description |
---|---|
ndarray
|
An (N, 2) float numpy array representing the scatter points. |
Source code in src/pyspect/impls/plotly.py
def transform_to_scatter(self, inp: R, axes: Axes2D, **kwds) -> np.ndarray:
"""Transform input data to scatter points (N x 2 float array).
This is a stub implementation. Actual implementation depends on the data type R.
Parameters:
inp: Input data to transform.
axes: Two axes to project onto.
**kwds: Additional keyword arguments for the transformation.
Returns:
An (N, 2) float numpy array representing the scatter points.
"""
raise NotImplementedError("transform_to_scatter not implemented")
transform_to_surface(inp, axes, **kwds)
Transform input data to a 3D surface (2D float array).
This is a stub implementation. Actual implementation depends on the data type R.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
inp
|
R
|
Input data to transform. |
required |
axes
|
Axes2D
|
Two axes to project onto. |
required |
**kwds
|
Additional keyword arguments for the transformation. |
{}
|
Returns:
Type | Description |
---|---|
ndarray
|
A 2D float numpy array whose value represents the surface. |
Source code in src/pyspect/impls/plotly.py
def transform_to_surface(self, inp: R, axes: Axes2D, **kwds) -> np.ndarray:
"""Transform input data to a 3D surface (2D float array).
This is a stub implementation. Actual implementation depends on the data type R.
Parameters:
inp: Input data to transform.
axes: Two axes to project onto.
**kwds: Additional keyword arguments for the transformation.
Returns:
A 2D float numpy array whose value represents the surface.
"""
raise NotImplementedError("transform_to_surface not implemented")
collect_keys(d, *keys, default=...)
Collect a subset of keys from d
.
Behavior:
- If default
is Ellipsis (the default), include only keys that exist in d
.
- Otherwise, include all requested keys
, filling missing ones with default
.
Returns: - New dict containing the selected keys.
Source code in src/pyspect/utils/__init__.py
def collect_keys(d: dict, *keys, default=...):
"""Collect a subset of keys from `d`.
Behavior:
- If `default` is Ellipsis (the default), include only keys that exist in `d`.
- Otherwise, include all requested `keys`, filling missing ones with `default`.
Returns:
- New dict containing the selected keys.
"""
if default is Ellipsis:
return {k: d[k] for k in keys if k in d}
else:
return {k: d.get(k, default) for k in keys}
collect_prefix(d, prefix, remove=False)
Extract and remove items whose keys start with prefix
.
Side effect:
- Matching items are popped from d
.
Parameters:
- d: source dictionary (mutated)
- prefix: string to match at the start of each key
- remove: if True, strip prefix
from keys in the returned dict;
if False, keep original keys
Returns: - New dict of the extracted items.
Source code in src/pyspect/utils/__init__.py
def collect_prefix(d: Dict[str, Any], prefix: str, remove=False) -> Dict[str, Any]:
"""Extract and remove items whose keys start with `prefix`.
Side effect:
- Matching items are popped from `d`.
Parameters:
- d: source dictionary (mutated)
- prefix: string to match at the start of each key
- remove: if True, strip `prefix` from keys in the returned dict;
if False, keep original keys
Returns:
- New dict of the extracted items.
"""
if remove:
return {k.removeprefix(prefix): d.pop(k)
for k in list(d) if k.startswith(prefix)}
else:
return {k: d.pop(k)
for k in list(d) if k.startswith(prefix)}
flatten(nested, *, sep='_', inplace=False)
Flatten nested mappings into a single level by joining keys with sep
.
Example: - {"a": {"b": 1}, "c": 2} -> {"a_b": 1, "c": 2} (sep="_")
Parameters:
- nested: mapping to flatten; nested values that are mappings are expanded
- sep: string inserted between joined key parts
- inplace: if True and nested
is mutable, mutate it in place; otherwise return a new dict
Returns:
- A flat dict (or the mutated input when inplace=True
).
Source code in src/pyspect/utils/__init__.py
def flatten(nested: Mapping[str, Any], *, sep: str = "_", inplace: bool = False) -> Dict[str, Any]:
"""Flatten nested mappings into a single level by joining keys with `sep`.
Example:
- {"a": {"b": 1}, "c": 2} -> {"a_b": 1, "c": 2} (sep="_")
Parameters:
- nested: mapping to flatten; nested values that are mappings are expanded
- sep: string inserted between joined key parts
- inplace: if True and `nested` is mutable, mutate it in place; otherwise return a new dict
Returns:
- A flat dict (or the mutated input when `inplace=True`).
"""
if inplace:
for k, v in list(nested.items()):
if isinstance(v, Mapping):
for kk, vv in flatten(v, sep=sep).items():
nested[f"{k}{sep}{kk}"] = vv
del nested[k]
return nested
else:
out: Dict[str, Any] = {}
for k, v in nested.items():
if isinstance(v, Mapping):
for kk, vv in flatten(v, sep=sep).items():
out[f"{k}{sep}{kk}"] = vv
else:
out[k] = v
return out
iterwin(seq, winlen=1)
Yield fixed-size windows from an indexable sequence by striding.
Equivalent to: zip((seq[i::winlen] for i in range(winlen))). For generic iterables, prefer: zip([iter(seq)] * winlen).
Parameters: - seq: indexable sequence (supports slicing) - winlen: positive window/stride length
Yields:
- Tuples of length winlen
with elements at positions i mod winlen
Source code in src/pyspect/utils/__init__.py
def iterwin(seq, winlen=1):
"""Yield fixed-size windows from an indexable sequence by striding.
Equivalent to: zip(*(seq[i::winlen] for i in range(winlen))).
For generic iterables, prefer: zip(*[iter(seq)] * winlen).
Parameters:
- seq: indexable sequence (supports slicing)
- winlen: positive window/stride length
Yields:
- Tuples of length `winlen` with elements at positions i mod winlen
"""
# Works for indexables; for iterables, use: zip(*[iter(seq)]*winlen)
slices = [seq[i::winlen] for i in range(winlen)]
yield from zip(*slices)
prefix_keys(d, prefix)
Return a new dict with prefix
added to every key in d
.
Source code in src/pyspect/utils/__init__.py
def prefix_keys(d: Mapping[str, Any], prefix: str) -> Dict[str, Any]:
"""Return a new dict with `prefix` added to every key in `d`."""
return {f"{prefix}{k}": v for k, v in d.items()}
setdefaults(d, *args, **kwds)
Set default key/value pairs on dict d
without overwriting existing keys.
Calling conventions: - Keyword form: setdefaults(d, a=1, b=2) - Dict form: setdefaults(d, {'a': 1, 'b': 2}) - Variadic kv: setdefaults(d, 'a', 1, 'b', 2) # even number of args
Raises: - TypeError/ValueError if the calling convention is invalid.
Source code in src/pyspect/utils/__init__.py
def setdefaults(d: dict, *args, **kwds) -> None:
"""Set default key/value pairs on dict `d` without overwriting existing keys.
Calling conventions:
- Keyword form: setdefaults(d, a=1, b=2)
- Dict form: setdefaults(d, {'a': 1, 'b': 2})
- Variadic kv: setdefaults(d, 'a', 1, 'b', 2) # even number of args
Raises:
- TypeError/ValueError if the calling convention is invalid.
"""
if not args:
if not kwds:
raise TypeError("setdefaults expected defaults")
defaults = kwds
elif len(args) == 1:
(defaults,) = args
if not isinstance(defaults, dict):
raise TypeError("single-arg form must be a dict")
if kwds:
raise TypeError("cannot mix dict arg with keyword args")
else:
if kwds:
raise TypeError("cannot mix variadic kv with keyword args")
if len(args) % 2 != 0:
raise ValueError("variadic kv form needs even number of args")
defaults = {k: v for k, v in iterwin(args, 2)}
for k, v in defaults.items():
d.setdefault(k, v)
update_theme(name=None, *, aspectratio='4:3', fig)
Apply layout for 2D/3D light/dark themes.
Source code in src/pyspect/impls/plotly.py
def update_theme(name: Optional[str] = None, *,
aspectratio: str = "4:3",
fig: BaseFigure) -> Dict[str, Any]:
"""Apply layout for 2D/3D light/dark themes."""
layout = dict(margin=dict(l=60, r=20, t=40, b=60))
if name is not None:
# Font
font = dict(family="Roboto, Arial, sans-serif", size=14)
layout.update(font=font)
# Dimensions
axes = dict(linewidth=2)
if name[-2:] not in ("2D", "3D"):
name += "2D"
if name.endswith("2D"):
layout.update(xaxis=axes, yaxis=axes)
if name.endswith("3D"):
layout.update(scene=dict(xaxis=axes, yaxis=axes, zaxis=axes))
# Color Theme
if name.startswith("Light"):
fig.update_layout(template="plotly_white")
layout.update(paper_bgcolor="rgba(255, 255, 255, 1)",
plot_bgcolor="rgba(250, 250, 250, 1)")
font.update(color="black")
axes.update(linecolor="rgba(0, 0, 0, 0.3)",
gridcolor="rgba(0, 0, 0, 0.1)",
zerolinecolor="rgba(0, 0, 0, 0.3)")
if name.startswith("Dark"):
fig.update_layout(template="plotly_dark")
layout.update(paper_bgcolor="rgba(26, 28, 36, 1)",
plot_bgcolor="rgba(26, 28, 36, 1)")
font.update(color="white")
axes.update(linecolor="rgba(255, 255, 255, 0.3)",
gridcolor="rgba(255, 255, 255, 0.1)",
zerolinecolor="rgba(255, 255, 255, 0.3)")
fig.update_layout(template_layout=layout)
with_figure(f=None, *, dim=None)
Decorator to handle figure creation and theming for plotting methods. Parameters: dim: Dimensionality of the plot (2 or 3).
Source code in src/pyspect/impls/plotly.py
def with_figure(f: Optional[Callable] = None, *, dim: Optional[int] = None):
"""Decorator to handle figure creation and theming for plotting methods.
Parameters:
dim: Dimensionality of the plot (2 or 3).
"""
if f is not None and not callable(f):
raise ValueError("with_figure decorator only accepts keyword arguments")
if dim not in (None, 2, 3):
raise ValueError("with_figure decorator expects dim=2 or dim=3")
def decorator(f: Callable[..., BaseFigure]):
@wraps(f)
def wrapper(self, *args, **kwds) -> BaseFigure:
flatten(kwds, inplace=True)
# Create new figure if not provided
if "fig" not in kwds:
kwds["fig"] = go.Figure()
# Merge common options into layout
match dim:
case None:
pass
case 2:
for k, v in collect_prefix(kwds, "xaxis_", remove=True).items():
kwds.setdefault(f"layout_xaxis_{k}", v)
for k, v in collect_prefix(kwds, "yaxis_", remove=True).items():
kwds.setdefault(f"layout_yaxis_{k}", v)
case 3:
for k, v in collect_prefix(kwds, "xaxis_", remove=True).items():
kwds.setdefault(f"layout_scene_xaxis_{k}", v)
for k, v in collect_prefix(kwds, "yaxis_", remove=True).items():
kwds.setdefault(f"layout_scene_yaxis_{k}", v)
for k, v in collect_prefix(kwds, "zaxis_", remove=True).items():
kwds.setdefault(f"layout_scene_zaxis_{k}", v)
for k, v in collect_prefix(kwds, "camera_", remove=True).items():
kwds.setdefault(f"layout_scene_camera_{k}", v)
# Collect theme and layout options
theme_args = collect_keys(kwds, "theme").values()
theme_kwds = collect_prefix(kwds, "theme_", remove=True)
layout = collect_prefix(kwds, "layout_", remove=True)
# Call the decorated function
fig = f(self, *args, **kwds)
# Apply theme and layout
update_theme(*theme_args, **theme_kwds, fig=fig)
fig.update_layout(**layout)
return fig
return wrapper
return decorator if f is None else decorator(f)