Source code for plotnine.guides.guide_legend

from __future__ import annotations

import hashlib
import types
import typing
from contextlib import suppress
from itertools import islice
from warnings import warn

import numpy as np
import pandas as pd

from ..exceptions import PlotnineError, PlotnineWarning
from ..geoms import geom_text
from ..mapping.aes import rename_aesthetics
from ..utils import SIZE_FACTOR, remove_missing
from .guide import guide

if typing.TYPE_CHECKING:
    from typing import Optional

    from plotnine.typing import TupleInt2


# See guides.py for terminology


[docs]class guide_legend(guide): """ Legend guide Parameters ---------- nrow : int Number of rows of legends. ncol : int Number of columns of legends. byrow : bool Whether to fill the legend row-wise or column-wise. keywidth : float Width of the legend key. keyheight : float Height of the legend key. kwargs : dict Parameters passed on to :class:`.guide` """ # general nrow: int = -1 ncol: int = -1 byrow = False # key keywidth: Optional[int] = None keyheight: Optional[int] = None # parameter available_aes = {"any"} def train(self, scale, aesthetic=None): """ Create the key for the guide The key is a dataframe with two columns: - scale name : values - label : labels for each value scale name is one of the aesthetics ['x', 'y', 'color', 'fill', 'size', 'shape', 'alpha', 'stroke'] Returns this guide if trainning is successful and None if it fails """ if aesthetic is None: aesthetic = scale.aesthetics[0] breaks = scale.get_bounded_breaks() if not breaks: return None key = pd.DataFrame( {aesthetic: scale.map(breaks), "label": scale.get_labels(breaks)} ) if len(key) == 0: return None self.key = key # create a hash of the important information in the guide labels = " ".join(str(x) for x in self.key["label"]) info = "\n".join( [self.title, labels, str(self.direction), self.__class__.__name__] ) self.hash = hashlib.md5(info.encode("utf-8")).hexdigest() return self def merge(self, other): """ Merge overlapped guides For example:: from ggplot import * gg = ggplot(aes(x='cut', fill='cut', color='cut'), data=diamonds) gg + stat_bin() This would create similar guides for fill and color where only a single guide would do """ self.key = self.key.merge(other.key) duplicated = set(self.override_aes) & set(other.override_aes) if duplicated: warn("Duplicated override_aes is ignored.", PlotnineWarning) self.override_aes.update(other.override_aes) for ae in duplicated: del self.override_aes[ae] return self def create_geoms(self, plot): """ Make information needed to draw a legend for each of the layers. For each layer, that information is a dictionary with the geom to draw the guide together with the data and the parameters that will be used in the call to geom. """ # A layer either contributes to the guide, or it does not. The # guide entries may be ploted in the layers self.glayers = [] for l in plot.layers: exclude = set() if isinstance(l.show_legend, dict): l.show_legend = rename_aesthetics(l.show_legend) exclude = {ae for ae, val in l.show_legend.items() if not val} elif l.show_legend not in (None, True): continue matched = self.legend_aesthetics(l, plot) # This layer does not contribute to the legend if not set(matched) - exclude: continue data = self.key[matched].copy() # Modify aesthetics try: data = l.use_defaults(data) except PlotnineError: warn( "Failed to apply `after_scale` modifications " "to the legend.", PlotnineWarning, ) data = l.use_defaults(data, aes_modifiers={}) # override.aes in guide_legend manually changes the geom for ae in set(self.override_aes) & set(data.columns): data[ae] = self.override_aes[ae] data = remove_missing( data, l.geom.params["na_rm"], list(l.geom.REQUIRED_AES | l.geom.NON_MISSING_AES), f"{l.geom.__class__.__name__} legend", ) self.glayers.append( types.SimpleNamespace(geom=l.geom, data=data, layer=l) ) if not self.glayers: return None return self def _calculate_rows_and_cols(self) -> TupleInt2: nrow, ncol = -1, -1 nbreak = len(self.key) if hasattr(self, "nrow"): nrow = self.nrow if hasattr(self, "ncol"): ncol = self.ncol if nrow != -1 and ncol != -1: if nrow * ncol < nbreak: raise PlotnineError( "nrow x ncol need to be larger " "than the number of breaks" ) return nrow, ncol if nrow == -1 and ncol == -1: if self.direction == "horizontal": nrow = int(np.ceil(nbreak / 5)) else: ncol = int(np.ceil(nbreak / 20)) if nrow == -1: nrow = int(np.ceil(nbreak / ncol)) elif ncol == -1: ncol = int(np.ceil(nbreak / nrow)) return nrow, ncol def _set_defaults(self, theme): guide._set_defaults(self, theme) _property = theme.themeables.property self.nrow, self.ncol = self._calculate_rows_and_cols() nbreak = len(self.key) # key width and key height for each legend entry # # Take a peak into data['size'] to make sure the # legend dimensions are big enough """ >>> gg = ggplot(diamonds, aes(x='cut', y='clarity')) >>> gg = gg + stat_sum(aes(group='cut')) >>> gg + scale_size(range=(3, 25)) Note the different height sizes for the entries """ # FIXME: This should be in the geom instead of having # special case conditional branches def determine_side_length(initial_size): default_pad = initial_size * 0.5 # default_pad = 0 size = np.ones(nbreak) * initial_size for i in range(nbreak): for gl in self.glayers: _size = 0 pad = default_pad # Full size of object to appear in the # legend key with suppress(IndexError): if "size" in gl.data: _size = gl.data["size"].iloc[i] * SIZE_FACTOR if "stroke" in gl.data: _size += ( 2 * gl.data["stroke"].iloc[i] * SIZE_FACTOR ) # special case, color does not apply to # border/linewidth if isinstance(gl.geom, geom_text): pad = 0 if _size < initial_size: continue try: # color(edgecolor) affects size(linewidth) # When the edge is not visible, we should # not expand the size of the keys if gl.data["color"].iloc[i] is not None: size[i] = np.max([_size + pad, size[i]]) except KeyError: break return size # keysize if self.keywidth is None: width = determine_side_length(_property("legend_key_width")) if self.direction == "vertical": width[:] = width.max() self._keywidth = width else: self._keywidth = [self.keywidth] * nbreak if self.keyheight is None: height = determine_side_length(_property("legend_key_height")) if self.direction == "horizontal": height[:] = height.max() self._keyheight = height else: self._keyheight = [self.keyheight] * nbreak def draw(self): """ Draw guide Returns ------- out : matplotlib.offsetbox.Offsetbox A drawing of this legend """ from matplotlib.offsetbox import HPacker, TextArea, VPacker from .._mpl.offsetbox import ColoredDrawingArea obverse = slice(0, None) reverse = slice(None, None, -1) nbreak = len(self.key) _targets = self.theme._targets # When there is more than one guide, we keep # record of all of them using lists if "legend_title" not in _targets: _targets["legend_title"] = [] if "legend_text_legend" not in _targets: _targets["legend_key"] = [] _targets["legend_text_legend"] = [] # title title_box = TextArea(self.title, textprops={"color": "black"}) _targets["legend_title"].append(title_box) # labels labels = [] for item in self.key["label"]: if isinstance(item, float) and float.is_integer(item): item = int(item) # 1.0 to 1 va = "center" if self.label_position == "top" else "baseline" ta = TextArea(item, textprops={"color": "black", "va": va}) labels.append(ta) _targets["legend_text_legend"].extend(labels) # Drawings drawings = [] for i in range(nbreak): da = ColoredDrawingArea( self._keywidth[i], self._keyheight[i], 0, 0, color="white" ) # overlay geoms for gl in self.glayers: with suppress(IndexError): data = gl.data.iloc[i] da = gl.geom.draw_legend(data, da, gl.layer) drawings.append(da) _targets["legend_key"].append(drawings) # Match Drawings with labels to create the entries lookup = { "right": (HPacker, reverse), "left": (HPacker, obverse), "bottom": (VPacker, reverse), "top": (VPacker, obverse), } packer, slc = lookup[self.label_position] entries = [] for d, l in zip(drawings, labels): e = packer( children=[l, d][slc], sep=self._label_margin, align="center", pad=0, ) entries.append(e) # Put the entries together in rows or columns # A chunk is either a row or a column of entries # for a single legend if self.byrow: chunk_size, packers = self.ncol, [HPacker, VPacker] sep1 = self._legend_entry_spacing_x sep2 = self._legend_entry_spacing_y else: chunk_size, packers = self.nrow, [VPacker, HPacker] sep1 = self._legend_entry_spacing_y sep2 = self._legend_entry_spacing_x if self.reverse: entries = entries[::-1] chunks = [] for i in range(len(entries)): start = i * chunk_size stop = start + chunk_size s = islice(entries, start, stop) chunks.append(list(s)) if stop >= len(entries): break chunk_boxes = [] for chunk in chunks: d1 = packers[0](children=chunk, align="left", sep=sep1, pad=0) chunk_boxes.append(d1) # Put all the entries (row & columns) together entries_box = packers[1]( children=chunk_boxes, align="baseline", sep=sep2, pad=0 ) # Put the title and entries together packer, slc = lookup[self.title_position] children = [title_box, entries_box][slc] box = packer( children=children, sep=self._title_margin, align=self._title_align, pad=self._legend_margin, ) return box