Kann ich die zugrunde liegenden Entscheidungsregeln (oder 'Entscheidungspfade') aus einem trainierten Baum in einem Entscheidungsbaum als Textliste extrahieren?
Etwas wie:
if A>0.4 then if B<0.2 then if C>0.8 then class='X'
Danke für Ihre Hilfe.
Kann ich die zugrunde liegenden Entscheidungsregeln (oder 'Entscheidungspfade') aus einem trainierten Baum in einem Entscheidungsbaum als Textliste extrahieren?
Etwas wie:
if A>0.4 then if B<0.2 then if C>0.8 then class='X'
Danke für Ihre Hilfe.
Antworten:
Ich glaube, dass diese Antwort korrekter ist als die anderen Antworten hier:
from sklearn.tree import _tree
def tree_to_code(tree, feature_names):
tree_ = tree.tree_
feature_name = [
feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
for i in tree_.feature
]
print "def tree({}):".format(", ".join(feature_names))
def recurse(node, depth):
indent = " " * depth
if tree_.feature[node] != _tree.TREE_UNDEFINED:
name = feature_name[node]
threshold = tree_.threshold[node]
print "{}if {} <= {}:".format(indent, name, threshold)
recurse(tree_.children_left[node], depth + 1)
print "{}else: # if {} > {}".format(indent, name, threshold)
recurse(tree_.children_right[node], depth + 1)
else:
print "{}return {}".format(indent, tree_.value[node])
recurse(0, 1)
Dies druckt eine gültige Python-Funktion aus. Hier ist eine Beispielausgabe für einen Baum, der versucht, seine Eingabe zurückzugeben, eine Zahl zwischen 0 und 10.
def tree(f0):
if f0 <= 6.0:
if f0 <= 1.5:
return [[ 0.]]
else: # if f0 > 1.5
if f0 <= 4.5:
if f0 <= 3.5:
return [[ 3.]]
else: # if f0 > 3.5
return [[ 4.]]
else: # if f0 > 4.5
return [[ 5.]]
else: # if f0 > 6.0
if f0 <= 8.5:
if f0 <= 7.5:
return [[ 7.]]
else: # if f0 > 7.5
return [[ 8.]]
else: # if f0 > 8.5
return [[ 9.]]
Hier sind einige Stolpersteine, die ich in anderen Antworten sehe:
tree_.threshold == -2
Entscheidung, ob ein Knoten ein Blatt ist, ist keine gute Idee. Was ist, wenn es sich um einen echten Entscheidungsknoten mit einem Schwellenwert von -2 handelt? Stattdessen sollten Sie sich tree.feature
oder ansehen tree.children_*
.features = [feature_names[i] for i in tree_.feature]
stürzt mit meiner Version von sklearn ab, da einige Werte vontree.tree_.feature
-2 sind (speziell für Blattknoten).print "{}return {}".format(indent, tree_.value[node])
sollte geändert werden, print "{}return {}".format(indent, np.argmax(tree_.value[node][0]))
damit die Funktion den Klassenindex zurückgibt .
RandomForestClassifier.estimators_
, aber ich konnte nicht herausfinden, wie die Ergebnisse der Schätzer kombiniert werden können.
print "bla"
=>print("bla")
Ich habe meine eigene Funktion erstellt, um die Regeln aus den von sklearn erstellten Entscheidungsbäumen zu extrahieren:
import pandas as pd
import numpy as np
from sklearn.tree import DecisionTreeClassifier
# dummy data:
df = pd.DataFrame({'col1':[0,1,2,3],'col2':[3,4,5,6],'dv':[0,1,0,1]})
# create decision tree
dt = DecisionTreeClassifier(max_depth=5, min_samples_leaf=1)
dt.fit(df.ix[:,:2], df.dv)
Diese Funktion beginnt zuerst mit den Knoten (in den untergeordneten Arrays mit -1 gekennzeichnet) und findet dann rekursiv die übergeordneten Knoten. Ich nenne dies die "Linie" eines Knotens. Unterwegs greife ich zu den Werten, die ich erstellen muss, wenn / dann / sonst SAS-Logik:
def get_lineage(tree, feature_names):
left = tree.tree_.children_left
right = tree.tree_.children_right
threshold = tree.tree_.threshold
features = [feature_names[i] for i in tree.tree_.feature]
# get ids of child nodes
idx = np.argwhere(left == -1)[:,0]
def recurse(left, right, child, lineage=None):
if lineage is None:
lineage = [child]
if child in left:
parent = np.where(left == child)[0].item()
split = 'l'
else:
parent = np.where(right == child)[0].item()
split = 'r'
lineage.append((parent, split, threshold[parent], features[parent]))
if parent == 0:
lineage.reverse()
return lineage
else:
return recurse(left, right, parent, lineage)
for child in idx:
for node in recurse(left, right, child):
print node
Die folgenden Tupelgruppen enthalten alles, was ich zum Erstellen von SAS if / then / else-Anweisungen benötige. Ich mag es nicht, do
Blöcke in SAS zu verwenden, weshalb ich eine Logik erstelle, die den gesamten Pfad eines Knotens beschreibt. Die einzelne Ganzzahl nach den Tupeln ist die ID des Endknotens in einem Pfad. Alle vorhergehenden Tupel bilden zusammen diesen Knoten.
In [1]: get_lineage(dt, df.columns)
(0, 'l', 0.5, 'col1')
1
(0, 'r', 0.5, 'col1')
(2, 'l', 4.5, 'col2')
3
(0, 'r', 0.5, 'col1')
(2, 'r', 4.5, 'col2')
(4, 'l', 2.5, 'col1')
5
(0, 'r', 0.5, 'col1')
(2, 'r', 4.5, 'col2')
(4, 'r', 2.5, 'col1')
6
(0.5, 2.5]
. Die Bäume werden mit rekursiver Partitionierung erstellt. Nichts hindert eine Variable daran, mehrmals ausgewählt zu werden.
Ich habe den von Zelazny7 übermittelten Code geändert , um einen Pseudocode zu drucken:
def get_code(tree, feature_names):
left = tree.tree_.children_left
right = tree.tree_.children_right
threshold = tree.tree_.threshold
features = [feature_names[i] for i in tree.tree_.feature]
value = tree.tree_.value
def recurse(left, right, threshold, features, node):
if (threshold[node] != -2):
print "if ( " + features[node] + " <= " + str(threshold[node]) + " ) {"
if left[node] != -1:
recurse (left, right, threshold, features,left[node])
print "} else {"
if right[node] != -1:
recurse (left, right, threshold, features,right[node])
print "}"
else:
print "return " + str(value[node])
recurse(left, right, threshold, features, 0)
Wenn Sie get_code(dt, df.columns)
dasselbe Beispiel aufrufen, erhalten Sie:
if ( col1 <= 0.5 ) {
return [[ 1. 0.]]
} else {
if ( col2 <= 4.5 ) {
return [[ 0. 1.]]
} else {
if ( col1 <= 2.5 ) {
return [[ 1. 0.]]
} else {
return [[ 0. 1.]]
}
}
}
(threshold[node] != -2)
zu ( left[node] != -1)
(ähnlich der folgenden Methode zum Abrufen von IDs von
Scikit Learn führte eine köstliche neue Methode ein, die export_text
in Version 0.21 (Mai 2019) aufgerufen wurde , um die Regeln aus einem Baum zu extrahieren. Dokumentation hier . Es ist nicht mehr erforderlich, eine benutzerdefinierte Funktion zu erstellen.
Sobald Sie Ihr Modell angepasst haben, benötigen Sie nur zwei Codezeilen. Importieren Sie zunächst export_text
:
from sklearn.tree.export import export_text
Zweitens erstellen Sie ein Objekt, das Ihre Regeln enthält. Verwenden Sie das feature_names
Argument und übergeben Sie eine Liste Ihrer Feature-Namen, damit die Regeln besser lesbar sind . Wenn beispielsweise Ihr Modell aufgerufen wird model
und Ihre Features in einem aufgerufenen Datenrahmen benannt sind X_train
, können Sie ein Objekt mit dem Namen erstellen tree_rules
:
tree_rules = export_text(model, feature_names=list(X_train))
Dann einfach ausdrucken oder speichern tree_rules
. Ihre Ausgabe sieht folgendermaßen aus:
|--- Age <= 0.63
| |--- EstimatedSalary <= 0.61
| | |--- Age <= -0.16
| | | |--- class: 0
| | |--- Age > -0.16
| | | |--- EstimatedSalary <= -0.06
| | | | |--- class: 0
| | | |--- EstimatedSalary > -0.06
| | | | |--- EstimatedSalary <= 0.40
| | | | | |--- EstimatedSalary <= 0.03
| | | | | | |--- class: 1
In der Version 0.18.0 gibt es eine neue DecisionTreeClassifier
Methode . Die Entwickler bieten eine umfangreiche (gut dokumentierte) exemplarische Vorgehensweisedecision_path
.
Der erste Codeabschnitt in der exemplarischen Vorgehensweise, der die Baumstruktur druckt, scheint in Ordnung zu sein. Ich habe jedoch den Code im zweiten Abschnitt geändert, um ein Beispiel abzufragen. Meine Änderungen bezeichnet mit# <--
Bearbeiten Die # <--
im folgenden Code gekennzeichneten Änderungen wurden seitdem im exemplarischen Link aktualisiert, nachdem auf die Fehler in den Pull-Anforderungen Nr. 8653 und Nr. 10951 hingewiesen wurde . Es ist jetzt viel einfacher, mitzumachen.
sample_id = 0
node_index = node_indicator.indices[node_indicator.indptr[sample_id]:
node_indicator.indptr[sample_id + 1]]
print('Rules used to predict sample %s: ' % sample_id)
for node_id in node_index:
if leave_id[sample_id] == node_id: # <-- changed != to ==
#continue # <-- comment out
print("leaf node {} reached, no decision here".format(leave_id[sample_id])) # <--
else: # < -- added else to iterate through decision nodes
if (X_test[sample_id, feature[node_id]] <= threshold[node_id]):
threshold_sign = "<="
else:
threshold_sign = ">"
print("decision id node %s : (X[%s, %s] (= %s) %s %s)"
% (node_id,
sample_id,
feature[node_id],
X_test[sample_id, feature[node_id]], # <-- changed i to sample_id
threshold_sign,
threshold[node_id]))
Rules used to predict sample 0:
decision id node 0 : (X[0, 3] (= 2.4) > 0.800000011921)
decision id node 2 : (X[0, 2] (= 5.1) > 4.94999980927)
leaf node 4 reached, no decision here
Ändern Sie die sample_id
, um die Entscheidungspfade für andere Beispiele anzuzeigen. Ich habe die Entwickler nicht nach diesen Änderungen gefragt, sondern schien beim Durcharbeiten des Beispiels nur intuitiver zu sein.
from StringIO import StringIO
out = StringIO()
out = tree.export_graphviz(clf, out_file=out)
print out.getvalue()
Sie können einen Digraphenbaum sehen. Dann sind clf.tree_.feature
und clf.tree_.value
Array von Knoten, die das Merkmal aufteilen, bzw. Array von Knotenwerten. Weitere Details finden Sie in dieser Github-Quelle .
Nur weil alle so hilfreich waren, füge ich einfach eine Modifikation zu Zelazny7 und Danieles schönen Lösungen hinzu. Dieser ist für Python 2.7 mit Registerkarten, um die Lesbarkeit zu verbessern:
def get_code(tree, feature_names, tabdepth=0):
left = tree.tree_.children_left
right = tree.tree_.children_right
threshold = tree.tree_.threshold
features = [feature_names[i] for i in tree.tree_.feature]
value = tree.tree_.value
def recurse(left, right, threshold, features, node, tabdepth=0):
if (threshold[node] != -2):
print '\t' * tabdepth,
print "if ( " + features[node] + " <= " + str(threshold[node]) + " ) {"
if left[node] != -1:
recurse (left, right, threshold, features,left[node], tabdepth+1)
print '\t' * tabdepth,
print "} else {"
if right[node] != -1:
recurse (left, right, threshold, features,right[node], tabdepth+1)
print '\t' * tabdepth,
print "}"
else:
print '\t' * tabdepth,
print "return " + str(value[node])
recurse(left, right, threshold, features, 0)
Die folgenden Codes sind mein Ansatz unter Anaconda Python 2.7 sowie der Paketname "pydot-ng" zum Erstellen einer PDF-Datei mit Entscheidungsregeln. Ich hoffe es ist hilfreich.
from sklearn import tree
clf = tree.DecisionTreeClassifier(max_leaf_nodes=n)
clf_ = clf.fit(X, data_y)
feature_names = X.columns
class_name = clf_.classes_.astype(int).astype(str)
def output_pdf(clf_, name):
from sklearn import tree
from sklearn.externals.six import StringIO
import pydot_ng as pydot
dot_data = StringIO()
tree.export_graphviz(clf_, out_file=dot_data,
feature_names=feature_names,
class_names=class_name,
filled=True, rounded=True,
special_characters=True,
node_ids=1,)
graph = pydot.graph_from_dot_data(dot_data.getvalue())
graph.write_pdf("%s.pdf"%name)
output_pdf(clf_, name='filename%s'%n)
Ich habe das durchgemacht, aber ich brauchte die Regeln, um in diesem Format geschrieben zu werden
if A>0.4 then if B<0.2 then if C>0.8 then class='X'
Also habe ich die Antwort von @paulkernfeld (danke) angepasst, die Sie an Ihre Bedürfnisse anpassen können
def tree_to_code(tree, feature_names, Y):
tree_ = tree.tree_
feature_name = [
feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
for i in tree_.feature
]
pathto=dict()
global k
k = 0
def recurse(node, depth, parent):
global k
indent = " " * depth
if tree_.feature[node] != _tree.TREE_UNDEFINED:
name = feature_name[node]
threshold = tree_.threshold[node]
s= "{} <= {} ".format( name, threshold, node )
if node == 0:
pathto[node]=s
else:
pathto[node]=pathto[parent]+' & ' +s
recurse(tree_.children_left[node], depth + 1, node)
s="{} > {}".format( name, threshold)
if node == 0:
pathto[node]=s
else:
pathto[node]=pathto[parent]+' & ' +s
recurse(tree_.children_right[node], depth + 1, node)
else:
k=k+1
print(k,')',pathto[parent], tree_.value[node])
recurse(0, 1, 0)
Hier ist eine Möglichkeit, den gesamten Baum mithilfe der SKompiler- Bibliothek in einen einzelnen (nicht unbedingt zu lesbaren) Python-Ausdruck zu übersetzen :
from skompiler import skompile
skompile(dtree.predict).to('python/code')
Dies baut auf der Antwort von @paulkernfeld auf. Wenn Sie einen Datenrahmen X mit Ihren Funktionen und einen Zieldatenrahmen y mit Ihren Antworten haben und eine Vorstellung davon erhalten möchten, welcher y-Wert in welchem Knoten endete (und auch, um ihn entsprechend zu zeichnen), können Sie Folgendes tun:
def tree_to_code(tree, feature_names):
from sklearn.tree import _tree
codelines = []
codelines.append('def get_cat(X_tmp):\n')
codelines.append(' catout = []\n')
codelines.append(' for codelines in range(0,X_tmp.shape[0]):\n')
codelines.append(' Xin = X_tmp.iloc[codelines]\n')
tree_ = tree.tree_
feature_name = [
feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
for i in tree_.feature
]
#print "def tree({}):".format(", ".join(feature_names))
def recurse(node, depth):
indent = " " * depth
if tree_.feature[node] != _tree.TREE_UNDEFINED:
name = feature_name[node]
threshold = tree_.threshold[node]
codelines.append ('{}if Xin["{}"] <= {}:\n'.format(indent, name, threshold))
recurse(tree_.children_left[node], depth + 1)
codelines.append( '{}else: # if Xin["{}"] > {}\n'.format(indent, name, threshold))
recurse(tree_.children_right[node], depth + 1)
else:
codelines.append( '{}mycat = {}\n'.format(indent, node))
recurse(0, 1)
codelines.append(' catout.append(mycat)\n')
codelines.append(' return pd.DataFrame(catout,index=X_tmp.index,columns=["category"])\n')
codelines.append('node_ids = get_cat(X)\n')
return codelines
mycode = tree_to_code(clf,X.columns.values)
# now execute the function and obtain the dataframe with all nodes
exec(''.join(mycode))
node_ids = [int(x[0]) for x in node_ids.values]
node_ids2 = pd.DataFrame(node_ids)
print('make plot')
import matplotlib.cm as cm
colors = cm.rainbow(np.linspace(0, 1, 1+max( list(set(node_ids)))))
#plt.figure(figsize=cm2inch(24, 21))
for i in list(set(node_ids)):
plt.plot(y[node_ids2.values==i],'o',color=colors[i], label=str(i))
mytitle = ['y colored by node']
plt.title(mytitle ,fontsize=14)
plt.xlabel('my xlabel')
plt.ylabel(tagname)
plt.xticks(rotation=70)
plt.legend(loc='upper center', bbox_to_anchor=(0.5, 1.00), shadow=True, ncol=9)
plt.tight_layout()
plt.show()
plt.close
nicht die eleganteste Version, aber es macht den Job ...
Ich habe den beliebtesten Code so geändert, dass er in einem Jupyter Notebook Python 3 korrekt eingerückt ist
import numpy as np
from sklearn.tree import _tree
def tree_to_code(tree, feature_names):
tree_ = tree.tree_
feature_name = [feature_names[i]
if i != _tree.TREE_UNDEFINED else "undefined!"
for i in tree_.feature]
print("def tree({}):".format(", ".join(feature_names)))
def recurse(node, depth):
indent = " " * depth
if tree_.feature[node] != _tree.TREE_UNDEFINED:
name = feature_name[node]
threshold = tree_.threshold[node]
print("{}if {} <= {}:".format(indent, name, threshold))
recurse(tree_.children_left[node], depth + 1)
print("{}else: # if {} > {}".format(indent, name, threshold))
recurse(tree_.children_right[node], depth + 1)
else:
print("{}return {}".format(indent, np.argmax(tree_.value[node])))
recurse(0, 1)
Hier ist eine Funktion zum Drucken von Regeln eines Scikit-Learn-Entscheidungsbaums unter Python 3 und mit Offsets für bedingte Blöcke, um die Struktur besser lesbar zu machen:
def print_decision_tree(tree, feature_names=None, offset_unit=' '):
'''Plots textual representation of rules of a decision tree
tree: scikit-learn representation of tree
feature_names: list of feature names. They are set to f1,f2,f3,... if not specified
offset_unit: a string of offset of the conditional block'''
left = tree.tree_.children_left
right = tree.tree_.children_right
threshold = tree.tree_.threshold
value = tree.tree_.value
if feature_names is None:
features = ['f%d'%i for i in tree.tree_.feature]
else:
features = [feature_names[i] for i in tree.tree_.feature]
def recurse(left, right, threshold, features, node, depth=0):
offset = offset_unit*depth
if (threshold[node] != -2):
print(offset+"if ( " + features[node] + " <= " + str(threshold[node]) + " ) {")
if left[node] != -1:
recurse (left, right, threshold, features,left[node],depth+1)
print(offset+"} else {")
if right[node] != -1:
recurse (left, right, threshold, features,right[node],depth+1)
print(offset+"}")
else:
print(offset+"return " + str(value[node]))
recurse(left, right, threshold, features, 0,0)
Sie können es auch informativer gestalten, indem Sie unterscheiden, zu welcher Klasse es gehört, oder indem Sie sogar den Ausgabewert angeben.
def print_decision_tree(tree, feature_names, offset_unit=' '):
left = tree.tree_.children_left
right = tree.tree_.children_right
threshold = tree.tree_.threshold
value = tree.tree_.value
if feature_names is None:
features = ['f%d'%i for i in tree.tree_.feature]
else:
features = [feature_names[i] for i in tree.tree_.feature]
def recurse(left, right, threshold, features, node, depth=0):
offset = offset_unit*depth
if (threshold[node] != -2):
print(offset+"if ( " + features[node] + " <= " + str(threshold[node]) + " ) {")
if left[node] != -1:
recurse (left, right, threshold, features,left[node],depth+1)
print(offset+"} else {")
if right[node] != -1:
recurse (left, right, threshold, features,right[node],depth+1)
print(offset+"}")
else:
#print(offset,value[node])
#To remove values from node
temp=str(value[node])
mid=len(temp)//2
tempx=[]
tempy=[]
cnt=0
for i in temp:
if cnt<=mid:
tempx.append(i)
cnt+=1
else:
tempy.append(i)
cnt+=1
val_yes=[]
val_no=[]
res=[]
for j in tempx:
if j=="[" or j=="]" or j=="." or j==" ":
res.append(j)
else:
val_no.append(j)
for j in tempy:
if j=="[" or j=="]" or j=="." or j==" ":
res.append(j)
else:
val_yes.append(j)
val_yes = int("".join(map(str, val_yes)))
val_no = int("".join(map(str, val_no)))
if val_yes>val_no:
print(offset,'\033[1m',"YES")
print('\033[0m')
elif val_no>val_yes:
print(offset,'\033[1m',"NO")
print('\033[0m')
else:
print(offset,'\033[1m',"Tie")
print('\033[0m')
recurse(left, right, threshold, features, 0,0)
Hier ist mein Ansatz, um die Entscheidungsregeln in einer Form zu extrahieren, die direkt in SQL verwendet werden kann, damit die Daten nach Knoten gruppiert werden können. (Basierend auf den Ansätzen früherer Poster.)
Das Ergebnis sind nachfolgende CASE
Klauseln, die in eine SQL-Anweisung kopiert werden können, z.
SELECT COALESCE(*CASE WHEN <conditions> THEN > <NodeA>*, > *CASE WHEN
<conditions> THEN <NodeB>*, > ....)NodeName,* > FROM <table or view>
import numpy as np
import pickle
feature_names=.............
features = [feature_names[i] for i in range(len(feature_names))]
clf= pickle.loads(trained_model)
impurity=clf.tree_.impurity
importances = clf.feature_importances_
SqlOut=""
#global Conts
global ContsNode
global Path
#Conts=[]#
ContsNode=[]
Path=[]
global Results
Results=[]
def print_decision_tree(tree, feature_names, offset_unit='' ''):
left = tree.tree_.children_left
right = tree.tree_.children_right
threshold = tree.tree_.threshold
value = tree.tree_.value
if feature_names is None:
features = [''f%d''%i for i in tree.tree_.feature]
else:
features = [feature_names[i] for i in tree.tree_.feature]
def recurse(left, right, threshold, features, node, depth=0,ParentNode=0,IsElse=0):
global Conts
global ContsNode
global Path
global Results
global LeftParents
LeftParents=[]
global RightParents
RightParents=[]
for i in range(len(left)): # This is just to tell you how to create a list.
LeftParents.append(-1)
RightParents.append(-1)
ContsNode.append("")
Path.append("")
for i in range(len(left)): # i is node
if (left[i]==-1 and right[i]==-1):
if LeftParents[i]>=0:
if Path[LeftParents[i]]>" ":
Path[i]=Path[LeftParents[i]]+" AND " +ContsNode[LeftParents[i]]
else:
Path[i]=ContsNode[LeftParents[i]]
if RightParents[i]>=0:
if Path[RightParents[i]]>" ":
Path[i]=Path[RightParents[i]]+" AND not " +ContsNode[RightParents[i]]
else:
Path[i]=" not " +ContsNode[RightParents[i]]
Results.append(" case when " +Path[i]+" then ''" +"{:4d}".format(i)+ " "+"{:2.2f}".format(impurity[i])+" "+Path[i][0:180]+"''")
else:
if LeftParents[i]>=0:
if Path[LeftParents[i]]>" ":
Path[i]=Path[LeftParents[i]]+" AND " +ContsNode[LeftParents[i]]
else:
Path[i]=ContsNode[LeftParents[i]]
if RightParents[i]>=0:
if Path[RightParents[i]]>" ":
Path[i]=Path[RightParents[i]]+" AND not " +ContsNode[RightParents[i]]
else:
Path[i]=" not "+ContsNode[RightParents[i]]
if (left[i]!=-1):
LeftParents[left[i]]=i
if (right[i]!=-1):
RightParents[right[i]]=i
ContsNode[i]= "( "+ features[i] + " <= " + str(threshold[i]) + " ) "
recurse(left, right, threshold, features, 0,0,0,0)
print_decision_tree(clf,features)
SqlOut=""
for i in range(len(Results)):
SqlOut=SqlOut+Results[i]+ " end,"+chr(13)+chr(10)
Jetzt können Sie export_text verwenden.
from sklearn.tree import export_text
r = export_text(loan_tree, feature_names=(list(X_train.columns)))
print(r)
Ein vollständiges Beispiel aus [sklearn] [1]
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import export_text
iris = load_iris()
X = iris['data']
y = iris['target']
decision_tree = DecisionTreeClassifier(random_state=0, max_depth=2)
decision_tree = decision_tree.fit(X, y)
r = export_text(decision_tree, feature_names=iris['feature_names'])
print(r)
Der Code von Zelazny7 wurde geändert, um SQL aus dem Entscheidungsbaum abzurufen.
# SQL from decision tree
def get_lineage(tree, feature_names):
left = tree.tree_.children_left
right = tree.tree_.children_right
threshold = tree.tree_.threshold
features = [feature_names[i] for i in tree.tree_.feature]
le='<='
g ='>'
# get ids of child nodes
idx = np.argwhere(left == -1)[:,0]
def recurse(left, right, child, lineage=None):
if lineage is None:
lineage = [child]
if child in left:
parent = np.where(left == child)[0].item()
split = 'l'
else:
parent = np.where(right == child)[0].item()
split = 'r'
lineage.append((parent, split, threshold[parent], features[parent]))
if parent == 0:
lineage.reverse()
return lineage
else:
return recurse(left, right, parent, lineage)
print 'case '
for j,child in enumerate(idx):
clause=' when '
for node in recurse(left, right, child):
if len(str(node))<3:
continue
i=node
if i[1]=='l': sign=le
else: sign=g
clause=clause+i[3]+sign+str(i[2])+' and '
clause=clause[:-4]+' then '+str(j)
print clause
print 'else 99 end as clusters'
Anscheinend hat sich schon vor langer Zeit jemand entschlossen, die folgenden Funktionen zu den Baum-Exportfunktionen des offiziellen Scikits hinzuzufügen (die grundsätzlich nur export_graphviz unterstützen).
def export_dict(tree, feature_names=None, max_depth=None) :
"""Export a decision tree in dict format.
Hier ist sein volles Engagement:
Ich bin mir nicht ganz sicher, was mit diesem Kommentar passiert ist. Sie können aber auch versuchen, diese Funktion zu verwenden.
Ich denke, dies rechtfertigt eine ernsthafte Dokumentationsanfrage an die guten Leute von scikit-learn, um die sklearn.tree.Tree
API, die die zugrunde liegende Baumstruktur darstellt, die DecisionTreeClassifier
als Attribut verfügbar gemacht wird, ordnungsgemäß zu dokumentieren tree_
.
Verwenden Sie einfach die Funktion von sklearn.tree wie folgt
from sklearn.tree import export_graphviz
export_graphviz(tree,
out_file = "tree.dot",
feature_names = tree.columns) //or just ["petal length", "petal width"]
Und dann suchen Sie in Ihrem Projektordner nach der Datei tree.dot , kopieren Sie den gesamten Inhalt und fügen Sie ihn hier ein http://www.webgraphviz.com/ und generieren Sie Ihr Diagramm :)
Vielen Dank für die wunderbare Lösung von @paulkerfeld. Hinzu kommt seine Lösung für all diejenigen , die eine serialisierte Version von Bäumen haben möchten, benutzen Sie einfach tree.threshold
, tree.children_left
, tree.children_right
, tree.feature
und tree.value
. Da die Blätter keine Spalten und daher keine Merkmalsnamen und Kinder haben, ist ihr Platzhalter in tree.feature
und tree.children_***
sind _tree.TREE_UNDEFINED
und _tree.TREE_LEAF
. Jedem Split wird ein eindeutiger Index zugewiesen depth first search
.
Beachten Sie, dass das tree.value
von Form ist[n, 1, 1]
Hier ist eine Funktion, die Python-Code aus einem Entscheidungsbaum generiert, indem sie die Ausgabe von konvertiert export_text
:
import string
from sklearn.tree import export_text
def export_py_code(tree, feature_names, max_depth=100, spacing=4):
if spacing < 2:
raise ValueError('spacing must be > 1')
# Clean up feature names (for correctness)
nums = string.digits
alnums = string.ascii_letters + nums
clean = lambda s: ''.join(c if c in alnums else '_' for c in s)
features = [clean(x) for x in feature_names]
features = ['_'+x if x[0] in nums else x for x in features if x]
if len(set(features)) != len(feature_names):
raise ValueError('invalid feature names')
# First: export tree to text
res = export_text(tree, feature_names=features,
max_depth=max_depth,
decimals=6,
spacing=spacing-1)
# Second: generate Python code from the text
skip, dash = ' '*spacing, '-'*(spacing-1)
code = 'def decision_tree({}):\n'.format(', '.join(features))
for line in repr(tree).split('\n'):
code += skip + "# " + line + '\n'
for line in res.split('\n'):
line = line.rstrip().replace('|',' ')
if '<' in line or '>' in line:
line, val = line.rsplit(maxsplit=1)
line = line.replace(' ' + dash, 'if')
line = '{} {:g}:'.format(line, float(val))
else:
line = line.replace(' {} class:'.format(dash), 'return')
code += skip + line + '\n'
return code
Beispielnutzung:
res = export_py_code(tree, feature_names=names, spacing=4)
print (res)
Beispielausgabe:
def decision_tree(f1, f2, f3):
# DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=3,
# max_features=None, max_leaf_nodes=None,
# min_impurity_decrease=0.0, min_impurity_split=None,
# min_samples_leaf=1, min_samples_split=2,
# min_weight_fraction_leaf=0.0, presort=False,
# random_state=42, splitter='best')
if f1 <= 12.5:
if f2 <= 17.5:
if f1 <= 10.5:
return 2
if f1 > 10.5:
return 3
if f2 > 17.5:
if f2 <= 22.5:
return 1
if f2 > 22.5:
return 1
if f1 > 12.5:
if f1 <= 17.5:
if f3 <= 23.5:
return 2
if f3 > 23.5:
return 3
if f1 > 17.5:
if f1 <= 25:
return 1
if f1 > 25:
return 2
Das obige Beispiel wird mit generiert names = ['f'+str(j+1) for j in range(NUM_FEATURES)]
.
Ein praktisches Feature ist, dass es kleinere Dateigrößen mit reduziertem Abstand erzeugen kann. Einfach einstellen spacing=2
.