Source code for plotnine.geoms.geom_abline
from __future__ import annotations
import typing
from typing import Sized
from warnings import warn
import numpy as np
import pandas as pd
from ..doctools import document
from ..exceptions import PlotnineWarning
from ..mapping import aes
from ..utils import order_as_data_mapping
from .geom import geom
from .geom_path import geom_path
from .geom_segment import geom_segment
if typing.TYPE_CHECKING:
from typing import Any
from plotnine.iapi import panel_view
from plotnine.typing import Aes, Axes, Coord, DataLike
[docs]@document
class geom_abline(geom):
"""
Lines specified by slope and intercept
{usage}
Parameters
----------
{common_parameters}
"""
DEFAULT_AES = {
"color": "black",
"linetype": "solid",
"alpha": 1,
"size": 0.5,
}
DEFAULT_PARAMS = {
"stat": "identity",
"position": "identity",
"na_rm": False,
"inherit_aes": False,
}
REQUIRED_AES = {"slope", "intercept"}
draw_legend = staticmethod(geom_path.draw_legend)
def __init__(
self,
mapping: Aes | None = None,
data: DataLike | None = None,
**kwargs: Any,
):
data, mapping = order_as_data_mapping(data, mapping)
slope = kwargs.pop("slope", None)
intercept = kwargs.pop("intercept", None)
# If nothing is set, it defaults to y=x
if mapping is None and slope is None and intercept is None:
slope = 1
intercept = 0
if slope is not None or intercept is not None:
if mapping:
warn(
"The 'intercept' and 'slope' when specified override "
"the aes() mapping.",
PlotnineWarning,
)
if isinstance(data, Sized) and len(data):
warn(
"The 'intercept' and 'slope' when specified override "
"the data",
PlotnineWarning,
)
if slope is None:
slope = 1
if intercept is None:
intercept = 0
data = pd.DataFrame(
{"intercept": np.repeat(intercept, 1), "slope": slope}
)
mapping = aes(intercept="intercept", slope="slope")
kwargs["show_legend"] = False
geom.__init__(self, mapping, data, **kwargs)
def draw_panel(
self,
data: pd.DataFrame,
panel_params: panel_view,
coord: Coord,
ax: Axes,
**params: Any,
):
"""
Plot all groups
"""
ranges = coord.backtransform_range(panel_params)
data["x"] = ranges.x[0]
data["xend"] = ranges.x[1]
data["y"] = ranges.x[0] * data["slope"] + data["intercept"]
data["yend"] = ranges.x[1] * data["slope"] + data["intercept"]
data = data.drop_duplicates()
for _, gdata in data.groupby("group"):
gdata.reset_index(inplace=True)
geom_segment.draw_group(gdata, panel_params, coord, ax, **params)