Source code for plotnine.positions.position_dodge
from contextlib import suppress
from copy import copy
import numpy as np
import pandas as pd
from ..exceptions import PlotnineError
from ..utils import groupby_apply, match
from .position import position
[docs]class position_dodge(position):
"""
Dodge overlaps and place objects side-by-side
Parameters
----------
width: float
Dodging width, when different to the width of the
individual elements. This is useful when you want
to align narrow geoms with wider geoms
preserve: str in ``['total', 'single']``
Should dodging preserve the total width of all elements
at a position, or the width of a single element?
"""
REQUIRED_AES = {"x"}
def __init__(self, width=None, preserve="total"):
self.params = {
"width": width,
"preserve": preserve,
}
def setup_data(self, data, params):
has_xmin_xmax = "xmin" in data and "xmax" in data
if "x" not in data and has_xmin_xmax:
data["x"] = (data["xmin"] + data["xmax"]) / 2
return super().setup_data(data, params)
def setup_params(self, data):
if (
("xmin" not in data)
and ("xmax" not in data)
and (self.params["width"] is None)
):
msg = "Width not defined. " "Set with `position_dodge(width = ?)`"
raise PlotnineError(msg)
params = copy(self.params)
if params["preserve"] == "total":
params["n"] = None
else:
# Count at the xmin values per panel and find the highest
# overall count
def max_xmin_values(gdf):
try:
n = gdf["xmin"].value_counts().max()
except KeyError:
n = gdf["x"].value_counts().max()
return pd.DataFrame({"n": [n]})
res = groupby_apply(data, "PANEL", max_xmin_values)
params["n"] = res["n"].max()
return params
@classmethod
def compute_panel(cls, data, scales, params):
return cls.collide(data, params=params)
@staticmethod
def strategy(data, params):
"""
Dodge overlapping interval
Assumes that each set has the same horizontal position.
"""
width = params["width"]
with suppress(TypeError):
iter(width)
width = np.asarray(width)
width = width[data.index]
udata_group = data["group"].drop_duplicates()
n = params.get("n", None)
if n is None:
n = len(udata_group)
if n == 1:
return data
if not all(col in data.columns for col in ["xmin", "xmax"]):
data["xmin"] = data["x"]
data["xmax"] = data["x"]
d_width = np.max(data["xmax"] - data["xmin"])
# Have a new group index from 1 to number of groups.
# This might be needed if the group numbers in this set don't
# include all of 1:n
udata_group = udata_group.sort_values()
groupidx = match(data["group"], udata_group)
groupidx = np.asarray(groupidx) + 1
# Find the center for each group, then use that to
# calculate xmin and xmax
data["x"] = data["x"] + width * ((groupidx - 0.5) / n - 0.5)
data["xmin"] = data["x"] - (d_width / n) / 2 # type: ignore
data["xmax"] = data["x"] + (d_width / n) / 2 # type: ignore
return data