Source code for plotnine.geoms.annotation_stripes

from __future__ import annotations

import typing
from itertools import cycle, islice

import numpy as np
import pandas as pd

from ..coords import coord_flip
from ..scales.scale_discrete import scale_discrete
from .annotate import annotate
from .geom import geom
from .geom_polygon import geom_polygon
from .geom_rect import geom_rect

if typing.TYPE_CHECKING:
    from typing import Any, Literal, Sequence

    from plotnine.iapi import panel_view
    from plotnine.typing import Axes, Coord, Layout, TupleFloat2


[docs]class annotation_stripes(annotate): """ Alternating stripes, centered around each label. Useful as a background for geom_jitter. Parameters ---------- fill : list-like List of colors for the strips. The default is `("#AAAAAA", "#CCCCCC")` fill_range: 'cycle' | 'nocycle' | 'auto' | 'no' How to fill stripes beyond the range of scale:: 'cycle' # keep cycling the colors of the # stripes after the range ends 'nocycle' # stop cycling the colors of the # stripes after the range ends 'auto' # 'cycle' for continuous scales and # 'nocycle' for discrete scales. 'no' # Do not add stripes passed the range # passed the range of the scales Default is 'auto'. direction : 'vertical' or 'horizontal' Orientation of the stripes extend : tuple Range of the stripes. The default is (0, 1), top to bottom. The values should be in the range [0, 1]. **kwargs : dict Other aesthetic parameters for the rectangular stripes. They include; *alpha*, *color*, *linetype*, and *size*. """ def __init__( self, fill: Sequence[str] = ("#AAAAAA", "#CCCCCC"), fill_range: Literal["auto", "cycle", "no", "nocycle"] = "auto", direction: Literal["horizontal", "vertical"] = "vertical", extend: TupleFloat2 = (0, 1), **kwargs: Any, ): allowed = ("vertical", "horizontal") if direction not in allowed: raise ValueError(f"direction must be one of {allowed}") self._annotation_geom = _geom_stripes( fill=fill, fill_range=fill_range, extend=extend, direction=direction, **kwargs, )
class _geom_stripes(geom): DEFAULT_AES = {} REQUIRED_AES = set() DEFAULT_PARAMS = { "stat": "identity", "position": "identity", "na_rm": False, "color": None, "fill": ("#AAAAAA", "#CCCCCC"), "linetype": "solid", "size": 1, "alpha": 0.5, "direction": "vertical", "extend": (0, 1), "fill_range": "auto", } draw_legend = staticmethod(geom_polygon.draw_legend) def draw_layer( self, data: pd.DataFrame, layout: Layout, coord: Coord, **params: Any ): """ Draw stripes on every panel """ for pid in layout.layout["PANEL"]: ploc = pid - 1 panel_params = layout.panel_params[ploc] ax = layout.axs[ploc] self.draw_group(data, panel_params, coord, ax, **params) @staticmethod def draw_group( data: pd.DataFrame, panel_params: panel_view, coord: Coord, ax: Axes, **params: Any, ): extend = params["extend"] fill_range = params["fill_range"] direction = params["direction"] # Range if direction == "vertical": axis, other_axis = "x", "y" else: axis, other_axis = "y", "x" if isinstance(coord, coord_flip): axis, other_axis = other_axis, axis _axis = getattr(panel_params, axis) breaks = _axis.breaks range = _axis.range other_range = getattr(panel_params, other_axis).range if fill_range == "auto": if isinstance(_axis.scale, scale_discrete): fill_range = "nocycle" else: fill_range = "cycle" # Breaks along the width n_stripes = len(breaks) if n_stripes > 1: diff = np.diff(breaks) step = diff[0] equal_spaces = np.all(diff == step) if not equal_spaces: raise ValueError( "The major breaks are not equally spaced. " "We cannot create stripes." ) else: step = breaks[0] deltas = np.array([step / 2] * n_stripes) many_stripes = len(breaks) > 1 xmin = breaks - deltas xmax = breaks + deltas if fill_range in ("cycle", "nocycle") and many_stripes: if range[0] < breaks[0]: n_stripes += 1 xmax = np.insert(xmax, 0, xmin[0]) xmin = np.insert(xmin, 0, range[0]) if range[1] > breaks[1]: n_stripes += 1 xmin = np.append(xmin, xmax[-1]) xmax = np.append(xmax, range[1]) # Height full_height = other_range[1] - other_range[0] ymin = other_range[0] + full_height * extend[0] ymax = other_range[0] + full_height * extend[1] fill = list(islice(cycle(params["fill"]), n_stripes)) if fill_range == "nocycle" and many_stripes: # there are at least two stripes at this point fill[0] = fill[1] fill[-1] = fill[-2] if direction != "vertical": xmin, xmax, ymin, ymax = ymin, ymax, xmin, xmax data = pd.DataFrame( { "xmin": xmin, "xmax": xmax, "ymin": ymin, "ymax": ymax, "fill": fill, "alpha": params["alpha"], "color": params["color"], "linetype": params["linetype"], "size": params["size"], } ) return geom_rect.draw_group(data, panel_params, coord, ax, **params)