Decision Tree
Last updated
Last updated
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
iris = load_iris()
X_train, X_test, y_train, y_test = train_test_split(
iris.data, iris.target, test_size=0.3, random_state=42)
clf = DecisionTreeClassifier(criterion='gini', max_depth=3)
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)
print("準確率:", accuracy_score(y_test, y_pred))from sklearn.tree import plot_tree
import matplotlib.pyplot as plt
plt.figure(figsize=(10, 6))
plot_tree(clf, feature_names=iris.feature_names, class_names=iris.target_names,
filled=True, rounded=True)
plt.title("Decision Tree")
plt.show()