Bagaimana cara mengekstrak aturan keputusan dari scikit-learn decision-tree?

157

Dapatkah saya mengekstrak aturan keputusan yang mendasari (atau 'jalur keputusan') dari pohon yang terlatih dalam pohon keputusan sebagai daftar tekstual?

Sesuatu seperti:

if A>0.4 then if B<0.2 then if C>0.8 then class='X'

Terima kasih atas bantuan Anda.

Dror Hilman
sumber
Apakah Anda pernah menemukan jawaban untuk masalah ini? Saya harus mengekspor aturan pohon keputusan dalam format langkah data SAS yang hampir persis seperti yang Anda daftarkan.
Zelazny7
1
Anda dapat menggunakan paket sklearn-porter untuk mengekspor dan mentransformasikan pohon keputusan (juga hutan acak dan pohon yang ditingkatkan) ke C, Java, JavaScript dan lainnya.
Darius
Anda dapat memeriksa tautan ini
yogesh agrawal

Jawaban:

139

Saya percaya bahwa jawaban ini lebih benar daripada jawaban lain di sini:

from sklearn.tree import _tree

def tree_to_code(tree, feature_names):
    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]
    print "def tree({}):".format(", ".join(feature_names))

    def recurse(node, depth):
        indent = "  " * depth
        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            print "{}if {} <= {}:".format(indent, name, threshold)
            recurse(tree_.children_left[node], depth + 1)
            print "{}else:  # if {} > {}".format(indent, name, threshold)
            recurse(tree_.children_right[node], depth + 1)
        else:
            print "{}return {}".format(indent, tree_.value[node])

    recurse(0, 1)

Ini mencetak fungsi Python yang valid. Berikut ini contoh output untuk pohon yang mencoba mengembalikan inputnya, angka antara 0 dan 10.

def tree(f0):
  if f0 <= 6.0:
    if f0 <= 1.5:
      return [[ 0.]]
    else:  # if f0 > 1.5
      if f0 <= 4.5:
        if f0 <= 3.5:
          return [[ 3.]]
        else:  # if f0 > 3.5
          return [[ 4.]]
      else:  # if f0 > 4.5
        return [[ 5.]]
  else:  # if f0 > 6.0
    if f0 <= 8.5:
      if f0 <= 7.5:
        return [[ 7.]]
      else:  # if f0 > 7.5
        return [[ 8.]]
    else:  # if f0 > 8.5
      return [[ 9.]]

Berikut adalah beberapa batu sandungan yang saya lihat di jawaban lain:

  1. Menggunakan tree_.threshold == -2untuk memutuskan apakah suatu simpul adalah daun bukanlah ide yang baik. Bagaimana jika itu adalah simpul keputusan nyata dengan ambang -2? Sebaliknya, Anda harus melihat tree.featureatau tree.children_*.
  2. Baris features = [feature_names[i] for i in tree_.feature]crash dengan versi sklearn saya, karena beberapa nilai tree.tree_.feature-2 (khusus untuk node daun).
  3. Tidak perlu memiliki beberapa pernyataan if dalam fungsi rekursif, cukup satu saja.
paulkernfeld
sumber
1
Kode ini sangat bagus untuk saya. Namun, saya memiliki 500+ fitur_names sehingga kode output hampir tidak mungkin dipahami manusia. Apakah ada cara untuk membiarkan saya hanya memasukkan nama-nama yang saya ingin tahu tentang ke dalam fungsi?
user3768495
1
Saya setuju dengan komentar sebelumnya. IIUC, print "{}return {}".format(indent, tree_.value[node])harus diubah ke print "{}return {}".format(indent, np.argmax(tree_.value[node][0]))untuk fungsi untuk mengembalikan indeks kelas.
soupault
1
@ Paulkernfeld Ah ya, saya melihat Anda bisa mengulang RandomForestClassifier.estimators_, tapi saya tidak bisa mengetahui cara menggabungkan hasil estimator.
Nathan Lloyd
6
Saya tidak bisa menjalankan ini dalam python 3, bit _tree sepertinya tidak pernah bekerja dan TREE_UNDEFINED tidak didefinisikan. Tautan ini membantu saya. Meskipun kode yang diekspor tidak dapat dijalankan secara langsung dalam python, kode ini mirip c dan cukup mudah untuk diterjemahkan ke bahasa lain: web.archive.org/web/20171005203850/http://www.kdnuggets.com/…
Josiah
1
@ Yosia, tambahkan () ke pernyataan cetak untuk membuatnya bekerja di python3. eg print "bla"=>print("bla")
Nir
48

Saya membuat fungsi saya sendiri untuk mengekstrak aturan dari pohon keputusan yang dibuat oleh sklearn:

import pandas as pd
import numpy as np
from sklearn.tree import DecisionTreeClassifier

# dummy data:
df = pd.DataFrame({'col1':[0,1,2,3],'col2':[3,4,5,6],'dv':[0,1,0,1]})

# create decision tree
dt = DecisionTreeClassifier(max_depth=5, min_samples_leaf=1)
dt.fit(df.ix[:,:2], df.dv)

Fungsi ini dimulai dengan node (diidentifikasi oleh -1 pada array anak) dan kemudian secara rekursif menemukan orang tua. Saya menyebutnya 'garis silsilah' simpul. Sepanjang jalan, saya ambil nilai-nilai yang perlu saya buat jika / kemudian / logika SAS:

def get_lineage(tree, feature_names):
     left      = tree.tree_.children_left
     right     = tree.tree_.children_right
     threshold = tree.tree_.threshold
     features  = [feature_names[i] for i in tree.tree_.feature]

     # get ids of child nodes
     idx = np.argwhere(left == -1)[:,0]     

     def recurse(left, right, child, lineage=None):          
          if lineage is None:
               lineage = [child]
          if child in left:
               parent = np.where(left == child)[0].item()
               split = 'l'
          else:
               parent = np.where(right == child)[0].item()
               split = 'r'

          lineage.append((parent, split, threshold[parent], features[parent]))

          if parent == 0:
               lineage.reverse()
               return lineage
          else:
               return recurse(left, right, parent, lineage)

     for child in idx:
          for node in recurse(left, right, child):
               print node

Kumpulan tupel di bawah ini berisi semua yang saya butuhkan untuk membuat SAS if / then / else statement. Saya tidak suka menggunakan doblok di SAS yang mengapa saya membuat logika yang menggambarkan seluruh jalur node. Integer tunggal setelah tupel adalah ID dari simpul terminal di jalur. Semua tupel sebelumnya bergabung untuk membuat simpul itu.

In [1]: get_lineage(dt, df.columns)
(0, 'l', 0.5, 'col1')
1
(0, 'r', 0.5, 'col1')
(2, 'l', 4.5, 'col2')
3
(0, 'r', 0.5, 'col1')
(2, 'r', 4.5, 'col2')
(4, 'l', 2.5, 'col1')
5
(0, 'r', 0.5, 'col1')
(2, 'r', 4.5, 'col2')
(4, 'r', 2.5, 'col1')
6

Output GraphViz dari pohon contoh

Zelazny7
sumber
apakah jenis pohon ini benar karena col1 akan kembali lagi satu adalah col1 <= 0,50000 dan satu col1 <= 2,5000 jika ya, apakah ini jenis rekursi yang digunakan di perpustakaan
jayant singh
cabang kanan akan memiliki catatan di antaranya (0.5, 2.5]. Pohon-pohon dibuat dengan partisi rekursif. Tidak ada yang mencegah suatu variabel dipilih berulang kali.
Zelazny7
oke bisa Anda jelaskan bagian rekursi apa yang terjadi karena saya telah menggunakannya dalam kode saya dan hasil yang serupa terlihat
jayant singh
38

Saya memodifikasi kode yang dikirimkan oleh Zelazny7 untuk mencetak beberapa pseudocode:

def get_code(tree, feature_names):
        left      = tree.tree_.children_left
        right     = tree.tree_.children_right
        threshold = tree.tree_.threshold
        features  = [feature_names[i] for i in tree.tree_.feature]
        value = tree.tree_.value

        def recurse(left, right, threshold, features, node):
                if (threshold[node] != -2):
                        print "if ( " + features[node] + " <= " + str(threshold[node]) + " ) {"
                        if left[node] != -1:
                                recurse (left, right, threshold, features,left[node])
                        print "} else {"
                        if right[node] != -1:
                                recurse (left, right, threshold, features,right[node])
                        print "}"
                else:
                        print "return " + str(value[node])

        recurse(left, right, threshold, features, 0)

jika Anda memanggil get_code(dt, df.columns)contoh yang sama Anda akan mendapatkan:

if ( col1 <= 0.5 ) {
return [[ 1.  0.]]
} else {
if ( col2 <= 4.5 ) {
return [[ 0.  1.]]
} else {
if ( col1 <= 2.5 ) {
return [[ 1.  0.]]
} else {
return [[ 0.  1.]]
}
}
}
Daniele
sumber
1
Dapatkah Anda memberi tahu, apa tepatnya [[1. 0.]] dalam pernyataan pengembalian artinya dalam output di atas. Saya bukan orang Python, tetapi mengerjakan hal yang sama. Jadi itu akan baik bagi saya jika Anda membuktikan beberapa detail sehingga akan lebih mudah bagi saya.
Subhradip Bose
1
@ user3156186 Berarti ada satu objek di kelas '0' dan nol objek di kelas '1'
Daniele
1
@Aniele, apakah Anda tahu bagaimana kelas-kelas itu dipesan? Saya kira alfanumerik, tetapi saya belum menemukan konfirmasi di mana pun.
IanS
Terima kasih! Untuk skenario tepi kasus di mana nilai ambang sebenarnya -2, kita mungkin perlu mengubah (threshold[node] != -2)ke ( left[node] != -1)(mirip dengan metode di bawah ini untuk mendapatkan id dari simpul anak)
tlingf
@Daniele, ada ide bagaimana membuat fungsi Anda "get_code" "mengembalikan" nilai dan tidak "mencetak" itu, karena saya perlu mengirimnya ke fungsi lain?
RoyaumeIX
17

Scikit belajar memperkenalkan metode baru yang lezat yang disebut export_textdalam versi 0.21 (Mei 2019) untuk mengekstrak aturan dari pohon. Dokumentasi di sini . Tidak perlu lagi membuat fungsi khusus.

Setelah sesuai dengan model Anda, Anda hanya perlu dua baris kode. Pertama, impor export_text:

from sklearn.tree.export import export_text

Kedua, buat objek yang akan berisi aturan Anda. Untuk membuat aturan terlihat lebih mudah dibaca, gunakan feature_namesargumen dan berikan daftar nama fitur Anda. Misalnya, jika model Anda dipanggil modeldan fitur Anda dinamai dalam kerangka data yang disebut X_train, Anda bisa membuat objek yang disebut tree_rules:

tree_rules = export_text(model, feature_names=list(X_train))

Kemudian cukup cetak atau simpan tree_rules. Output Anda akan terlihat seperti ini:

|--- Age <= 0.63
|   |--- EstimatedSalary <= 0.61
|   |   |--- Age <= -0.16
|   |   |   |--- class: 0
|   |   |--- Age >  -0.16
|   |   |   |--- EstimatedSalary <= -0.06
|   |   |   |   |--- class: 0
|   |   |   |--- EstimatedSalary >  -0.06
|   |   |   |   |--- EstimatedSalary <= 0.40
|   |   |   |   |   |--- EstimatedSalary <= 0.03
|   |   |   |   |   |   |--- class: 1
yzerman
sumber
14

Ada DecisionTreeClassifiermetode baru decision_path,, dalam rilis 0.18.0 . Para pengembang menyediakan luas (terdokumentasi) walkthrough .

Bagian pertama dari kode dalam panduan yang mencetak struktur pohon tampaknya OK. Namun, saya memodifikasi kode di bagian kedua untuk menginterogasi satu sampel. Perubahan saya dilambangkan dengan# <--

Sunting Perubahan yang ditandai oleh # <--dalam kode di bawah ini telah diperbarui di tautan langkah-langkah setelah kesalahan ditunjukkan dalam permintaan tarik # 8653 dan # 10951 . Jauh lebih mudah untuk diikuti sekarang.

sample_id = 0
node_index = node_indicator.indices[node_indicator.indptr[sample_id]:
                                    node_indicator.indptr[sample_id + 1]]

print('Rules used to predict sample %s: ' % sample_id)
for node_id in node_index:

    if leave_id[sample_id] == node_id:  # <-- changed != to ==
        #continue # <-- comment out
        print("leaf node {} reached, no decision here".format(leave_id[sample_id])) # <--

    else: # < -- added else to iterate through decision nodes
        if (X_test[sample_id, feature[node_id]] <= threshold[node_id]):
            threshold_sign = "<="
        else:
            threshold_sign = ">"

        print("decision id node %s : (X[%s, %s] (= %s) %s %s)"
              % (node_id,
                 sample_id,
                 feature[node_id],
                 X_test[sample_id, feature[node_id]], # <-- changed i to sample_id
                 threshold_sign,
                 threshold[node_id]))

Rules used to predict sample 0: 
decision id node 0 : (X[0, 3] (= 2.4) > 0.800000011921)
decision id node 2 : (X[0, 2] (= 5.1) > 4.94999980927)
leaf node 4 reached, no decision here

Ubah sample_iduntuk melihat jalur keputusan untuk sampel lain. Saya belum bertanya kepada pengembang tentang perubahan ini, hanya tampak lebih intuitif ketika mengerjakan contoh.

Kevin
sumber
Anda teman saya adalah legenda! ada ide bagaimana merencanakan pohon keputusan untuk sampel tertentu? banyak bantuan dihargai
1
Terima kasih Victor, mungkin yang terbaik untuk mengajukan ini sebagai pertanyaan terpisah karena memplot persyaratan dapat spesifik untuk kebutuhan pengguna. Anda mungkin akan mendapatkan respons yang baik jika Anda ingin tahu seperti apa bentuk output yang Anda inginkan.
Kevin
hei kevin, saya membuat pertanyaan stackoverflow.com/questions/48888893/...
apakah Anda akan berbaik hati melihat: stackoverflow.com/questions/52654280/…
Alexander Chervov
Bisakah Anda jelaskan bagian yang disebut node_index, tidak mendapatkan bagian itu. apa fungsinya?
Anindya Sankar Dey
12
from StringIO import StringIO
out = StringIO()
out = tree.export_graphviz(clf, out_file=out)
print out.getvalue()

Anda dapat melihat Pohon digraf. Kemudian, clf.tree_.featuredan clf.tree_.valueadalah array fitur pembelahan node dan array nilai node masing-masing. Anda dapat merujuk lebih detail dari sumber github ini .

lennon310
sumber
1
Ya, saya tahu cara menggambar pohon - tetapi saya membutuhkan versi yang lebih tekstual - aturan. sesuatu seperti: orange.biolab.si/docs/latest/reference/rst/…
Dror Hilman
4

Hanya karena semua orang sangat membantu saya hanya akan menambahkan modifikasi untuk solusi Zelazny7 dan Daniele yang indah. Yang ini untuk python 2.7, dengan tab agar lebih mudah dibaca:

def get_code(tree, feature_names, tabdepth=0):
    left      = tree.tree_.children_left
    right     = tree.tree_.children_right
    threshold = tree.tree_.threshold
    features  = [feature_names[i] for i in tree.tree_.feature]
    value = tree.tree_.value

    def recurse(left, right, threshold, features, node, tabdepth=0):
            if (threshold[node] != -2):
                    print '\t' * tabdepth,
                    print "if ( " + features[node] + " <= " + str(threshold[node]) + " ) {"
                    if left[node] != -1:
                            recurse (left, right, threshold, features,left[node], tabdepth+1)
                    print '\t' * tabdepth,
                    print "} else {"
                    if right[node] != -1:
                            recurse (left, right, threshold, features,right[node], tabdepth+1)
                    print '\t' * tabdepth,
                    print "}"
            else:
                    print '\t' * tabdepth,
                    print "return " + str(value[node])

    recurse(left, right, threshold, features, 0)
Ruslan
sumber
3

Kode di bawah ini adalah pendekatan saya di bawah anaconda python 2.7 plus nama paket "pydot-ng" untuk membuat file PDF dengan aturan keputusan. Saya harap ini membantu.

from sklearn import tree

clf = tree.DecisionTreeClassifier(max_leaf_nodes=n)
clf_ = clf.fit(X, data_y)

feature_names = X.columns
class_name = clf_.classes_.astype(int).astype(str)

def output_pdf(clf_, name):
    from sklearn import tree
    from sklearn.externals.six import StringIO
    import pydot_ng as pydot
    dot_data = StringIO()
    tree.export_graphviz(clf_, out_file=dot_data,
                         feature_names=feature_names,
                         class_names=class_name,
                         filled=True, rounded=True,
                         special_characters=True,
                          node_ids=1,)
    graph = pydot.graph_from_dot_data(dot_data.getvalue())
    graph.write_pdf("%s.pdf"%name)

output_pdf(clf_, name='filename%s'%n)

pertunjukan grafik pohon di sini

TED Zhao
sumber
3

Saya telah melalui ini, tetapi saya membutuhkan aturan untuk ditulis dalam format ini

if A>0.4 then if B<0.2 then if C>0.8 then class='X' 

Jadi saya mengadaptasi jawaban @paulkernfeld (terima kasih) yang dapat Anda sesuaikan dengan kebutuhan Anda

def tree_to_code(tree, feature_names, Y):
    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]
    pathto=dict()

    global k
    k = 0
    def recurse(node, depth, parent):
        global k
        indent = "  " * depth

        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            s= "{} <= {} ".format( name, threshold, node )
            if node == 0:
                pathto[node]=s
            else:
                pathto[node]=pathto[parent]+' & ' +s

            recurse(tree_.children_left[node], depth + 1, node)
            s="{} > {}".format( name, threshold)
            if node == 0:
                pathto[node]=s
            else:
                pathto[node]=pathto[parent]+' & ' +s
            recurse(tree_.children_right[node], depth + 1, node)
        else:
            k=k+1
            print(k,')',pathto[parent], tree_.value[node])
    recurse(0, 1, 0)
Ala Ham
sumber
3

Berikut adalah cara untuk menerjemahkan seluruh pohon menjadi ekspresi python tunggal (tidak harus terlalu bisa dibaca manusia) menggunakan perpustakaan SKompiler :

from skompiler import skompile
skompile(dtree.predict).to('python/code')
KT.
sumber
3

Ini didasarkan pada jawaban @ paulkernfeld. Jika Anda memiliki kerangka data X dengan fitur-fitur Anda dan kerangka data target y dengan resonses Anda dan Anda ingin mendapatkan ide mana nilai y berakhir di simpul mana (dan juga semut untuk memplotnya), Anda dapat melakukan hal berikut:

    def tree_to_code(tree, feature_names):
        from sklearn.tree import _tree
        codelines = []
        codelines.append('def get_cat(X_tmp):\n')
        codelines.append('   catout = []\n')
        codelines.append('   for codelines in range(0,X_tmp.shape[0]):\n')
        codelines.append('      Xin = X_tmp.iloc[codelines]\n')
        tree_ = tree.tree_
        feature_name = [
            feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
            for i in tree_.feature
        ]
        #print "def tree({}):".format(", ".join(feature_names))

        def recurse(node, depth):
            indent = "      " * depth
            if tree_.feature[node] != _tree.TREE_UNDEFINED:
                name = feature_name[node]
                threshold = tree_.threshold[node]
                codelines.append ('{}if Xin["{}"] <= {}:\n'.format(indent, name, threshold))
                recurse(tree_.children_left[node], depth + 1)
                codelines.append( '{}else:  # if Xin["{}"] > {}\n'.format(indent, name, threshold))
                recurse(tree_.children_right[node], depth + 1)
            else:
                codelines.append( '{}mycat = {}\n'.format(indent, node))

        recurse(0, 1)
        codelines.append('      catout.append(mycat)\n')
        codelines.append('   return pd.DataFrame(catout,index=X_tmp.index,columns=["category"])\n')
        codelines.append('node_ids = get_cat(X)\n')
        return codelines
    mycode = tree_to_code(clf,X.columns.values)

    # now execute the function and obtain the dataframe with all nodes
    exec(''.join(mycode))
    node_ids = [int(x[0]) for x in node_ids.values]
    node_ids2 = pd.DataFrame(node_ids)

    print('make plot')
    import matplotlib.cm as cm
    colors = cm.rainbow(np.linspace(0, 1, 1+max( list(set(node_ids)))))
    #plt.figure(figsize=cm2inch(24, 21))
    for i in list(set(node_ids)):
        plt.plot(y[node_ids2.values==i],'o',color=colors[i], label=str(i))  
    mytitle = ['y colored by node']
    plt.title(mytitle ,fontsize=14)
    plt.xlabel('my xlabel')
    plt.ylabel(tagname)
    plt.xticks(rotation=70)       
    plt.legend(loc='upper center', bbox_to_anchor=(0.5, 1.00), shadow=True, ncol=9)
    plt.tight_layout()
    plt.show()
    plt.close 

bukan versi yang paling elegan tetapi berfungsi ...

sepatu kuda
sumber
1
Ini adalah pendekatan yang baik ketika Anda ingin mengembalikan baris kode alih-alih hanya mencetaknya.
Hajar Homayouni
3

Ini adalah kode yang Anda butuhkan

Saya telah memodifikasi kode paling disukai untuk indentasi di jupyter notebook python 3 dengan benar

import numpy as np
from sklearn.tree import _tree

def tree_to_code(tree, feature_names):
    tree_ = tree.tree_
    feature_name = [feature_names[i] 
                    if i != _tree.TREE_UNDEFINED else "undefined!" 
                    for i in tree_.feature]
    print("def tree({}):".format(", ".join(feature_names)))

    def recurse(node, depth):
        indent = "    " * depth
        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            print("{}if {} <= {}:".format(indent, name, threshold))
            recurse(tree_.children_left[node], depth + 1)
            print("{}else:  # if {} > {}".format(indent, name, threshold))
            recurse(tree_.children_right[node], depth + 1)
        else:
            print("{}return {}".format(indent, np.argmax(tree_.value[node])))

    recurse(0, 1)
Cameron Sorensen
sumber
2

Berikut adalah fungsi, aturan pencetakan pohon keputusan scikit-learn di bawah python 3 dan dengan offset untuk blok bersyarat untuk membuat struktur lebih mudah dibaca:

def print_decision_tree(tree, feature_names=None, offset_unit='    '):
    '''Plots textual representation of rules of a decision tree
    tree: scikit-learn representation of tree
    feature_names: list of feature names. They are set to f1,f2,f3,... if not specified
    offset_unit: a string of offset of the conditional block'''

    left      = tree.tree_.children_left
    right     = tree.tree_.children_right
    threshold = tree.tree_.threshold
    value = tree.tree_.value
    if feature_names is None:
        features  = ['f%d'%i for i in tree.tree_.feature]
    else:
        features  = [feature_names[i] for i in tree.tree_.feature]        

    def recurse(left, right, threshold, features, node, depth=0):
            offset = offset_unit*depth
            if (threshold[node] != -2):
                    print(offset+"if ( " + features[node] + " <= " + str(threshold[node]) + " ) {")
                    if left[node] != -1:
                            recurse (left, right, threshold, features,left[node],depth+1)
                    print(offset+"} else {")
                    if right[node] != -1:
                            recurse (left, right, threshold, features,right[node],depth+1)
                    print(offset+"}")
            else:
                    print(offset+"return " + str(value[node]))

    recurse(left, right, threshold, features, 0,0)
Apogentus
sumber
2

Anda juga dapat membuatnya lebih informatif dengan membedakan kelas mana yang dimiliki atau bahkan dengan menyebutkan nilai outputnya.

def print_decision_tree(tree, feature_names, offset_unit='    '):    
left      = tree.tree_.children_left
right     = tree.tree_.children_right
threshold = tree.tree_.threshold
value = tree.tree_.value
if feature_names is None:
    features  = ['f%d'%i for i in tree.tree_.feature]
else:
    features  = [feature_names[i] for i in tree.tree_.feature]        

def recurse(left, right, threshold, features, node, depth=0):
        offset = offset_unit*depth
        if (threshold[node] != -2):
                print(offset+"if ( " + features[node] + " <= " + str(threshold[node]) + " ) {")
                if left[node] != -1:
                        recurse (left, right, threshold, features,left[node],depth+1)
                print(offset+"} else {")
                if right[node] != -1:
                        recurse (left, right, threshold, features,right[node],depth+1)
                print(offset+"}")
        else:
                #print(offset,value[node]) 

                #To remove values from node
                temp=str(value[node])
                mid=len(temp)//2
                tempx=[]
                tempy=[]
                cnt=0
                for i in temp:
                    if cnt<=mid:
                        tempx.append(i)
                        cnt+=1
                    else:
                        tempy.append(i)
                        cnt+=1
                val_yes=[]
                val_no=[]
                res=[]
                for j in tempx:
                    if j=="[" or j=="]" or j=="." or j==" ":
                        res.append(j)
                    else:
                        val_no.append(j)
                for j in tempy:
                    if j=="[" or j=="]" or j=="." or j==" ":
                        res.append(j)
                    else:
                        val_yes.append(j)
                val_yes = int("".join(map(str, val_yes)))
                val_no = int("".join(map(str, val_no)))

                if val_yes>val_no:
                    print(offset,'\033[1m',"YES")
                    print('\033[0m')
                elif val_no>val_yes:
                    print(offset,'\033[1m',"NO")
                    print('\033[0m')
                else:
                    print(offset,'\033[1m',"Tie")
                    print('\033[0m')

recurse(left, right, threshold, features, 0,0)

masukkan deskripsi gambar di sini

Amit Rautray
sumber
2

Berikut ini adalah pendekatan saya untuk mengekstrak aturan keputusan dalam bentuk yang dapat digunakan secara langsung dalam sql, sehingga data dapat dikelompokkan berdasarkan node. (Berdasarkan pendekatan poster sebelumnya.)

Hasilnya akan menjadi CASEklausa berikutnya yang dapat disalin ke pernyataan sql, mis.

SELECT COALESCE(*CASE WHEN <conditions> THEN > <NodeA>*, > *CASE WHEN <conditions> THEN <NodeB>*, > ....)NodeName,* > FROM <table or view>


import numpy as np

import pickle
feature_names=.............
features  = [feature_names[i] for i in range(len(feature_names))]
clf= pickle.loads(trained_model)
impurity=clf.tree_.impurity
importances = clf.feature_importances_
SqlOut=""

#global Conts
global ContsNode
global Path
#Conts=[]#
ContsNode=[]
Path=[]
global Results
Results=[]

def print_decision_tree(tree, feature_names, offset_unit=''    ''):    
    left      = tree.tree_.children_left
    right     = tree.tree_.children_right
    threshold = tree.tree_.threshold
    value = tree.tree_.value

    if feature_names is None:
        features  = [''f%d''%i for i in tree.tree_.feature]
    else:
        features  = [feature_names[i] for i in tree.tree_.feature]        

    def recurse(left, right, threshold, features, node, depth=0,ParentNode=0,IsElse=0):
        global Conts
        global ContsNode
        global Path
        global Results
        global LeftParents
        LeftParents=[]
        global RightParents
        RightParents=[]
        for i in range(len(left)): # This is just to tell you how to create a list.
            LeftParents.append(-1)
            RightParents.append(-1)
            ContsNode.append("")
            Path.append("")


        for i in range(len(left)): # i is node
            if (left[i]==-1 and right[i]==-1):      
                if LeftParents[i]>=0:
                    if Path[LeftParents[i]]>" ":
                        Path[i]=Path[LeftParents[i]]+" AND " +ContsNode[LeftParents[i]]                                 
                    else:
                        Path[i]=ContsNode[LeftParents[i]]                                   
                if RightParents[i]>=0:
                    if Path[RightParents[i]]>" ":
                        Path[i]=Path[RightParents[i]]+" AND not " +ContsNode[RightParents[i]]                                   
                    else:
                        Path[i]=" not " +ContsNode[RightParents[i]]                     
                Results.append(" case when  " +Path[i]+"  then ''" +"{:4d}".format(i)+ " "+"{:2.2f}".format(impurity[i])+" "+Path[i][0:180]+"''")

            else:       
                if LeftParents[i]>=0:
                    if Path[LeftParents[i]]>" ":
                        Path[i]=Path[LeftParents[i]]+" AND " +ContsNode[LeftParents[i]]                                 
                    else:
                        Path[i]=ContsNode[LeftParents[i]]                                   
                if RightParents[i]>=0:
                    if Path[RightParents[i]]>" ":
                        Path[i]=Path[RightParents[i]]+" AND not " +ContsNode[RightParents[i]]                                   
                    else:
                        Path[i]=" not "+ContsNode[RightParents[i]]                      
                if (left[i]!=-1):
                    LeftParents[left[i]]=i
                if (right[i]!=-1):
                    RightParents[right[i]]=i
                ContsNode[i]=   "( "+ features[i] + " <= " + str(threshold[i])   + " ) "

    recurse(left, right, threshold, features, 0,0,0,0)
print_decision_tree(clf,features)
SqlOut=""
for i in range(len(Results)): 
    SqlOut=SqlOut+Results[i]+ " end,"+chr(13)+chr(10)
Gat
sumber
1

Sekarang Anda dapat menggunakan export_text.

from sklearn.tree import export_text

r = export_text(loan_tree, feature_names=(list(X_train.columns)))
print(r)

Contoh lengkap dari [sklearn] [1]

from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import export_text
iris = load_iris()
X = iris['data']
y = iris['target']
decision_tree = DecisionTreeClassifier(random_state=0, max_depth=2)
decision_tree = decision_tree.fit(X, y)
r = export_text(decision_tree, feature_names=iris['feature_names'])
print(r)
kevin
sumber
0

Memodifikasi kode Zelazny7 untuk mengambil SQL dari pohon keputusan.

# SQL from decision tree

def get_lineage(tree, feature_names):
     left      = tree.tree_.children_left
     right     = tree.tree_.children_right
     threshold = tree.tree_.threshold
     features  = [feature_names[i] for i in tree.tree_.feature]
     le='<='               
     g ='>'
     # get ids of child nodes
     idx = np.argwhere(left == -1)[:,0]     

     def recurse(left, right, child, lineage=None):          
          if lineage is None:
               lineage = [child]
          if child in left:
               parent = np.where(left == child)[0].item()
               split = 'l'
          else:
               parent = np.where(right == child)[0].item()
               split = 'r'
          lineage.append((parent, split, threshold[parent], features[parent]))
          if parent == 0:
               lineage.reverse()
               return lineage
          else:
               return recurse(left, right, parent, lineage)
     print 'case '
     for j,child in enumerate(idx):
        clause=' when '
        for node in recurse(left, right, child):
            if len(str(node))<3:
                continue
            i=node
            if i[1]=='l':  sign=le 
            else: sign=g
            clause=clause+i[3]+sign+str(i[2])+' and '
        clause=clause[:-4]+' then '+str(j)
        print clause
     print 'else 99 end as clusters'
Arslan
sumber
0

Rupanya sudah lama ada seseorang yang memutuskan untuk mencoba menambahkan fungsi berikut ke fungsi ekspor pohon scikit resmi (yang pada dasarnya hanya mendukung export_graphviz)

def export_dict(tree, feature_names=None, max_depth=None) :
    """Export a decision tree in dict format.

Inilah komitmen penuhnya:

https://github.com/scikit-learn/scikit-learn/blob/79bdc8f711d0af225ed6be9fdb708cea9f98a910/sklearn/tree/export.py

Tidak yakin apa yang terjadi pada komentar ini. Tetapi Anda juga bisa mencoba menggunakan fungsi itu.

Saya pikir ini menjamin permintaan dokumentasi serius kepada orang-orang baik scikit-belajar untuk mendokumentasikan sklearn.tree.TreeAPI yang merupakan struktur pohon yang mendasari yang DecisionTreeClassifiermengekspos sebagai atributnya tree_.

Aris Koning
sumber
0

Cukup gunakan fungsi dari sklearn.tree seperti ini

from sklearn.tree import export_graphviz
    export_graphviz(tree,
                out_file = "tree.dot",
                feature_names = tree.columns) //or just ["petal length", "petal width"]

Dan kemudian lihat di folder proyek Anda untuk file tree.dot , salin SEMUA konten dan tempel di sini http://www.webgraphviz.com/ dan hasilkan grafik Anda :)

chainstair
sumber
0

Terima kasih atas solusi luar biasa dari @paulkerfeld. Di atas solusi nya, untuk semua orang yang ingin memiliki versi serial pohon, hanya menggunakan tree.threshold, tree.children_left, tree.children_right, tree.featuredan tree.value. Karena daun tidak memiliki perpecahan dan karenanya tidak ada fitur nama dan anak-anak, placeholder mereka tree.featuredan tree.children_***yang _tree.TREE_UNDEFINEDdan _tree.TREE_LEAF. Setiap pemisahan diberi indeks unik oleh depth first search.
Perhatikan bahwa tree.valuebentuknya[n, 1, 1]

Yanqi Huang
sumber
0

Berikut adalah fungsi yang menghasilkan kode Python dari pohon keputusan dengan mengubah output dari export_text :

import string
from sklearn.tree import export_text

def export_py_code(tree, feature_names, max_depth=100, spacing=4):
    if spacing < 2:
        raise ValueError('spacing must be > 1')

    # Clean up feature names (for correctness)
    nums = string.digits
    alnums = string.ascii_letters + nums
    clean = lambda s: ''.join(c if c in alnums else '_' for c in s)
    features = [clean(x) for x in feature_names]
    features = ['_'+x if x[0] in nums else x for x in features if x]
    if len(set(features)) != len(feature_names):
        raise ValueError('invalid feature names')

    # First: export tree to text
    res = export_text(tree, feature_names=features, 
                        max_depth=max_depth,
                        decimals=6,
                        spacing=spacing-1)

    # Second: generate Python code from the text
    skip, dash = ' '*spacing, '-'*(spacing-1)
    code = 'def decision_tree({}):\n'.format(', '.join(features))
    for line in repr(tree).split('\n'):
        code += skip + "# " + line + '\n'
    for line in res.split('\n'):
        line = line.rstrip().replace('|',' ')
        if '<' in line or '>' in line:
            line, val = line.rsplit(maxsplit=1)
            line = line.replace(' ' + dash, 'if')
            line = '{} {:g}:'.format(line, float(val))
        else:
            line = line.replace(' {} class:'.format(dash), 'return')
        code += skip + line + '\n'

    return code

Penggunaan sampel:

res = export_py_code(tree, feature_names=names, spacing=4)
print (res)

Output sampel:

def decision_tree(f1, f2, f3):
    # DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=3,
    #                        max_features=None, max_leaf_nodes=None,
    #                        min_impurity_decrease=0.0, min_impurity_split=None,
    #                        min_samples_leaf=1, min_samples_split=2,
    #                        min_weight_fraction_leaf=0.0, presort=False,
    #                        random_state=42, splitter='best')
    if f1 <= 12.5:
        if f2 <= 17.5:
            if f1 <= 10.5:
                return 2
            if f1 > 10.5:
                return 3
        if f2 > 17.5:
            if f2 <= 22.5:
                return 1
            if f2 > 22.5:
                return 1
    if f1 > 12.5:
        if f1 <= 17.5:
            if f3 <= 23.5:
                return 2
            if f3 > 23.5:
                return 3
        if f1 > 17.5:
            if f1 <= 25:
                return 1
            if f1 > 25:
                return 2

Contoh di atas dihasilkan dengan names = ['f'+str(j+1) for j in range(NUM_FEATURES)].

Salah satu fitur praktis adalah dapat menghasilkan ukuran file yang lebih kecil dengan pengurangan jarak. Baru diatur spacing=2.

Andriy Makukha
sumber