{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Linear regression diagnostics" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In real-life, relation between response and target variables are seldom linear. Here, we make use of outputs of statsmodels to visualise and identify potential problems that can occur from fitting linear regression model to non-linear relation. Primarily, the aim is to reproduce visualisations discussed in Potential Problems section (Chapter 3.3.3) of *An Introduction to Statistical Learning* (ISLR) book by James et al., Springer." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "execution": { "iopub.execute_input": "2023-01-26T15:15:05.503608Z", "iopub.status.busy": "2023-01-26T15:15:05.502527Z", "iopub.status.idle": "2023-01-26T15:15:06.092672Z", "shell.execute_reply": "2023-01-26T15:15:06.092078Z" } }, "outputs": [], "source": [ "import statsmodels\n", "import statsmodels.formula.api as smf\n", "import pandas as pd" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Simple multiple linear regression\n", "\n", "Firstly, let us load the Advertising data from Chapter 2 of ISLR book and fit a linear model to it." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2023-01-26T15:15:06.096844Z", "iopub.status.busy": "2023-01-26T15:15:06.096322Z", "iopub.status.idle": "2023-01-26T15:15:06.268547Z", "shell.execute_reply": "2023-01-26T15:15:06.267940Z" } }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
0230.137.869.222.1
144.539.345.110.4
217.245.969.39.3
3151.541.358.518.5
4180.810.858.412.9
\n", "
" ], "text/plain": [ " TV Radio Newspaper Sales\n", "0 230.1 37.8 69.2 22.1\n", "1 44.5 39.3 45.1 10.4\n", "2 17.2 45.9 69.3 9.3\n", "3 151.5 41.3 58.5 18.5\n", "4 180.8 10.8 58.4 12.9" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Load data\n", "data_url = \"https://raw.githubusercontent.com/nguyen-toan/ISLR/07fd968ea484b5f6febc7b392a28eb64329a4945/dataset/Advertising.csv\"\n", "df = pd.read_csv(data_url).drop('Unnamed: 0', axis=1)\n", "df.head()" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2023-01-26T15:15:06.273882Z", "iopub.status.busy": "2023-01-26T15:15:06.271010Z", "iopub.status.idle": "2023-01-26T15:15:06.301822Z", "shell.execute_reply": "2023-01-26T15:15:06.301192Z" } }, "outputs": [ { "data": { "text/html": [ "\n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "
Dep. Variable: R-squared: Sales 0.897 OLS 0.896 Least Squares 570.3 Thu, 26 Jan 2023 1.58e-96 15:15:06 -386.18 200 780.4 196 793.6 3 nonrobust
\n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "
coef std err t P>|t| [0.025 0.975] 2.9389 0.312 9.422 0.000 2.324 3.554 0.0458 0.001 32.809 0.000 0.043 0.049 0.1885 0.009 21.893 0.000 0.172 0.206 -0.0010 0.006 -0.177 0.860 -0.013 0.011
\n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "
 Omnibus: Durbin-Watson: 60.414 2.084 0 151.241 -1.327 1.44e-33 6.332 454

Notes:
[1] Standard Errors assume that the covariance matrix of the errors is correctly specified." ], "text/latex": [ "\\begin{center}\n", "\\begin{tabular}{lclc}\n", "\\toprule\n", "\\textbf{Dep. Variable:} & Sales & \\textbf{ R-squared: } & 0.897 \\\\\n", "\\textbf{Model:} & OLS & \\textbf{ Adj. R-squared: } & 0.896 \\\\\n", "\\textbf{Method:} & Least Squares & \\textbf{ F-statistic: } & 570.3 \\\\\n", "\\textbf{Date:} & Thu, 26 Jan 2023 & \\textbf{ Prob (F-statistic):} & 1.58e-96 \\\\\n", "\\textbf{Time:} & 15:15:06 & \\textbf{ Log-Likelihood: } & -386.18 \\\\\n", "\\textbf{No. Observations:} & 200 & \\textbf{ AIC: } & 780.4 \\\\\n", "\\textbf{Df Residuals:} & 196 & \\textbf{ BIC: } & 793.6 \\\\\n", "\\textbf{Df Model:} & 3 & \\textbf{ } & \\\\\n", "\\textbf{Covariance Type:} & nonrobust & \\textbf{ } & \\\\\n", "\\bottomrule\n", "\\end{tabular}\n", "\\begin{tabular}{lcccccc}\n", " & \\textbf{coef} & \\textbf{std err} & \\textbf{t} & \\textbf{P$> |$t$|$} & \\textbf{[0.025} & \\textbf{0.975]} \\\\\n", "\\midrule\n", "\\textbf{Intercept} & 2.9389 & 0.312 & 9.422 & 0.000 & 2.324 & 3.554 \\\\\n", "\\textbf{TV} & 0.0458 & 0.001 & 32.809 & 0.000 & 0.043 & 0.049 \\\\\n", "\\textbf{Radio} & 0.1885 & 0.009 & 21.893 & 0.000 & 0.172 & 0.206 \\\\\n", "\\textbf{Newspaper} & -0.0010 & 0.006 & -0.177 & 0.860 & -0.013 & 0.011 \\\\\n", "\\bottomrule\n", "\\end{tabular}\n", "\\begin{tabular}{lclc}\n", "\\textbf{Omnibus:} & 60.414 & \\textbf{ Durbin-Watson: } & 2.084 \\\\\n", "\\textbf{Prob(Omnibus):} & 0.000 & \\textbf{ Jarque-Bera (JB): } & 151.241 \\\\\n", "\\textbf{Skew:} & -1.327 & \\textbf{ Prob(JB): } & 1.44e-33 \\\\\n", "\\textbf{Kurtosis:} & 6.332 & \\textbf{ Cond. No. } & 454. \\\\\n", "\\bottomrule\n", "\\end{tabular}\n", "%\\caption{OLS Regression Results}\n", "\\end{center}\n", "\n", "Notes: \\newline\n", " [1] Standard Errors assume that the covariance matrix of the errors is correctly specified." ], "text/plain": [ "\n", "\"\"\"\n", " OLS Regression Results \n", "==============================================================================\n", "Dep. Variable: Sales R-squared: 0.897\n", "Model: OLS Adj. R-squared: 0.896\n", "Method: Least Squares F-statistic: 570.3\n", "Date: Thu, 26 Jan 2023 Prob (F-statistic): 1.58e-96\n", "Time: 15:15:06 Log-Likelihood: -386.18\n", "No. Observations: 200 AIC: 780.4\n", "Df Residuals: 196 BIC: 793.6\n", "Df Model: 3 \n", "Covariance Type: nonrobust \n", "==============================================================================\n", " coef std err t P>|t| [0.025 0.975]\n", "------------------------------------------------------------------------------\n", "Intercept 2.9389 0.312 9.422 0.000 2.324 3.554\n", "TV 0.0458 0.001 32.809 0.000 0.043 0.049\n", "Radio 0.1885 0.009 21.893 0.000 0.172 0.206\n", "Newspaper -0.0010 0.006 -0.177 0.860 -0.013 0.011\n", "==============================================================================\n", "Omnibus: 60.414 Durbin-Watson: 2.084\n", "Prob(Omnibus): 0.000 Jarque-Bera (JB): 151.241\n", "Skew: -1.327 Prob(JB): 1.44e-33\n", "Kurtosis: 6.332 Cond. No. 454.\n", "==============================================================================\n", "\n", "Notes:\n", "[1] Standard Errors assume that the covariance matrix of the errors is correctly specified.\n", "\"\"\"" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Fitting linear model\n", "res = smf.ols(formula= \"Sales ~ TV + Radio + Newspaper\", data=df).fit()\n", "res.summary()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Diagnostic Figures/Table\n", "\n", "In the following first we present a base code that we will later use to generate following diagnostic plots:\n", "\n", " a. residual\n", " b. qq\n", " c. scale location\n", " d. leverage\n", "\n", "and a table\n", "\n", " a. vif" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2023-01-26T15:15:06.304826Z", "iopub.status.busy": "2023-01-26T15:15:06.304607Z", "iopub.status.idle": "2023-01-26T15:15:06.764358Z", "shell.execute_reply": "2023-01-26T15:15:06.763675Z" } }, "outputs": [], "source": [ "# base code\n", "import numpy as np\n", "import seaborn as sns\n", "from statsmodels.tools.tools import maybe_unwrap_results\n", "from statsmodels.graphics.gofplots import ProbPlot\n", "from statsmodels.stats.outliers_influence import variance_inflation_factor\n", "import matplotlib.pyplot as plt\n", "from typing import Type\n", "\n", "style_talk = 'seaborn-talk' #refer to plt.style.available\n", "\n", "class Linear_Reg_Diagnostic():\n", " \"\"\"\n", " Diagnostic plots to identify potential problems in a linear regression fit.\n", " Mainly,\n", " a. non-linearity of data\n", " b. Correlation of error terms\n", " c. non-constant variance \n", " d. outliers\n", " e. high-leverage points\n", " f. collinearity\n", "\n", " Author:\n", " Prajwal Kafle (p33ajkafle@gmail.com, where 3 = r)\n", " Does not come with any sort of warranty. \n", " Please test the code one your end before using.\n", " \"\"\"\n", "\n", " def __init__(self, \n", " results: Type[statsmodels.regression.linear_model.RegressionResultsWrapper]) -> None:\n", " \"\"\"\n", " For a linear regression model, generates following diagnostic plots:\n", "\n", " a. residual\n", " b. qq\n", " c. scale location and\n", " d. leverage\n", "\n", " and a table\n", "\n", " e. vif\n", "\n", " Args:\n", " results (Type[statsmodels.regression.linear_model.RegressionResultsWrapper]): \n", " must be instance of statsmodels.regression.linear_model object\n", "\n", " Raises:\n", " TypeError: if instance does not belong to above object\n", "\n", " Example:\n", " >>> import numpy as np\n", " >>> import pandas as pd\n", " >>> import statsmodels.formula.api as smf\n", " >>> x = np.linspace(-np.pi, np.pi, 100)\n", " >>> y = 3*x + 8 + np.random.normal(0,1, 100)\n", " >>> df = pd.DataFrame({'x':x, 'y':y})\n", " >>> res = smf.ols(formula= \"y ~ x\", data=df).fit()\n", " >>> cls = Linear_Reg_Diagnostic(res)\n", " >>> cls(plot_context=\"seaborn-paper\") \n", "\n", " In case you do not need all plots you can also independently make an individual plot/table\n", " in following ways\n", "\n", " >>> cls = Linear_Reg_Diagnostic(res)\n", " >>> cls.residual_plot()\n", " >>> cls.qq_plot()\n", " >>> cls.scale_location_plot()\n", " >>> cls.leverage_plot()\n", " >>> cls.vif_table()\n", " \"\"\"\n", "\n", " if isinstance(results, statsmodels.regression.linear_model.RegressionResultsWrapper) is False:\n", " raise TypeError(\"result must be instance of statsmodels.regression.linear_model.RegressionResultsWrapper object\")\n", "\n", " self.results = maybe_unwrap_results(results)\n", "\n", " self.y_true = self.results.model.endog\n", " self.y_predict = self.results.fittedvalues\n", " self.xvar = self.results.model.exog\n", " self.xvar_names = self.results.model.exog_names\n", "\n", " self.residual = np.array(self.results.resid)\n", " influence = self.results.get_influence()\n", " self.residual_norm = influence.resid_studentized_internal\n", " self.leverage = influence.hat_matrix_diag\n", " self.cooks_distance = influence.cooks_distance[0]\n", " self.nparams = len(self.results.params)\n", "\n", " def __call__(self, plot_context='seaborn-paper'):\n", " # print(plt.style.available)\n", " with plt.style.context(plot_context):\n", " fig, ax = plt.subplots(nrows=2, ncols=2, figsize=(10,10))\n", " self.residual_plot(ax=ax[0,0])\n", " self.qq_plot(ax=ax[0,1])\n", " self.scale_location_plot(ax=ax[1,0])\n", " self.leverage_plot(ax=ax[1,1])\n", " plt.show()\n", " \n", " self.vif_table()\n", " return fig, ax\n", "\n", "\n", " def residual_plot(self, ax=None):\n", " \"\"\"\n", " Residual vs Fitted Plot\n", "\n", " Graphical tool to identify non-linearity.\n", " (Roughly) Horizontal red line is an indicator that the residual has a linear pattern\n", " \"\"\"\n", " if ax is None:\n", " fig, ax = plt.subplots()\n", "\n", " sns.residplot(\n", " x=self.y_predict, \n", " y=self.residual, \n", " lowess=True,\n", " scatter_kws={'alpha': 0.5},\n", " line_kws={'color': 'red', 'lw': 1, 'alpha': 0.8}, \n", " ax=ax)\n", "\n", " # annotations\n", " residual_abs = np.abs(self.residual)\n", " abs_resid = np.flip(np.sort(residual_abs))\n", " abs_resid_top_3 = abs_resid[:3]\n", " for i, _ in enumerate(abs_resid_top_3):\n", " ax.annotate(\n", " i, \n", " xy=(self.y_predict[i], self.residual[i]), \n", " color='C3')\n", "\n", " ax.set_title('Residuals vs Fitted', fontweight=\"bold\")\n", " ax.set_xlabel('Fitted values')\n", " ax.set_ylabel('Residuals')\n", " return ax\n", "\n", " def qq_plot(self, ax=None):\n", " \"\"\"\n", " Standarized Residual vs Theoretical Quantile plot\n", "\n", " Used to visually check if residuals are normally distributed.\n", " Points spread along the diagonal line will suggest so.\n", " \"\"\"\n", " if ax is None:\n", " fig, ax = plt.subplots()\n", " \n", " QQ = ProbPlot(self.residual_norm)\n", " QQ.qqplot(line='45', alpha=0.5, lw=1, ax=ax)\n", "\n", " # annotations\n", " abs_norm_resid = np.flip(np.argsort(np.abs(self.residual_norm)), 0)\n", " abs_norm_resid_top_3 = abs_norm_resid[:3]\n", " for r, i in enumerate(abs_norm_resid_top_3):\n", " ax.annotate(\n", " i, \n", " xy=(np.flip(QQ.theoretical_quantiles, 0)[r], self.residual_norm[i]), \n", " ha='right', color='C3')\n", "\n", " ax.set_title('Normal Q-Q', fontweight=\"bold\")\n", " ax.set_xlabel('Theoretical Quantiles')\n", " ax.set_ylabel('Standardized Residuals')\n", " return ax\n", " \n", " def scale_location_plot(self, ax=None):\n", " \"\"\"\n", " Sqrt(Standarized Residual) vs Fitted values plot\n", "\n", " Used to check homoscedasticity of the residuals.\n", " Horizontal line will suggest so.\n", " \"\"\"\n", " if ax is None:\n", " fig, ax = plt.subplots()\n", "\n", " residual_norm_abs_sqrt = np.sqrt(np.abs(self.residual_norm))\n", "\n", " ax.scatter(self.y_predict, residual_norm_abs_sqrt, alpha=0.5);\n", " sns.regplot(\n", " x=self.y_predict, \n", " y=residual_norm_abs_sqrt,\n", " scatter=False, ci=False,\n", " lowess=True,\n", " line_kws={'color': 'red', 'lw': 1, 'alpha': 0.8}, \n", " ax=ax)\n", "\n", " # annotations\n", " abs_sq_norm_resid = np.flip(np.argsort(residual_norm_abs_sqrt), 0)\n", " abs_sq_norm_resid_top_3 = abs_sq_norm_resid[:3]\n", " for i in abs_sq_norm_resid_top_3:\n", " ax.annotate(\n", " i, \n", " xy=(self.y_predict[i], residual_norm_abs_sqrt[i]), \n", " color='C3')\n", " ax.set_title('Scale-Location', fontweight=\"bold\")\n", " ax.set_xlabel('Fitted values')\n", " ax.set_ylabel(r'$\\sqrt{|\\mathrm{Standardized\\ Residuals}|}$');\n", " return ax\n", "\n", " def leverage_plot(self, ax=None):\n", " \"\"\"\n", " Residual vs Leverage plot\n", "\n", " Points falling outside Cook's distance curves are considered observation that can sway the fit\n", " aka are influential.\n", " Good to have none outside the curves.\n", " \"\"\"\n", " if ax is None:\n", " fig, ax = plt.subplots()\n", "\n", " ax.scatter(\n", " self.leverage, \n", " self.residual_norm, \n", " alpha=0.5);\n", " \n", " sns.regplot(\n", " x=self.leverage, \n", " y=self.residual_norm,\n", " scatter=False,\n", " ci=False,\n", " lowess=True,\n", " line_kws={'color': 'red', 'lw': 1, 'alpha': 0.8},\n", " ax=ax)\n", "\n", " # annotations\n", " leverage_top_3 = np.flip(np.argsort(self.cooks_distance), 0)[:3]\n", " for i in leverage_top_3:\n", " ax.annotate(\n", " i, \n", " xy=(self.leverage[i], self.residual_norm[i]),\n", " color = 'C3')\n", "\n", " xtemp, ytemp = self.__cooks_dist_line(0.5) # 0.5 line\n", " ax.plot(xtemp, ytemp, label=\"Cook's distance\", lw=1, ls='--', color='red')\n", " xtemp, ytemp = self.__cooks_dist_line(1) # 1 line\n", " ax.plot(xtemp, ytemp, lw=1, ls='--', color='red')\n", " \n", " ax.set_xlim(0, max(self.leverage)+0.01)\n", " ax.set_title('Residuals vs Leverage', fontweight=\"bold\")\n", " ax.set_xlabel('Leverage')\n", " ax.set_ylabel('Standardized Residuals')\n", " ax.legend(loc='upper right')\n", " return ax\n", "\n", " def vif_table(self):\n", " \"\"\"\n", " VIF table\n", "\n", " VIF, the variance inflation factor, is a measure of multicollinearity.\n", " VIF > 5 for a variable indicates that it is highly collinear with the \n", " other input variables.\n", " \"\"\"\n", " vif_df = pd.DataFrame()\n", " vif_df[\"Features\"] = self.xvar_names\n", " vif_df[\"VIF Factor\"] = [variance_inflation_factor(self.xvar, i) for i in range(self.xvar.shape[1])]\n", "\n", " print(vif_df\n", " .sort_values(\"VIF Factor\")\n", " .round(2))\n", " \n", "\n", " def __cooks_dist_line(self, factor):\n", " \"\"\"\n", " Helper function for plotting Cook's distance curves\n", " \"\"\"\n", " p = self.nparams \n", " formula = lambda x: np.sqrt((factor * p * (1 - x)) / x)\n", " x = np.linspace(0.001, max(self.leverage), 50)\n", " y = formula(x)\n", " return x, y" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Making use of the \n", "\n", " * fitted model on the Advertising data above and \n", " * the base code provided\n", "now we generate diagnostic plots one by one." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2023-01-26T15:15:06.767559Z", "iopub.status.busy": "2023-01-26T15:15:06.767261Z", "iopub.status.idle": "2023-01-26T15:15:06.775720Z", "shell.execute_reply": "2023-01-26T15:15:06.775170Z" } }, "outputs": [], "source": [ "cls = Linear_Reg_Diagnostic(res)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**A. Residual vs Fitted values**\n", "\n", "Graphical tool to identify non-linearity.\n", "\n", "In the graph red (roughly) horizontal line is an indicator that the residual has a linear pattern." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2023-01-26T15:15:06.779107Z", "iopub.status.busy": "2023-01-26T15:15:06.778272Z", "iopub.status.idle": "2023-01-26T15:15:07.015097Z", "shell.execute_reply": "2023-01-26T15:15:07.014528Z" } }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "