Source code for plotnine.stats.stat_function
from __future__ import annotations
import typing
import numpy as np
import pandas as pd
from ..doctools import document
from ..exceptions import PlotnineError
from ..mapping.evaluation import after_stat
from ..scales.scale_continuous import scale_continuous
from .stat import stat
if typing.TYPE_CHECKING:
from typing import Callable
from plotnine.typing import FloatArrayLike
[docs]@document
class stat_function(stat):
"""
Superimpose a function onto a plot
{usage}
Parameters
----------
{common_parameters}
fun : function
Function to evaluate.
n : int, optional (default: 101)
Number of points at which to evaluate the function.
xlim : tuple (default: None)
``x`` limits for the range. The default depends on
the ``x`` aesthetic. There is not an ``x`` aesthetic
then the ``xlim`` must be provided.
args : tuple or dict (default: None)
Arguments to pass to ``fun``.
"""
_aesthetics_doc = """
{aesthetics_table}
.. rubric:: Options for computed aesthetics
::
'x' # x points at which the function is evaluated
'fx' # points evaluated at each x
"""
DEFAULT_PARAMS = {
"geom": "path",
"position": "identity",
"na_rm": False,
"fun": None,
"n": 101,
"args": None,
"xlim": None,
}
DEFAULT_AES = {"y": after_stat("fx")}
CREATES = {"fx"}
def __init__(self, mapping=None, data=None, **kwargs):
if data is None:
def _data_func(data: pd.DataFrame) -> pd.DataFrame:
if data.empty:
data = pd.DataFrame({"group": [1]})
return data
data = _data_func
super().__init__(mapping, data, **kwargs)
def setup_params(self, data):
if not callable(self.params["fun"]):
raise PlotnineError(
"stat_function requires parameter 'fun' to be "
"a function or any other callable object"
)
return self.params
@classmethod
def compute_group(cls, data, scales, **params):
old_fun: Callable[..., FloatArrayLike] = params["fun"]
n = params["n"]
args = params["args"]
xlim = params["xlim"]
range_x = xlim or scales.x.dimension((0, 0))
if isinstance(args, (list, tuple)):
def fun(x):
return old_fun(x, *args)
elif isinstance(args, dict):
def fun(x):
return old_fun(x, **args)
elif args is not None:
def fun(x):
return old_fun(x, args)
else:
def fun(x):
return old_fun(x)
x = np.linspace(range_x[0], range_x[1], n)
# continuous scale
if isinstance(scales.x, scale_continuous):
x = scales.x.trans.inverse(x)
# We know these can handle array-likes
if isinstance(old_fun, (np.ufunc, np.vectorize)):
fx = fun(x)
else:
fx = [fun(val) for val in x]
new_data = pd.DataFrame({"x": x, "fx": fx})
return new_data