from __future__ import annotations
import typing
from types import SimpleNamespace as NS
from warnings import warn
from ..exceptions import PlotnineWarning
from ..iapi import panel_ranges, panel_view
from ..positions.position import transform_position
from .coord import coord, dist_euclidean
if typing.TYPE_CHECKING:
from typing import Optional
import pandas as pd
from plotnine.iapi import scale_view
from plotnine.typing import (
FloatArray,
FloatSeries,
Scale,
TFloatArrayLike,
Trans,
TupleFloat2,
)
[docs]class coord_trans(coord):
"""
Transformed cartesian coordinate system
Parameters
----------
x : str | trans
Name of transform or `trans` class to
transform the x axis
y : str | trans
Name of transform or `trans` class to
transform the y axis
xlim : None | (float, float)
Limits for x axis. If None, then they are
automatically computed.
ylim : None | (float, float)
Limits for y axis. If None, then they are
automatically computed.
expand : bool
If `True`, expand the coordinate axes by
some factor. If `False`, use the limits
from the data.
"""
trans_x: Trans
trans_y: Trans
def __init__(
self,
x: str | Trans = "identity",
y: str | Trans = "identity",
xlim: Optional[TupleFloat2] = None,
ylim: Optional[TupleFloat2] = None,
expand: bool = True,
):
from mizani.transforms import gettrans
self.trans_x = gettrans(x)
self.trans_y = gettrans(y)
self.limits = NS(x=xlim, y=ylim)
self.expand = expand
def transform(
self, data: pd.DataFrame, panel_params: panel_view, munch: bool = False
) -> pd.DataFrame:
from mizani.bounds import squish_infinite
if not self.is_linear and munch:
data = self.munch(data, panel_params)
def trans_x(col: FloatSeries) -> FloatSeries:
result = transform_value(self.trans_x, col, panel_params.x.range)
if any(result.isna()):
warn(
"Coordinate transform of x aesthetic "
"created one or more NaN values.",
PlotnineWarning,
)
return result
def trans_y(col: FloatSeries) -> FloatSeries:
result = transform_value(self.trans_y, col, panel_params.y.range)
if any(result.isna()):
warn(
"Coordinate transform of y aesthetic "
"created one or more NaN values.",
PlotnineWarning,
)
return result
data = transform_position(data, trans_x, trans_y)
return transform_position(data, squish_infinite, squish_infinite)
def backtransform_range(self, panel_params: panel_view) -> panel_ranges:
return panel_ranges(
x=self.trans_x.inverse(panel_params.x.range),
y=self.trans_y.inverse(panel_params.y.range),
)
def setup_panel_params(self, scale_x: Scale, scale_y: Scale) -> panel_view:
"""
Compute the range and break information for the panel
"""
def get_scale_view(
scale: Scale, coord_limits: TupleFloat2, trans: Trans
) -> scale_view:
if coord_limits:
coord_limits = trans.transform(coord_limits)
expansion = scale.default_expansion(expand=self.expand)
ranges = scale.expand_limits(
scale.limits, expansion, coord_limits, trans
)
sv = scale.view(limits=coord_limits, range=ranges.range)
sv.range = tuple(sorted(ranges.range_coord))
sv.breaks = transform_value(
trans,
# TODO: fix typecheck
sv.breaks, # pyright: ignore
sv.range,
)
sv.minor_breaks = transform_value(trans, sv.minor_breaks, sv.range)
return sv
out = panel_view(
x=get_scale_view(scale_x, self.limits.x, self.trans_x),
y=get_scale_view(scale_y, self.limits.y, self.trans_y),
)
return out
def distance(
self,
x: FloatSeries,
y: FloatSeries,
panel_params: panel_view,
) -> FloatArray:
max_dist = dist_euclidean(panel_params.x.range, panel_params.y.range)[
0
]
xt = self.trans_x.transform(x)
yt = self.trans_y.transform(y)
return dist_euclidean(xt, yt) / max_dist
def transform_value(
trans: Trans, value: TFloatArrayLike, range: TupleFloat2
) -> TFloatArrayLike:
"""
Transform value
"""
return trans.transform(value)