{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Interactive Example of Decision Trees\n", "\n", "### Authors: \n", "- Christian Michelsen (Niels Bohr Institute)\n", "\n", "### Date: \n", "- 03-01-2019 (latest update)\n", "\n", "***\n", "\n", "This is a Jupyter Notebook which in an interactive fashion illustrates the inner workings of the machine learning (ML) algorithm called decision trees. \n", "This notebook is based on the same dataset as the `fisher_discrimant.ipynb` notebook: the __[Iris dataset](https://en.wikipedia.org/wiki/Iris_flower_data_set)__. The focus on this small example is neither the actual code nor getting any specific results, but - hopefully - getting a better understanding of the foundation of all tree-based ML algorithms; the decision tree. This is also why we don't describe the code in great detail - and simply load the dataset from Sklearn directly with the __[load_iris]( https://scikit-learn.org/stable/datasets/index.html#iris-plants-dataset)__ function - but the first part of the code should hopefully look familiar by now. \n", "\n", "\n", "\n", "***\n", "\n", "First, we import the modules we want to use:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "from sklearn.tree import DecisionTreeClassifier, export_graphviz\n", "from sklearn import tree\n", "from sklearn.datasets import load_iris, load_wine\n", "from sklearn.metrics import accuracy_score\n", "from IPython.display import SVG\n", "from graphviz import Source\n", "from IPython.display import display \n", "from ipywidgets import interactive\n", "\n", "\n", "verbose = True\n", "N_verbose = 10" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And load the dataset and extract the feature matrix (the independent variables with shape `(# samples, # features) = (150, 4)`:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)', 'class']\n", "[5.1 3.5 1.4 0.2] 0\n", "[4.9 3. 1.4 0.2] 0\n", "[4.7 3.2 1.3 0.2] 0\n", "[4.6 3.1 1.5 0.2] 0\n", "[5. 3.6 1.4 0.2] 0\n", "[5.4 3.9 1.7 0.4] 0\n", "[4.6 3.4 1.4 0.3] 0\n", "[5. 3.4 1.5 0.2] 0\n", "[4.4 2.9 1.4 0.2] 0\n", "[4.9 3.1 1.5 0.1] 0\n" ] } ], "source": [ "# Load dataset - either Fisher's Iris data or the alternative Wine data\n", "data = load_iris()\n", "# data = load_wine()\n", "\n", "# Feature matrix\n", "X = data.data\n", "# Target vector\n", "y = data.target\n", "# Feature names\n", "feature_names = data.feature_names\n", "\n", "if verbose:\n", " print(feature_names + ['class'])\n", " for i, (xi, yi) in enumerate(zip(X, y)):\n", " if i < N_verbose:\n", " print(xi, yi)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Part A: Max Depth\n", "\n", "First we start with the simple part of the exercise; understanding the (hyper)parameter `depth` and what it changes. Try to play around with the slider below and see how the decision tree graph changes along with the accuracy of the prediction:" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "5f262764bd3a4654a8114a4fcb9f5382", "version_major": 2, "version_minor": 0 }, "text/plain": [ "interactive(children=(IntSlider(value=1, description='depth', max=5, min=1), Output()), _dom_classes=('widget-…" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "def fit_and_grapth_estimator(estimator):\n", " estimator.fit(X, y)\n", " accuracy = accuracy_score(y, estimator.predict(X))\n", " print(f'Accuracy: {accuracy:.4f}')\n", " graph = Source(tree.export_graphviz(estimator, out_file=None, feature_names=feature_names, \n", " class_names=['0', '1', '2'], filled = True))\n", " display(SVG(graph.pipe(format='svg')))\n", " return estimator\n", "\n", "\n", "def plot_tree_simple(depth=1):\n", " estimator = DecisionTreeClassifier(random_state = 0, max_depth = depth)\n", " estimator = fit_and_grapth_estimator(estimator)\n", " return estimator\n", "\n", "inter_simple = interactive(plot_tree_simple, depth=(1, 5, 1))\n", "\n", "display(inter_simple)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Part B: More hyper parameters\n", "\n", "Now we add in some extra hyper parameters to change. We add the criterion `crit` which states how the loss should be measured (either the __[Gini impurity](https://en.wikipedia.org/wiki/Decision_tree_learning#Gini_impurity)__ or the __[Entropy](https://en.wikipedia.org/wiki/Decision_tree_learning#Information_gain)__), the strategy behind the splitting algorithm `split`, the minimum number of samples in each split/leaf `min_split`/`min_leaf`. See the following link for more information about the __[DecisionTreeClassifier](https://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html)__.\n", "\n", "Try to play around with the different parameters and see how they affect the tree graph and accuracy:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "5b812f5ae7b54f629ec6860221cd5619", "version_major": 2, "version_minor": 0 }, "text/plain": [ "interactive(children=(IntSlider(value=1, description='depth', max=5, min=1), Dropdown(description='crit', opti…" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "def plot_tree_advanced(depth=1, crit='gini', split='best', min_split=2, min_leaf=1):\n", " estimator = DecisionTreeClassifier(random_state = 0, \n", " criterion = crit, \n", " splitter = split, \n", " max_depth = depth, \n", " min_samples_split=min_split, \n", " min_samples_leaf=min_leaf,\n", " )\n", " return fit_and_grapth_estimator(estimator)\n", "\n", "inter_advanced = interactive(plot_tree_advanced, \n", " depth=(1, 5, 1),\n", " min_split=(2, 10), \n", " min_leaf=(1, 10),\n", " crit = [\"gini\", \"entropy\"], \n", " split = [\"best\", \"random\"], \n", " )\n", "\n", "display(inter_advanced)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.7" } }, "nbformat": 4, "nbformat_minor": 2 }