{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Decision Trees" ] }, { "cell_type": "markdown", "metadata": { "tags": [] }, "source": [ "# Setup" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This project requires Python 3.7 or above:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import sys\n", "\n", "assert sys.version_info >= (3, 7)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It also requires Scikit-Learn ≥ 1.0.1:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from packaging import version\n", "import sklearn\n", "\n", "assert version.parse(sklearn.__version__) >= version.parse(\"1.0.1\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's define the default font sizes to make the figures prettier:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "\n", "plt.rc('font', size=14)\n", "plt.rc('axes', labelsize=14, titlesize=14)\n", "plt.rc('legend', fontsize=14)\n", "plt.rc('xtick', labelsize=10)\n", "plt.rc('ytick', labelsize=10)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And let's create the `images/decision_trees` folder (if it doesn't already exist), and define the `save_fig()` function which is used through this notebook to save the figures in high-res for the book:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "from pathlib import Path\n", "\n", "IMAGES_PATH = Path() / \"images\" / \"decision_trees\"\n", "IMAGES_PATH.mkdir(parents=True, exist_ok=True)\n", "\n", "def save_fig(fig_id, tight_layout=True, fig_extension=\"png\", resolution=300):\n", " path = IMAGES_PATH / f\"{fig_id}.{fig_extension}\"\n", " if tight_layout:\n", " plt.tight_layout()\n", " plt.savefig(path, format=fig_extension, dpi=resolution)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Training and Visualizing a Decision Tree" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
DecisionTreeClassifier(max_depth=2, random_state=42)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
DecisionTreeClassifier(min_samples_leaf=5, random_state=42)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
DecisionTreeRegressor(max_depth=2, random_state=42)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
DecisionTreeRegressor(max_depth=3, random_state=42)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
DecisionTreeClassifier(max_depth=2, random_state=42)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
DecisionTreeClassifier(max_depth=2, random_state=40)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
GridSearchCV(estimator=DecisionTreeClassifier(),\n",
" param_grid={'criterion': ['gini', 'entropy'],\n",
" 'max_leaf_nodes': [2, 3, 4, 5, 6, 7, 8]})In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. DecisionTreeClassifier(max_leaf_nodes=4)