Source code for statsmodels.graphics.factorplots

"""
Authors:    Josef Perktold, Skipper Seabold, Denis A. Engemann
"""

from statsmodels.compat.python import lrange

import numpy as np

from statsmodels.graphics import utils
from statsmodels.graphics.plottools import rainbow


[docs] def interaction_plot( x, trace, response, func="mean", ax=None, plottype="b", xlabel=None, ylabel=None, colors=None, markers=None, linestyles=None, legendloc="best", legendtitle=None, **kwargs, ): """ Interaction plot for factor level statistics. Note. If categorial factors are supplied levels will be internally recoded to integers. This ensures matplotlib compatibility. Uses a DataFrame to calculate an `aggregate` statistic for each level of the factor or group given by `trace`. Parameters ---------- x : array_like The `x` factor levels constitute the x-axis. If a `pandas.Series` is given its name will be used in `xlabel` if `xlabel` is None. trace : array_like The `trace` factor levels will be drawn as lines in the plot. If `trace` is a `pandas.Series` its name will be used as the `legendtitle` if `legendtitle` is None. response : array_like The reponse or dependent variable. If a `pandas.Series` is given its name will be used in `ylabel` if `ylabel` is None. func : function Anything accepted by `pandas.DataFrame.aggregate`. This is applied to the response variable grouped by the trace levels. ax : axes, optional Matplotlib axes instance plottype : str {'line', 'scatter', 'both'}, optional The type of plot to return. Can be 'l', 's', or 'b' xlabel : str, optional Label to use for `x`. Default is 'X'. If `x` is a `pandas.Series` it will use the series names. ylabel : str, optional Label to use for `response`. Default is 'func of response'. If `response` is a `pandas.Series` it will use the series names. colors : list, optional If given, must have length == number of levels in trace. markers : list, optional If given, must have length == number of levels in trace linestyles : list, optional If given, must have length == number of levels in trace. legendloc : {None, str, int} Location passed to the legend command. legendtitle : {None, str} Title of the legend. **kwargs These will be passed to the plot command used either plot or scatter. If you want to control the overall plotting options, use kwargs. Returns ------- Figure The figure given by `ax.figure` or a new instance. Examples -------- >>> import numpy as np >>> np.random.seed(12345) >>> weight = np.random.randint(1,4,size=60) >>> duration = np.random.randint(1,3,size=60) >>> days = np.log(np.random.randint(1,30, size=60)) >>> fig = interaction_plot(weight, duration, days, ... colors=['red','blue'], markers=['D','^'], ms=10) >>> import matplotlib.pyplot as plt >>> plt.show() .. plot:: import numpy as np from statsmodels.graphics.factorplots import interaction_plot np.random.seed(12345) weight = np.random.randint(1,4,size=60) duration = np.random.randint(1,3,size=60) days = np.log(np.random.randint(1,30, size=60)) fig = interaction_plot(weight, duration, days, colors=['red','blue'], markers=['D','^'], ms=10) import matplotlib.pyplot as plt #plt.show() """ from pandas import DataFrame fig, ax = utils.create_mpl_ax(ax) response_name = ylabel or getattr(response, "name", "response") func_name = getattr(func, "__name__", str(func)) ylabel = f"{func_name} of {response_name}" xlabel = xlabel or getattr(x, "name", "X") legendtitle = legendtitle or getattr(trace, "name", "Trace") ax.set_ylabel(ylabel) ax.set_xlabel(xlabel) x_values = x_levels = None if isinstance(x[0], str): x_levels = np.unique(x).tolist() x_values = lrange(len(x_levels)) x = _recode(x, dict(zip(x_levels, x_values))) data = DataFrame(dict(x=x, trace=trace, response=response)) plot_data = data.groupby(["trace", "x"]).aggregate(func).reset_index() # return data # check plot args n_trace = len(plot_data["trace"].unique()) linestyles = ["-"] * n_trace if linestyles is None else linestyles markers = ["."] * n_trace if markers is None else markers colors = rainbow(n_trace) if colors is None else colors if len(linestyles) != n_trace: raise ValueError("Must be a linestyle for each trace level") if len(markers) != n_trace: raise ValueError("Must be a marker for each trace level") if len(colors) != n_trace: raise ValueError("Must be a color for each trace level") if plottype == "both" or plottype == "b": for i, (_, group) in enumerate(plot_data.groupby("trace")): # trace label label = str(group["trace"].values[0]) ax.plot( group["x"], group["response"], color=colors[i], marker=markers[i], label=label, linestyle=linestyles[i], **kwargs, ) elif plottype == "line" or plottype == "l": for i, (_, group) in enumerate(plot_data.groupby("trace")): # trace label label = str(group["trace"].values[0]) ax.plot( group["x"], group["response"], color=colors[i], label=label, linestyle=linestyles[i], **kwargs, ) elif plottype == "scatter" or plottype == "s": for i, (_, group) in enumerate(plot_data.groupby("trace")): # trace label label = str(group["trace"].values[0]) ax.scatter( group["x"], group["response"], color=colors[i], label=label, marker=markers[i], **kwargs, ) else: raise ValueError("Plot type %s not understood" % plottype) ax.legend(loc=legendloc, title=legendtitle) ax.margins(0.1) if all([x_levels, x_values]): ax.set_xticks(x_values) ax.set_xticklabels(x_levels) return fig
def _recode(x, levels): """Recode categorial data to int factor. Parameters ---------- x : array_like array like object supporting with numpy array methods of categorially coded data. levels : dict mapping of labels to integer-codings Returns ------- out : instance numpy.ndarray """ from pandas import Series name = None index = None if isinstance(x, Series): name = x.name index = x.index x = x.values if x.dtype.type not in [np.str_, np.object_, str]: raise ValueError( "This is not a categorial factor. Array of str type required." ) elif not isinstance(levels, dict): raise ValueError("This is not a valid value for levels. Dict required.") elif not (np.unique(x) == np.unique(list(levels.keys()))).all(): raise ValueError("The levels do not match the array values.") else: out = np.empty(x.shape[0], dtype=int) for level, coding in levels.items(): out[x == level] = coding if name: out = Series(out, name=name, index=index) return out