"""
Authors: Josef Perktold, Skipper Seabold, Denis A. Engemann
"""
from statsmodels.compat.python import lrange
import numpy as np
from statsmodels.graphics import utils
from statsmodels.graphics.plottools import rainbow
[docs]
def interaction_plot(
x,
trace,
response,
func="mean",
ax=None,
plottype="b",
xlabel=None,
ylabel=None,
colors=None,
markers=None,
linestyles=None,
legendloc="best",
legendtitle=None,
**kwargs,
):
"""
Interaction plot for factor level statistics.
Note. If categorial factors are supplied levels will be internally
recoded to integers. This ensures matplotlib compatibility. Uses
a DataFrame to calculate an `aggregate` statistic for each level of the
factor or group given by `trace`.
Parameters
----------
x : array_like
The `x` factor levels constitute the x-axis. If a `pandas.Series` is
given its name will be used in `xlabel` if `xlabel` is None.
trace : array_like
The `trace` factor levels will be drawn as lines in the plot.
If `trace` is a `pandas.Series` its name will be used as the
`legendtitle` if `legendtitle` is None.
response : array_like
The reponse or dependent variable. If a `pandas.Series` is given
its name will be used in `ylabel` if `ylabel` is None.
func : function
Anything accepted by `pandas.DataFrame.aggregate`. This is applied to
the response variable grouped by the trace levels.
ax : axes, optional
Matplotlib axes instance
plottype : str {'line', 'scatter', 'both'}, optional
The type of plot to return. Can be 'l', 's', or 'b'
xlabel : str, optional
Label to use for `x`. Default is 'X'. If `x` is a `pandas.Series` it
will use the series names.
ylabel : str, optional
Label to use for `response`. Default is 'func of response'. If
`response` is a `pandas.Series` it will use the series names.
colors : list, optional
If given, must have length == number of levels in trace.
markers : list, optional
If given, must have length == number of levels in trace
linestyles : list, optional
If given, must have length == number of levels in trace.
legendloc : {None, str, int}
Location passed to the legend command.
legendtitle : {None, str}
Title of the legend.
**kwargs
These will be passed to the plot command used either plot or scatter.
If you want to control the overall plotting options, use kwargs.
Returns
-------
Figure
The figure given by `ax.figure` or a new instance.
Examples
--------
>>> import numpy as np
>>> np.random.seed(12345)
>>> weight = np.random.randint(1,4,size=60)
>>> duration = np.random.randint(1,3,size=60)
>>> days = np.log(np.random.randint(1,30, size=60))
>>> fig = interaction_plot(weight, duration, days,
... colors=['red','blue'], markers=['D','^'], ms=10)
>>> import matplotlib.pyplot as plt
>>> plt.show()
.. plot::
import numpy as np
from statsmodels.graphics.factorplots import interaction_plot
np.random.seed(12345)
weight = np.random.randint(1,4,size=60)
duration = np.random.randint(1,3,size=60)
days = np.log(np.random.randint(1,30, size=60))
fig = interaction_plot(weight, duration, days,
colors=['red','blue'], markers=['D','^'], ms=10)
import matplotlib.pyplot as plt
#plt.show()
"""
from pandas import DataFrame
fig, ax = utils.create_mpl_ax(ax)
response_name = ylabel or getattr(response, "name", "response")
func_name = getattr(func, "__name__", str(func))
ylabel = f"{func_name} of {response_name}"
xlabel = xlabel or getattr(x, "name", "X")
legendtitle = legendtitle or getattr(trace, "name", "Trace")
ax.set_ylabel(ylabel)
ax.set_xlabel(xlabel)
x_values = x_levels = None
if isinstance(x[0], str):
x_levels = np.unique(x).tolist()
x_values = lrange(len(x_levels))
x = _recode(x, dict(zip(x_levels, x_values)))
data = DataFrame(dict(x=x, trace=trace, response=response))
plot_data = data.groupby(["trace", "x"]).aggregate(func).reset_index()
# return data
# check plot args
n_trace = len(plot_data["trace"].unique())
linestyles = ["-"] * n_trace if linestyles is None else linestyles
markers = ["."] * n_trace if markers is None else markers
colors = rainbow(n_trace) if colors is None else colors
if len(linestyles) != n_trace:
raise ValueError("Must be a linestyle for each trace level")
if len(markers) != n_trace:
raise ValueError("Must be a marker for each trace level")
if len(colors) != n_trace:
raise ValueError("Must be a color for each trace level")
if plottype == "both" or plottype == "b":
for i, (_, group) in enumerate(plot_data.groupby("trace")):
# trace label
label = str(group["trace"].values[0])
ax.plot(
group["x"],
group["response"],
color=colors[i],
marker=markers[i],
label=label,
linestyle=linestyles[i],
**kwargs,
)
elif plottype == "line" or plottype == "l":
for i, (_, group) in enumerate(plot_data.groupby("trace")):
# trace label
label = str(group["trace"].values[0])
ax.plot(
group["x"],
group["response"],
color=colors[i],
label=label,
linestyle=linestyles[i],
**kwargs,
)
elif plottype == "scatter" or plottype == "s":
for i, (_, group) in enumerate(plot_data.groupby("trace")):
# trace label
label = str(group["trace"].values[0])
ax.scatter(
group["x"],
group["response"],
color=colors[i],
label=label,
marker=markers[i],
**kwargs,
)
else:
raise ValueError("Plot type %s not understood" % plottype)
ax.legend(loc=legendloc, title=legendtitle)
ax.margins(0.1)
if all([x_levels, x_values]):
ax.set_xticks(x_values)
ax.set_xticklabels(x_levels)
return fig
def _recode(x, levels):
"""Recode categorial data to int factor.
Parameters
----------
x : array_like
array like object supporting with numpy array methods of categorially
coded data.
levels : dict
mapping of labels to integer-codings
Returns
-------
out : instance numpy.ndarray
"""
from pandas import Series
name = None
index = None
if isinstance(x, Series):
name = x.name
index = x.index
x = x.values
if x.dtype.type not in [np.str_, np.object_, str]:
raise ValueError(
"This is not a categorial factor. Array of str type required."
)
elif not isinstance(levels, dict):
raise ValueError("This is not a valid value for levels. Dict required.")
elif not (np.unique(x) == np.unique(list(levels.keys()))).all():
raise ValueError("The levels do not match the array values.")
else:
out = np.empty(x.shape[0], dtype=int)
for level, coding in levels.items():
out[x == level] = coding
if name:
out = Series(out, name=name, index=index)
return out