from __future__ import annotations
import typing
import warnings
import numpy as np
import pandas as pd
from ..coords import coord_flip
from ..exceptions import PlotnineWarning
from ..scales.scale_continuous import scale_continuous as ScaleContinuous
from ..utils import log
from .annotate import annotate
from .geom_path import geom_path
from .geom_rug import geom_rug
if typing.TYPE_CHECKING:
from typing import Any, Literal, Optional, Sequence
from typing_extensions import TypeGuard
from plotnine.iapi import panel_view
from plotnine.typing import (
AnyArray,
Axes,
Coord,
Geom,
Layout,
Scale,
Trans,
TupleFloat2,
TupleFloat3,
)
class _geom_logticks(geom_rug):
"""
Internal geom implementing drawing of annotation_logticks
"""
DEFAULT_AES = {}
DEFAULT_PARAMS = {
"stat": "identity",
"position": "identity",
"na_rm": False,
"sides": "bl",
"alpha": 1,
"color": "black",
"size": 0.5,
"linetype": "solid",
"lengths": (0.036, 0.0225, 0.012),
"base": 10,
}
draw_legend = staticmethod(geom_path.draw_legend)
def draw_layer(
self, data: pd.DataFrame, layout: Layout, coord: Coord, **params: Any
):
"""
Draw ticks on every panel
"""
for pid in layout.layout["PANEL"]:
ploc = pid - 1
panel_params = layout.panel_params[ploc]
ax = layout.axs[ploc]
self.draw_panel(data, panel_params, coord, ax, **params)
@staticmethod
def _check_log_scale(
base: Optional[float],
sides: str,
panel_params: panel_view,
coord: Coord,
) -> TupleFloat2:
"""
Check the log transforms
Parameters
----------
base : float or None
Base of the logarithm in which the ticks will be
calculated. If ``None``, the base of the log transform
the scale will be used.
sides : str (default: bl)
Sides onto which to draw the marks. Any combination
chosen from the characters ``btlr``, for *bottom*, *top*,
*left* or *right* side marks. If ``coord_flip()`` is used,
these are the sides *before* the flip.
panel_params : panel_view
``x`` and ``y`` view scale values.
coord : coord
Coordinate (e.g. coord_cartesian) system of the geom.
Returns
-------
out : tuple
The bases (base_x, base_y) to use when generating the ticks.
"""
def is_log_trans(t: Trans) -> bool:
return hasattr(t, "base") and t.__class__.__name__.startswith(
"log"
)
def get_base(sc, ubase: Optional[float]) -> float:
ae = sc.aesthetics[0]
if not isinstance(sc, ScaleContinuous) or not is_log_trans(
sc.trans
):
warnings.warn(
f"annotation_logticks for {ae}-axis which does not have "
"a log scale. The logticks may not make sense.",
PlotnineWarning,
)
return 10 if ubase is None else ubase
base = sc.trans.base # pyright: ignore
if ubase is not None and base != ubase:
warnings.warn(
f"The x-axis is log transformed in base={base} ,"
"but the annotation_logticks are computed in base="
f"{ubase}",
PlotnineWarning,
)
return ubase
return base
base_x, base_y = 10, 10
x_scale = panel_params.x.scale
y_scale = panel_params.y.scale
if isinstance(coord, coord_flip):
x_scale, y_scale = y_scale, x_scale
base_x, base_y = base_y, base_x
if "t" in sides or "b" in sides:
base_x = get_base(x_scale, base)
if "l" in sides or "r" in sides:
base_y = get_base(y_scale, base)
return base_x, base_y
@staticmethod
def _calc_ticks(
value_range: TupleFloat2, base: float
) -> tuple[AnyArray, AnyArray, AnyArray]:
"""
Calculate tick marks within a range
Parameters
----------
value_range: tuple
Range for which to calculate ticks.
base : number
Base of logarithm
Returns
-------
out: tuple
(major, middle, minor) tick locations
"""
def _minor(x: Sequence[Any], mid_idx: int) -> AnyArray:
return np.hstack([x[1:mid_idx], x[mid_idx + 1 : -1]])
# * Calculate the low and high powers,
# * Generate for all intervals in along the low-high power range
# The intervals are in normal space
# * Calculate evenly spaced breaks in normal space, then convert
# them to log space.
low = np.floor(value_range[0])
high = np.ceil(value_range[1])
arr = base ** np.arange(low, float(high + 1))
n_ticks = int(np.round(base) - 1)
breaks = [
log(np.linspace(b1, b2, n_ticks + 1), base)
for (b1, b2) in list(zip(arr, arr[1:]))
]
# Partition the breaks in the 3 groups
major = np.array([x[0] for x in breaks] + [breaks[-1][-1]])
if n_ticks % 2:
mid_idx = n_ticks // 2
middle = np.array([x[mid_idx] for x in breaks])
minor = np.hstack([_minor(x, mid_idx) for x in breaks])
else:
middle = np.array([])
minor = np.hstack([x[1:-1] for x in breaks])
return major, middle, minor
def draw_panel(
self,
data: pd.DataFrame,
panel_params: panel_view,
coord: Coord,
ax: Axes,
**params: Any,
):
# Any passed data is ignored, the relevant data is created
sides = params["sides"]
lengths = params["lengths"]
_aesthetics = {
"size": params["size"],
"color": params["color"],
"alpha": params["alpha"],
"linetype": params["linetype"],
}
def _draw(
geom: Geom,
axis: Literal["x", "y"],
tick_positions: tuple[AnyArray, AnyArray, AnyArray],
):
for position, length in zip(tick_positions, lengths):
data = pd.DataFrame({axis: position, **_aesthetics})
geom.draw_group(
data, panel_params, coord, ax, length=length, **params
)
if isinstance(coord, coord_flip):
tick_range_x = panel_params.y.range
tick_range_y = panel_params.x.range
else:
tick_range_x = panel_params.x.range
tick_range_y = panel_params.y.range
# these are already flipped iff coord_flip
base_x, base_y = self._check_log_scale(
params["base"], sides, panel_params, coord
)
if "b" in sides or "t" in sides:
tick_positions = self._calc_ticks(tick_range_x, base_x)
_draw(self, "x", tick_positions)
if "l" in sides or "r" in sides:
tick_positions = self._calc_ticks(tick_range_y, base_y)
_draw(self, "y", tick_positions)
[docs]class annotation_logticks(annotate):
"""
Marginal log ticks.
If added to a plot that does not have a log10 axis
on the respective side, a warning will be issued.
Parameters
----------
sides : str (default: bl)
Sides onto which to draw the marks. Any combination
chosen from the characters ``btlr``, for *bottom*, *top*,
*left* or *right* side marks. If ``coord_flip()`` is used,
these are the sides *after* the flip.
alpha : float (default: 1)
Transparency of the ticks
color : str | tuple (default: 'black')
Colour of the ticks
size : float
Thickness of the ticks
linetype : 'solid' | 'dashed' | 'dashdot' | 'dotted' | tuple
Type of line. Default is *solid*.
lengths: tuple (default (0.036, 0.0225, 0.012))
length of the ticks drawn for full / half / tenth
ticks relative to panel size
base : float (default: None)
Base of the logarithm in which the ticks will be
calculated. If ``None``, the base used to log transform
the scale will be used.
"""
def __init__(
self,
sides: str = "bl",
alpha: float = 1,
color: str | tuple[float, ...] = "black",
size: float = 0.5,
linetype: str | tuple[float, ...] = "solid",
lengths: TupleFloat3 = (0.036, 0.0225, 0.012),
base: float | None = None,
):
if len(lengths) != 3:
raise ValueError(
"length for annotation_logticks must be a tuple of 3 floats"
)
self._annotation_geom = _geom_logticks(
sides=sides,
alpha=alpha,
color=color,
size=size,
linetype=linetype,
lengths=lengths,
base=base,
)
def is_continuous_scale(sc: Scale) -> TypeGuard[ScaleContinuous]:
return isinstance(sc, ScaleContinuous)