Source code for plotnine.positions.position_dodge

from contextlib import suppress
from copy import copy

import numpy as np
import pandas as pd

from ..exceptions import PlotnineError
from ..utils import groupby_apply, match
from .position import position


[docs]class position_dodge(position): """ Dodge overlaps and place objects side-by-side Parameters ---------- width: float Dodging width, when different to the width of the individual elements. This is useful when you want to align narrow geoms with wider geoms preserve: str in ``['total', 'single']`` Should dodging preserve the total width of all elements at a position, or the width of a single element? """ REQUIRED_AES = {"x"} def __init__(self, width=None, preserve="total"): self.params = { "width": width, "preserve": preserve, } def setup_data(self, data, params): has_xmin_xmax = "xmin" in data and "xmax" in data if "x" not in data and has_xmin_xmax: data["x"] = (data["xmin"] + data["xmax"]) / 2 return super().setup_data(data, params) def setup_params(self, data): if ( ("xmin" not in data) and ("xmax" not in data) and (self.params["width"] is None) ): msg = "Width not defined. " "Set with `position_dodge(width = ?)`" raise PlotnineError(msg) params = copy(self.params) if params["preserve"] == "total": params["n"] = None else: # Count at the xmin values per panel and find the highest # overall count def max_xmin_values(gdf): try: n = gdf["xmin"].value_counts().max() except KeyError: n = gdf["x"].value_counts().max() return pd.DataFrame({"n": [n]}) res = groupby_apply(data, "PANEL", max_xmin_values) params["n"] = res["n"].max() return params @classmethod def compute_panel(cls, data, scales, params): return cls.collide(data, params=params) @staticmethod def strategy(data, params): """ Dodge overlapping interval Assumes that each set has the same horizontal position. """ width = params["width"] with suppress(TypeError): iter(width) width = np.asarray(width) width = width[data.index] udata_group = data["group"].drop_duplicates() n = params.get("n", None) if n is None: n = len(udata_group) if n == 1: return data if not all(col in data.columns for col in ["xmin", "xmax"]): data["xmin"] = data["x"] data["xmax"] = data["x"] d_width = np.max(data["xmax"] - data["xmin"]) # Have a new group index from 1 to number of groups. # This might be needed if the group numbers in this set don't # include all of 1:n udata_group = udata_group.sort_values() groupidx = match(data["group"], udata_group) groupidx = np.asarray(groupidx) + 1 # Find the center for each group, then use that to # calculate xmin and xmax data["x"] = data["x"] + width * ((groupidx - 0.5) / n - 0.5) data["xmin"] = data["x"] - (d_width / n) / 2 # type: ignore data["xmax"] = data["x"] + (d_width / n) / 2 # type: ignore return data