{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# k-Nearest Neighbors" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Introduction" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "k-nearest neighbors (*kNN*) is an intuitively simple algorithm in which the label (in classification) or continuous value (in regression) of an unknown test data point is determined using the closest *k* points to it in a training dataset. The distance between data points is often calculated using [Euclidean distance](https://en.wikipedia.org/wiki/Euclidean_distance), although there are other possible choices.\n", "\n", "```{note}\n", "*kNN* is sometimes called a \"lazy\" model because it leaves all the computational work until testing time. Training simply involves storing the training dataset, whereas testing an unknown data point involves searching the training dataset for the *k* nearest neighbors.\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Animation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The animation below shows an example of a *kNN* classifier, where an unknown data point is classified using the most common class amongst the *k* nearest neighbors (meassured by Euclidean distance). In the example, there are two features (*X1* and *X2*) and the response (*Y*). Choose different values of *k* from the drop down menu and observe how the classification of the unknown data point changes.\n", "\n", "```{note}\n", "In the regression setting, it is common to assign the average value of the *k* nearest neighbors to the test data point.\n", "```" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "tags": [ "hide-input" ] }, "outputs": [ { "data": { "text/html": [ " \n", " " ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", " \n", " \n", "
\n", " \n", "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Imports\n", "import numpy as np\n", "import pandas as pd\n", "import plotly.io as pio\n", "import plotly.express as px\n", "import plotly.offline as py\n", "pio.renderers.default = \"notebook\"\n", "\n", "# Create the data\n", "np.random.seed(11)\n", "df = pd.DataFrame({'X1': np.random.randint(1, 10, 9),\n", " 'X2': np.random.randint(1, 10, 9),\n", " 'Y': np.random.choice(['Class 2', 'Class 1'], size=9)})\n", "df.loc[len(df)] = [6, 3, 'Unknown'] # query point\n", "df['Distance'] = ((df[['X1', 'X2']] - df.iloc[-1, :2]) ** 2).sum(axis=1) # distances from query point\n", "df = df.sort_values(by='Distance')\n", "df['Predicted Class'] = ['Unknown', 'Class 1', 'Class 1', 'Class 1', 'Class 1', 'Class 1', 'Class 1', 'Class 2', 'Class 2', 'Class 2']\n", "\n", "# Plot with plotly\n", "color_dict = {\"Class 1\": \"#636EFA\", \"Class 2\": \"#EF553B\", \"Unknown\": \"#7F7F7F\"}\n", "fig = px.scatter(df, x=\"X1\", y=\"X2\", color='Y', color_discrete_map=color_dict,\n", " range_x=[0, 10], range_y=[0, 10],\n", " width=650, height=520)\n", "fig.update_traces(marker=dict(size=20,\n", " line=dict(width=1)))\n", "\n", "# Add lines\n", "shape_dict = {} # create a dictionary\n", "for k in range(0, len(df)):\n", " shape_dict[k] = [dict(type=\"line\", xref=\"x\", yref=\"y\",x0=x, y0=y, x1=6, y1=3, layer='below',\n", " line=dict(color=\"Black\", width=2)) for x, y in df.iloc[1:k+1, :2].to_numpy()]\n", " if k != 0:\n", " shape_dict[k].append(dict(type=\"circle\", xref=\"x\", yref=\"y\",x0=5.75, y0=2.75, x1=6.25, y1=3.25,\n", " fillcolor=color_dict[df.iloc[k, 4]]))\n", "\n", "# Add dropdown\n", "fig.update_layout(\n", " updatemenus=[dict(buttons=[dict(args=[{\"shapes\": shape_dict[k]}],\n", " label=str(k),\n", " method=\"relayout\") for k in range(0, len(df))],\n", " direction=\"down\", showactive=True,\n", " x=0.115, xanchor=\"left\", y=1.14, yanchor=\"top\")])\n", "\n", "# Add dropdown label\n", "fig.update_layout(annotations=[dict(text=\"k = \",\n", " x=0, xref=\"paper\", y=1.13, yref=\"paper\",\n", " align=\"left\", showarrow=False)],\n", " font=dict(size=20))" ] } ], "metadata": { "celltoolbar": "Tags", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.6" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": true, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 4 }