diff --git a/decision tree classification.ipynb b/decision tree classification.ipynb
index 3ffdabe..ff34977 100644
--- a/decision tree classification.ipynb
+++ b/decision tree classification.ipynb
@@ -9,7 +9,7 @@
},
{
"cell_type": "code",
- "execution_count": 1,
+ "execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
@@ -26,7 +26,7 @@
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": 4,
"metadata": {
"scrolled": false
},
@@ -66,7 +66,7 @@
"
3.5 | \n",
" 1.4 | \n",
" 0.2 | \n",
- " 0 | \n",
+ " Setosa | \n",
" \n",
" \n",
" 1 | \n",
@@ -74,7 +74,7 @@
" 3.0 | \n",
" 1.4 | \n",
" 0.2 | \n",
- " 0 | \n",
+ " Setosa | \n",
"
\n",
" \n",
" 2 | \n",
@@ -82,7 +82,7 @@
" 3.2 | \n",
" 1.3 | \n",
" 0.2 | \n",
- " 0 | \n",
+ " Setosa | \n",
"
\n",
" \n",
" 3 | \n",
@@ -90,7 +90,7 @@
" 3.1 | \n",
" 1.5 | \n",
" 0.2 | \n",
- " 0 | \n",
+ " Setosa | \n",
"
\n",
" \n",
" 4 | \n",
@@ -98,7 +98,7 @@
" 3.6 | \n",
" 1.4 | \n",
" 0.2 | \n",
- " 0 | \n",
+ " Setosa | \n",
"
\n",
" \n",
" 5 | \n",
@@ -106,7 +106,7 @@
" 3.9 | \n",
" 1.7 | \n",
" 0.4 | \n",
- " 0 | \n",
+ " Setosa | \n",
"
\n",
" \n",
" 6 | \n",
@@ -114,7 +114,7 @@
" 3.4 | \n",
" 1.4 | \n",
" 0.3 | \n",
- " 0 | \n",
+ " Setosa | \n",
"
\n",
" \n",
" 7 | \n",
@@ -122,7 +122,7 @@
" 3.4 | \n",
" 1.5 | \n",
" 0.2 | \n",
- " 0 | \n",
+ " Setosa | \n",
"
\n",
" \n",
" 8 | \n",
@@ -130,7 +130,7 @@
" 2.9 | \n",
" 1.4 | \n",
" 0.2 | \n",
- " 0 | \n",
+ " Setosa | \n",
"
\n",
" \n",
" 9 | \n",
@@ -138,27 +138,27 @@
" 3.1 | \n",
" 1.5 | \n",
" 0.1 | \n",
- " 0 | \n",
+ " Setosa | \n",
"
\n",
" \n",
"\n",
""
],
"text/plain": [
- " sepal_length sepal_width petal_length petal_width type\n",
- "0 5.1 3.5 1.4 0.2 0\n",
- "1 4.9 3.0 1.4 0.2 0\n",
- "2 4.7 3.2 1.3 0.2 0\n",
- "3 4.6 3.1 1.5 0.2 0\n",
- "4 5.0 3.6 1.4 0.2 0\n",
- "5 5.4 3.9 1.7 0.4 0\n",
- "6 4.6 3.4 1.4 0.3 0\n",
- "7 5.0 3.4 1.5 0.2 0\n",
- "8 4.4 2.9 1.4 0.2 0\n",
- "9 4.9 3.1 1.5 0.1 0"
+ " sepal_length sepal_width petal_length petal_width type\n",
+ "0 5.1 3.5 1.4 0.2 Setosa\n",
+ "1 4.9 3.0 1.4 0.2 Setosa\n",
+ "2 4.7 3.2 1.3 0.2 Setosa\n",
+ "3 4.6 3.1 1.5 0.2 Setosa\n",
+ "4 5.0 3.6 1.4 0.2 Setosa\n",
+ "5 5.4 3.9 1.7 0.4 Setosa\n",
+ "6 4.6 3.4 1.4 0.3 Setosa\n",
+ "7 5.0 3.4 1.5 0.2 Setosa\n",
+ "8 4.4 2.9 1.4 0.2 Setosa\n",
+ "9 4.9 3.1 1.5 0.1 Setosa"
]
},
- "execution_count": 2,
+ "execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
@@ -178,7 +178,7 @@
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
@@ -206,7 +206,7 @@
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
@@ -221,6 +221,7 @@
" self.min_samples_split = min_samples_split\n",
" self.max_depth = max_depth\n",
" \n",
+ " \n",
" def build_tree(self, dataset, curr_depth=0):\n",
" ''' recursive function to build the tree ''' \n",
" \n",
@@ -229,14 +230,17 @@
" \n",
" # split until stopping conditions are met\n",
" if num_samples>=self.min_samples_split and curr_depth<=self.max_depth:\n",
+ " \n",
" # find the best split\n",
- " best_split = self.get_best_split(dataset, num_samples, num_features)\n",
+ " best_split = self.get_best_split(dataset, num_features)\n",
+ " \n",
" # check if information gain is positive\n",
" if best_split[\"info_gain\"]>0:\n",
" # recur left\n",
" left_subtree = self.build_tree(best_split[\"dataset_left\"], curr_depth+1)\n",
" # recur right\n",
" right_subtree = self.build_tree(best_split[\"dataset_right\"], curr_depth+1)\n",
+ " \n",
" # return decision node\n",
" return Node(best_split[\"feature_index\"], best_split[\"threshold\"], \n",
" left_subtree, right_subtree, best_split[\"info_gain\"])\n",
@@ -246,7 +250,8 @@
" # return leaf node\n",
" return Node(value=leaf_value)\n",
" \n",
- " def get_best_split(self, dataset, num_samples, num_features):\n",
+ " \n",
+ " def get_best_split(self, dataset, num_features):\n",
" ''' function to find the best split '''\n",
" \n",
" # dictionary to store the best split\n",
@@ -257,16 +262,19 @@
" for feature_index in range(num_features):\n",
" feature_values = dataset[:, feature_index]\n",
" possible_thresholds = np.unique(feature_values)\n",
+ " \n",
" # loop over all the feature values present in the data\n",
" for threshold in possible_thresholds:\n",
" # get current split\n",
" dataset_left, dataset_right = self.split(dataset, feature_index, threshold)\n",
+ " \n",
" # check if childs are not null\n",
" if len(dataset_left)>0 and len(dataset_right)>0:\n",
" y, left_y, right_y = dataset[:, -1], dataset_left[:, -1], dataset_right[:, -1]\n",
" # compute information gain\n",
" curr_info_gain = self.information_gain(y, left_y, right_y, \"gini\")\n",
- " # update the best split if needed\n",
+ "\n",
+ " # update the best split if needed \n",
" if curr_info_gain>max_info_gain:\n",
" best_split[\"feature_index\"] = feature_index\n",
" best_split[\"threshold\"] = threshold\n",
@@ -278,6 +286,7 @@
" # return best split\n",
" return best_split\n",
" \n",
+ " \n",
" def split(self, dataset, feature_index, threshold):\n",
" ''' function to split the data '''\n",
" \n",
@@ -285,17 +294,20 @@
" dataset_right = np.array([row for row in dataset if row[feature_index]>threshold])\n",
" return dataset_left, dataset_right\n",
" \n",
+ " \n",
" def information_gain(self, parent, l_child, r_child, mode=\"entropy\"):\n",
" ''' function to compute information gain '''\n",
" \n",
" weight_l = len(l_child) / len(parent)\n",
" weight_r = len(r_child) / len(parent)\n",
+ " \n",
" if mode==\"gini\":\n",
" gain = self.gini_index(parent) - (weight_l*self.gini_index(l_child) + weight_r*self.gini_index(r_child))\n",
" else:\n",
" gain = self.entropy(parent) - (weight_l*self.entropy(l_child) + weight_r*self.entropy(r_child))\n",
" return gain\n",
" \n",
+ " \n",
" def entropy(self, y):\n",
" ''' function to compute entropy '''\n",
" \n",
@@ -306,6 +318,7 @@
" entropy += -p_cls * np.log2(p_cls)\n",
" return entropy\n",
" \n",
+ " \n",
" def gini_index(self, y):\n",
" ''' function to compute gini index '''\n",
" \n",
@@ -316,12 +329,14 @@
" gini += p_cls**2\n",
" return 1 - gini\n",
" \n",
+ " \n",
" def calculate_leaf_value(self, Y):\n",
" ''' function to compute leaf node '''\n",
" \n",
" Y = list(Y)\n",
" return max(Y, key=Y.count)\n",
" \n",
+ " \n",
" def print_tree(self, tree=None, indent=\" \"):\n",
" ''' function to print the tree '''\n",
" \n",
@@ -332,17 +347,19 @@
" print(tree.value)\n",
"\n",
" else:\n",
- " print(\"X_\"+str(tree.feature_index), \"<=\", tree.threshold, \"?\", tree.info_gain)\n",
+ " print(f'X_{str(tree.feature_index)} <= {tree.threshold} ? {tree.info_gain}')\n",
" print(\"%sleft:\" % (indent), end=\"\")\n",
" self.print_tree(tree.left, indent + indent)\n",
" print(\"%sright:\" % (indent), end=\"\")\n",
" self.print_tree(tree.right, indent + indent)\n",
+ " \n",
" \n",
" def fit(self, X, Y):\n",
" ''' function to train the tree '''\n",
" \n",
" dataset = np.concatenate((X, Y), axis=1)\n",
" self.root = self.build_tree(dataset)\n",
+ " \n",
" \n",
" def predict(self, X):\n",
" ''' function to predict new dataset '''\n",
@@ -350,12 +367,13 @@
" preditions = [self.make_prediction(x, self.root) for x in X]\n",
" return preditions\n",
" \n",
+ " \n",
" def make_prediction(self, x, tree):\n",
" ''' function to predict a single data point '''\n",
" \n",
- " if tree.value!=None: return tree.value\n",
+ " if tree.value is not None: return tree.value\n",
" feature_val = x[tree.feature_index]\n",
- " if feature_val<=tree.threshold:\n",
+ " if feature_val <= tree.threshold:\n",
" return self.make_prediction(x, tree.left)\n",
" else:\n",
" return self.make_prediction(x, tree.right)"
@@ -370,7 +388,7 @@
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
@@ -389,24 +407,24 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "X_2 <= 1.9 ? 0.33741385372714494\n",
- " left:0.0\n",
- " right:X_3 <= 1.5 ? 0.427106638180289\n",
- " left:X_2 <= 4.9 ? 0.05124653739612173\n",
- " left:1.0\n",
- " right:2.0\n",
- " right:X_2 <= 5.0 ? 0.019631171921475288\n",
- " left:X_1 <= 2.8 ? 0.20833333333333334\n",
- " left:2.0\n",
- " right:1.0\n",
- " right:2.0\n"
+ "X_2 <= 1.9 ? 0.33741385372714494\n",
+ " left:Setosa\n",
+ " right:X_3 <= 1.5 ? 0.427106638180289\n",
+ " left:X_2 <= 4.9 ? 0.05124653739612173\n",
+ " left:Versicolor\n",
+ " right:Virginica\n",
+ " right:X_2 <= 5.0 ? 0.019631171921475288\n",
+ " left:X_1 <= 2.8 ? 0.20833333333333334\n",
+ " left:Virginica\n",
+ " right:Versicolor\n",
+ " right:Virginica\n"
]
}
],
@@ -425,7 +443,7 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": 9,
"metadata": {},
"outputs": [
{
@@ -434,7 +452,7 @@
"0.9333333333333333"
]
},
- "execution_count": 7,
+ "execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
@@ -462,7 +480,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.8.5"
+ "version": "3.12.3"
}
},
"nbformat": 4,
diff --git a/iris.csv b/iris.csv
new file mode 100644
index 0000000..1b9d029
--- /dev/null
+++ b/iris.csv
@@ -0,0 +1,151 @@
+"sepal.length","sepal.width","petal.length","petal.width","variety"
+5.1,3.5,1.4,.2,"Setosa"
+4.9,3,1.4,.2,"Setosa"
+4.7,3.2,1.3,.2,"Setosa"
+4.6,3.1,1.5,.2,"Setosa"
+5,3.6,1.4,.2,"Setosa"
+5.4,3.9,1.7,.4,"Setosa"
+4.6,3.4,1.4,.3,"Setosa"
+5,3.4,1.5,.2,"Setosa"
+4.4,2.9,1.4,.2,"Setosa"
+4.9,3.1,1.5,.1,"Setosa"
+5.4,3.7,1.5,.2,"Setosa"
+4.8,3.4,1.6,.2,"Setosa"
+4.8,3,1.4,.1,"Setosa"
+4.3,3,1.1,.1,"Setosa"
+5.8,4,1.2,.2,"Setosa"
+5.7,4.4,1.5,.4,"Setosa"
+5.4,3.9,1.3,.4,"Setosa"
+5.1,3.5,1.4,.3,"Setosa"
+5.7,3.8,1.7,.3,"Setosa"
+5.1,3.8,1.5,.3,"Setosa"
+5.4,3.4,1.7,.2,"Setosa"
+5.1,3.7,1.5,.4,"Setosa"
+4.6,3.6,1,.2,"Setosa"
+5.1,3.3,1.7,.5,"Setosa"
+4.8,3.4,1.9,.2,"Setosa"
+5,3,1.6,.2,"Setosa"
+5,3.4,1.6,.4,"Setosa"
+5.2,3.5,1.5,.2,"Setosa"
+5.2,3.4,1.4,.2,"Setosa"
+4.7,3.2,1.6,.2,"Setosa"
+4.8,3.1,1.6,.2,"Setosa"
+5.4,3.4,1.5,.4,"Setosa"
+5.2,4.1,1.5,.1,"Setosa"
+5.5,4.2,1.4,.2,"Setosa"
+4.9,3.1,1.5,.2,"Setosa"
+5,3.2,1.2,.2,"Setosa"
+5.5,3.5,1.3,.2,"Setosa"
+4.9,3.6,1.4,.1,"Setosa"
+4.4,3,1.3,.2,"Setosa"
+5.1,3.4,1.5,.2,"Setosa"
+5,3.5,1.3,.3,"Setosa"
+4.5,2.3,1.3,.3,"Setosa"
+4.4,3.2,1.3,.2,"Setosa"
+5,3.5,1.6,.6,"Setosa"
+5.1,3.8,1.9,.4,"Setosa"
+4.8,3,1.4,.3,"Setosa"
+5.1,3.8,1.6,.2,"Setosa"
+4.6,3.2,1.4,.2,"Setosa"
+5.3,3.7,1.5,.2,"Setosa"
+5,3.3,1.4,.2,"Setosa"
+7,3.2,4.7,1.4,"Versicolor"
+6.4,3.2,4.5,1.5,"Versicolor"
+6.9,3.1,4.9,1.5,"Versicolor"
+5.5,2.3,4,1.3,"Versicolor"
+6.5,2.8,4.6,1.5,"Versicolor"
+5.7,2.8,4.5,1.3,"Versicolor"
+6.3,3.3,4.7,1.6,"Versicolor"
+4.9,2.4,3.3,1,"Versicolor"
+6.6,2.9,4.6,1.3,"Versicolor"
+5.2,2.7,3.9,1.4,"Versicolor"
+5,2,3.5,1,"Versicolor"
+5.9,3,4.2,1.5,"Versicolor"
+6,2.2,4,1,"Versicolor"
+6.1,2.9,4.7,1.4,"Versicolor"
+5.6,2.9,3.6,1.3,"Versicolor"
+6.7,3.1,4.4,1.4,"Versicolor"
+5.6,3,4.5,1.5,"Versicolor"
+5.8,2.7,4.1,1,"Versicolor"
+6.2,2.2,4.5,1.5,"Versicolor"
+5.6,2.5,3.9,1.1,"Versicolor"
+5.9,3.2,4.8,1.8,"Versicolor"
+6.1,2.8,4,1.3,"Versicolor"
+6.3,2.5,4.9,1.5,"Versicolor"
+6.1,2.8,4.7,1.2,"Versicolor"
+6.4,2.9,4.3,1.3,"Versicolor"
+6.6,3,4.4,1.4,"Versicolor"
+6.8,2.8,4.8,1.4,"Versicolor"
+6.7,3,5,1.7,"Versicolor"
+6,2.9,4.5,1.5,"Versicolor"
+5.7,2.6,3.5,1,"Versicolor"
+5.5,2.4,3.8,1.1,"Versicolor"
+5.5,2.4,3.7,1,"Versicolor"
+5.8,2.7,3.9,1.2,"Versicolor"
+6,2.7,5.1,1.6,"Versicolor"
+5.4,3,4.5,1.5,"Versicolor"
+6,3.4,4.5,1.6,"Versicolor"
+6.7,3.1,4.7,1.5,"Versicolor"
+6.3,2.3,4.4,1.3,"Versicolor"
+5.6,3,4.1,1.3,"Versicolor"
+5.5,2.5,4,1.3,"Versicolor"
+5.5,2.6,4.4,1.2,"Versicolor"
+6.1,3,4.6,1.4,"Versicolor"
+5.8,2.6,4,1.2,"Versicolor"
+5,2.3,3.3,1,"Versicolor"
+5.6,2.7,4.2,1.3,"Versicolor"
+5.7,3,4.2,1.2,"Versicolor"
+5.7,2.9,4.2,1.3,"Versicolor"
+6.2,2.9,4.3,1.3,"Versicolor"
+5.1,2.5,3,1.1,"Versicolor"
+5.7,2.8,4.1,1.3,"Versicolor"
+6.3,3.3,6,2.5,"Virginica"
+5.8,2.7,5.1,1.9,"Virginica"
+7.1,3,5.9,2.1,"Virginica"
+6.3,2.9,5.6,1.8,"Virginica"
+6.5,3,5.8,2.2,"Virginica"
+7.6,3,6.6,2.1,"Virginica"
+4.9,2.5,4.5,1.7,"Virginica"
+7.3,2.9,6.3,1.8,"Virginica"
+6.7,2.5,5.8,1.8,"Virginica"
+7.2,3.6,6.1,2.5,"Virginica"
+6.5,3.2,5.1,2,"Virginica"
+6.4,2.7,5.3,1.9,"Virginica"
+6.8,3,5.5,2.1,"Virginica"
+5.7,2.5,5,2,"Virginica"
+5.8,2.8,5.1,2.4,"Virginica"
+6.4,3.2,5.3,2.3,"Virginica"
+6.5,3,5.5,1.8,"Virginica"
+7.7,3.8,6.7,2.2,"Virginica"
+7.7,2.6,6.9,2.3,"Virginica"
+6,2.2,5,1.5,"Virginica"
+6.9,3.2,5.7,2.3,"Virginica"
+5.6,2.8,4.9,2,"Virginica"
+7.7,2.8,6.7,2,"Virginica"
+6.3,2.7,4.9,1.8,"Virginica"
+6.7,3.3,5.7,2.1,"Virginica"
+7.2,3.2,6,1.8,"Virginica"
+6.2,2.8,4.8,1.8,"Virginica"
+6.1,3,4.9,1.8,"Virginica"
+6.4,2.8,5.6,2.1,"Virginica"
+7.2,3,5.8,1.6,"Virginica"
+7.4,2.8,6.1,1.9,"Virginica"
+7.9,3.8,6.4,2,"Virginica"
+6.4,2.8,5.6,2.2,"Virginica"
+6.3,2.8,5.1,1.5,"Virginica"
+6.1,2.6,5.6,1.4,"Virginica"
+7.7,3,6.1,2.3,"Virginica"
+6.3,3.4,5.6,2.4,"Virginica"
+6.4,3.1,5.5,1.8,"Virginica"
+6,3,4.8,1.8,"Virginica"
+6.9,3.1,5.4,2.1,"Virginica"
+6.7,3.1,5.6,2.4,"Virginica"
+6.9,3.1,5.1,2.3,"Virginica"
+5.8,2.7,5.1,1.9,"Virginica"
+6.8,3.2,5.9,2.3,"Virginica"
+6.7,3.3,5.7,2.5,"Virginica"
+6.7,3,5.2,2.3,"Virginica"
+6.3,2.5,5,1.9,"Virginica"
+6.5,3,5.2,2,"Virginica"
+6.2,3.4,5.4,2.3,"Virginica"
+5.9,3,5.1,1.8,"Virginica"
\ No newline at end of file