决策树一般的三种算法:ID3,C4.5,CART。
ID3 sklearn手搓
# from sklearn.metric import accuracy_score
import numpy as np
import pandas as pd
from sklearn.datasets import load_iris
datas=load_iris()
print(datas.keys())
X=datas['data']
y=datas['target']
# 简历决策树模型
from sklearn import tree
dc_tree=tree.DecisionTreeClassifier(criterion='entropy', min_samples_leaf=5)
dc_tree.fit(X, y)
y_predict = dc_tree.predict(X)
from sklearn.metrics import accuracy_score
accuracy=accuracy_score(y, y_predict)
print(accuracy)
# %matplotlib inline
from matplotlib import pyplot as plt
# fig = plt.figure(figsize=(10,10))
tree.plot_tree(dc_tree, filled=True, feature_names= datas['feature_names'], class_names=datas['target_names'])
```
输出
0.9733333333333334
Text(0.4444444444444444, 0.9, 'petal width (cm) <= 0.8\nentropy = 1.585\nsamples = 150\nvalue = [50, 50, 50]\nclass = setosa'),
Text(0.3333333333333333, 0.7, 'entropy = 0.0\nsamples = 50\nvalue = [50, 0, 0]\nclass = setosa'),
Text(0.38888888888888884, 0.8, 'True '),
Text(0.5555555555555556, 0.7, 'petal width (cm) <= 1.75\nentropy = 1.0\nsamples = 100\nvalue = [0, 50, 50]\nclass = versicolor'),
Text(0.5, 0.8, ' False'),
Text(0.3333333333333333, 0.5, 'petal length (cm) <= 4.95\nentropy = 0.445\nsamples = 54\nvalue = [0, 49, 5]\nclass = versicolor'),
Text(0.2222222222222222, 0.3, 'sepal length (cm) <= 5.15\nentropy = 0.146\nsamples = 48\nvalue = [0, 47, 1]\nclass = versicolor'),
Text(0.1111111111111111, 0.1, 'entropy = 0.722\nsamples = 5\nvalue = [0, 4, 1]\nclass = versicolor'),
Text(0.3333333333333333, 0.1, 'entropy = 0.0\nsamples = 43\nvalue = [0, 43, 0]\nclass = versicolor'),
Text(0.4444444444444444, 0.3, 'entropy = 0.918\nsamples = 6\nvalue = [0, 2, 4]\nclass = virginica'),
Text(0.7777777777777778, 0.5, 'petal length (cm) <= 4.95\nentropy = 0.151\nsamples = 46\nvalue = [0, 1, 45]\nclass = virginica'),
Text(0.6666666666666666, 0.3, 'entropy = 0.65\nsamples = 6\nvalue = [0, 1, 5]\nclass = virginica'),
Text(0.8888888888888888, 0.3, 'entropy = 0.0\nsamples = 40\nvalue = [0, 0, 40]\nclass = virginica')]

修改min_samples_leaf=1
# 简历决策树模型
from sklearn import tree
dc_tree=tree.DecisionTreeClassifier(criterion='entropy', min_samples_leaf=1)
dc_tree.fit(X, y)
y_predict = dc_tree.predict(X)
from sklearn.metrics import accuracy_score
accuracy=accuracy_score(y, y_predict)
print(accuracy)
# %matplotlib inline
from matplotlib import pyplot as plt
# fig = plt.figure(figsize=(10,10))
tree.plot_tree(dc_tree, filled=True, feature_names= datas['feature_names'], class_names=datas['target_names'])
输出
1.0
[Text(0.5, 0.9166666666666666, 'petal length (cm) <= 2.45\nentropy = 1.585\nsamples = 150\nvalue = [50, 50, 50]\nclass = setosa'),
Text(0.4230769230769231, 0.75, 'entropy = 0.0\nsamples = 50\nvalue = [50, 0, 0]\nclass = setosa'),
Text(0.46153846153846156, 0.8333333333333333, 'True '),
Text(0.5769230769230769, 0.75, 'petal width (cm) <= 1.75\nentropy = 1.0\nsamples = 100\nvalue = [0, 50, 50]\nclass = versicolor'),
Text(0.5384615384615384, 0.8333333333333333, ' False'),
Text(0.3076923076923077, 0.5833333333333334, 'petal length (cm) <= 4.95\nentropy = 0.445\nsamples = 54\nvalue = [0, 49, 5]\nclass = versicolor'),
Text(0.15384615384615385, 0.4166666666666667, 'petal width (cm) <= 1.65\nentropy = 0.146\nsamples = 48\nvalue = [0, 47, 1]\nclass = versicolor'),
Text(0.07692307692307693, 0.25, 'entropy = 0.0\nsamples = 47\nvalue = [0, 47, 0]\nclass = versicolor'),
Text(0.23076923076923078, 0.25, 'entropy = 0.0\nsamples = 1\nvalue = [0, 0, 1]\nclass = virginica'),
Text(0.46153846153846156, 0.4166666666666667, 'petal width (cm) <= 1.55\nentropy = 0.918\nsamples = 6\nvalue = [0, 2, 4]\nclass = virginica'),
Text(0.38461538461538464, 0.25, 'entropy = 0.0\nsamples = 3\nvalue = [0, 0, 3]\nclass = virginica'),
Text(0.5384615384615384, 0.25, 'sepal length (cm) <= 6.95\nentropy = 0.918\nsamples = 3\nvalue = [0, 2, 1]\nclass = versicolor'),
Text(0.46153846153846156, 0.08333333333333333, 'entropy = 0.0\nsamples = 2\nvalue = [0, 2, 0]\nclass = versicolor'),
Text(0.6153846153846154, 0.08333333333333333, 'entropy = 0.0\nsamples = 1\nvalue = [0, 0, 1]\nclass = virginica'),
Text(0.8461538461538461, 0.5833333333333334, 'petal length (cm) <= 4.85\nentropy = 0.151\nsamples = 46\nvalue = [0, 1, 45]\nclass = virginica'),
Text(0.7692307692307693, 0.4166666666666667, 'sepal width (cm) <= 3.1\nentropy = 0.918\nsamples = 3\nvalue = [0, 1, 2]\nclass = virginica'),
Text(0.6923076923076923, 0.25, 'entropy = 0.0\nsamples = 2\nvalue = [0, 0, 2]\nclass = virginica'),
Text(0.8461538461538461, 0.25, 'entropy = 0.0\nsamples = 1\nvalue = [0, 1, 0]\nclass = versicolor'),
Text(0.9230769230769231, 0.4166666666666667, 'entropy = 0.0\nsamples = 43\nvalue = [0, 0, 43]\nclass = virginica')]

