Source code for plotnine.positions.position_dodge2

from __future__ import annotations

import typing
from copy import copy

import numpy as np
import pandas as pd

from ..exceptions import PlotnineError
from ..utils import groupby_apply, pivot_apply
from .position_dodge import position_dodge

if typing.TYPE_CHECKING:
    from plotnine.typing import IntArray


[docs]class position_dodge2(position_dodge): """ Dodge overlaps and place objects side-by-side This is an enhanced version of :class:`~plotnine.positions.position_dodge` that can deal with rectangular overlaps that do not share a lower x border. Parameters ---------- width: float Dodging width, when different to the width of the individual elements. This is useful when you want to align narrow geoms with wider geoms preserve: str in ``['total', 'single']`` Should dodging preserve the total width of all elements at a position, or the width of a single element? padding : float Padding between elements at the same position. Elements are shrunk by this proportion to allow space between them (Default: 0.1) reverse : bool Reverse the default ordering of the groups. This is useful if you're rotating both the plot and legend. (Default: False) """ REQUIRED_AES = {"x"} def __init__( self, width=None, preserve="total", padding=0.1, reverse=False ): self.params = { "width": width, "preserve": preserve, "padding": padding, "reverse": reverse, } def setup_params(self, data): if ( ("xmin" not in data) and ("xmax" not in data) and (self.params["width"] is None) ): msg = "Width not defined. " "Set with `position_dodge2(width = ?)`" raise PlotnineError(msg) params = copy(self.params) if params["preserve"] == "total": params["n"] = None elif "x" in data: def max_x_values(gdf): n = gdf["x"].value_counts().max() return pd.DataFrame({"n": [n]}) res = groupby_apply(data, "PANEL", max_x_values) params["n"] = res["n"].max() else: def _find_x_overlaps(gdf): return pd.DataFrame({"n": find_x_overlaps(gdf)}) # interval geoms res = groupby_apply(data, "PANEL", _find_x_overlaps) params["n"] = res["n"].max() return params @classmethod def compute_panel(cls, data, scales, params): return cls.collide2(data, params=params) @staticmethod def strategy(data, params): padding = params["padding"] n = params["n"] if not all(col in data.columns for col in ["xmin", "xmax"]): data["xmin"] = data["x"] data["xmax"] = data["x"] # Groups of boxes that share the same position data["xid"] = find_x_overlaps(data) # Find newx using xid, i.e. the center of each group of # overlapping elements. for boxes, bars, etc. this should # be the same as original x, but for arbitrary rects it # may not be res1 = pivot_apply(data, "xmin", "xid", np.min) res2 = pivot_apply(data, "xmax", "xid", np.max) data["newx"] = (res1 + res2)[data["xid"].to_numpy()].to_numpy() / 2 if n is None: # If n is None, preserve total widths of elements at # each position by dividing widths by the number of # elements at that position n = data["xid"].value_counts().to_numpy() n = n[data.loc[:, "xid"] - 1] data["new_width"] = (data["xmax"] - data["xmin"]) / n else: data["new_width"] = (data["xmax"] - data["xmin"]) / n # Find the total width of each group of elements def sum_new_width(gdf): return pd.DataFrame( { "size": [gdf["new_width"].sum()], "newx": gdf["newx"].iloc[0], } ) group_sizes = groupby_apply(data, "newx", sum_new_width) # Starting xmin for each group of elements starts = group_sizes["newx"] - (group_sizes["size"] / 2) # Set the elements in place for i, start in enumerate(starts, start=1): bool_idx = data["xid"] == i divisions = np.cumsum( np.hstack([start, data.loc[bool_idx, "new_width"]]) ) data.loc[bool_idx, "xmin"] = divisions[:-1] data.loc[bool_idx, "xmax"] = divisions[1:] # x values get moved to between xmin and xmax data["x"] = (data["xmin"] + data["xmax"]) / 2 # Shrink elements to add space between them if data["xid"].duplicated().any(): pad_width = data["new_width"] * (1 - padding) data["xmin"] = data["x"] - pad_width / 2 data["xmax"] = data["x"] + pad_width / 2 data = data.drop(columns=["xid", "newx", "new_width"], errors="ignore") return data
def find_x_overlaps(df: pd.DataFrame) -> IntArray: """ Find overlapping regions along the x axis """ n = len(df) overlaps = np.zeros(n, dtype=int) overlaps[0] = 1 counter = 1 for i in range(1, n): if df["xmin"].iloc[i] >= df["xmax"].iloc[i - 1]: counter += 1 overlaps[i] = counter return overlaps