Source code for plotnine.geoms.geom_segment

from __future__ import annotations

import typing

import numpy as np
import pandas as pd

from ..doctools import document
from ..utils import SIZE_FACTOR, interleave, make_line_segments, to_rgba
from .geom import geom
from .geom_path import geom_path

if typing.TYPE_CHECKING:
    from typing import Any

    from plotnine.iapi import panel_view
    from plotnine.typing import (
        Axes,
        Coord,
    )


[docs]@document class geom_segment(geom): """ Line segments {usage} Parameters ---------- {common_parameters} lineend : str (default: butt) Line end style, of of *butt*, *round* or *projecting.* This option is applied for solid linetypes. arrow : plotnine.geoms.geom_path.arrow (default: None) Arrow specification. Default is no arrow. See Also -------- plotnine.geoms.geom_path.arrow : for adding arrowhead(s) to segments. """ DEFAULT_AES = { "alpha": 1, "color": "black", "linetype": "solid", "size": 0.5, } REQUIRED_AES = {"x", "y", "xend", "yend"} NON_MISSING_AES = {"linetype", "size", "shape"} DEFAULT_PARAMS = { "stat": "identity", "position": "identity", "na_rm": False, "lineend": "butt", "arrow": None, } draw_legend = staticmethod(geom_path.draw_legend) @staticmethod def draw_group( data: pd.DataFrame, panel_params: panel_view, coord: Coord, ax: Axes, **params: Any, ): from matplotlib.collections import LineCollection data = coord.transform(data, panel_params) data["size"] *= SIZE_FACTOR color = to_rgba(data["color"], data["alpha"]) # start point -> end point, sequence of xy points # from which line segments are created x = interleave(data["x"], data["xend"]) y = interleave(data["y"], data["yend"]) segments = make_line_segments(x, y, ispath=False) coll = LineCollection( segments, edgecolor=color, linewidth=data["size"], linestyle=data["linetype"][0], zorder=params["zorder"], rasterized=params["raster"], ) ax.add_collection(coll) if "arrow" in params and params["arrow"]: adata = pd.DataFrame(index=range(len(data) * 2)) idx = np.arange(1, len(data) + 1) adata["group"] = np.hstack([idx, idx]) adata["x"] = np.hstack([data["x"], data["xend"]]) adata["y"] = np.hstack([data["y"], data["yend"]]) other = ["color", "alpha", "size", "linetype"] for param in other: adata[param] = np.hstack([data[param], data[param]]) params["arrow"].draw( adata, panel_params, coord, ax, constant=False, **params )