我能从决策树中的训练树中提取基本的决策规则(或“决策路径”)作为文本列表吗?

喜欢的东西:

if A>0.4 then if B<0.2 then if C>0.8 then class='X'

当前回答

我相信这个答案比这里的其他答案更正确:

from sklearn.tree import _tree

def tree_to_code(tree, feature_names):
    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]
    print "def tree({}):".format(", ".join(feature_names))

    def recurse(node, depth):
        indent = "  " * depth
        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            print "{}if {} <= {}:".format(indent, name, threshold)
            recurse(tree_.children_left[node], depth + 1)
            print "{}else:  # if {} > {}".format(indent, name, threshold)
            recurse(tree_.children_right[node], depth + 1)
        else:
            print "{}return {}".format(indent, tree_.value[node])

    recurse(0, 1)

这将打印出一个有效的Python函数。下面是一个树的输出示例,它试图返回它的输入,一个0到10之间的数字。

def tree(f0):
  if f0 <= 6.0:
    if f0 <= 1.5:
      return [[ 0.]]
    else:  # if f0 > 1.5
      if f0 <= 4.5:
        if f0 <= 3.5:
          return [[ 3.]]
        else:  # if f0 > 3.5
          return [[ 4.]]
      else:  # if f0 > 4.5
        return [[ 5.]]
  else:  # if f0 > 6.0
    if f0 <= 8.5:
      if f0 <= 7.5:
        return [[ 7.]]
      else:  # if f0 > 7.5
        return [[ 8.]]
    else:  # if f0 > 8.5
      return [[ 9.]]

以下是我在其他答案中看到的一些绊脚石:

使用tree_。用阈值== -2来判断节点是否是叶节点不是一个好主意。如果它是一个阈值为-2的真实决策节点呢?相反,你应该看看树。Feature or tree.children_*。 对于tree_中的i,行features = [feature_names[i]。我的sklearn版本崩溃了,因为树。树_。特征为-2(特别是叶节点)。 递归函数中不需要有多个if语句,一个就可以了。

其他回答

下面是一个通过转换export_text的输出从决策树生成Python代码的函数:

import string
from sklearn.tree import export_text

def export_py_code(tree, feature_names, max_depth=100, spacing=4):
    if spacing < 2:
        raise ValueError('spacing must be > 1')

    # Clean up feature names (for correctness)
    nums = string.digits
    alnums = string.ascii_letters + nums
    clean = lambda s: ''.join(c if c in alnums else '_' for c in s)
    features = [clean(x) for x in feature_names]
    features = ['_'+x if x[0] in nums else x for x in features if x]
    if len(set(features)) != len(feature_names):
        raise ValueError('invalid feature names')

    # First: export tree to text
    res = export_text(tree, feature_names=features, 
                        max_depth=max_depth,
                        decimals=6,
                        spacing=spacing-1)

    # Second: generate Python code from the text
    skip, dash = ' '*spacing, '-'*(spacing-1)
    code = 'def decision_tree({}):\n'.format(', '.join(features))
    for line in repr(tree).split('\n'):
        code += skip + "# " + line + '\n'
    for line in res.split('\n'):
        line = line.rstrip().replace('|',' ')
        if '<' in line or '>' in line:
            line, val = line.rsplit(maxsplit=1)
            line = line.replace(' ' + dash, 'if')
            line = '{} {:g}:'.format(line, float(val))
        else:
            line = line.replace(' {} class:'.format(dash), 'return')
        code += skip + line + '\n'

    return code

示例用法:

res = export_py_code(tree, feature_names=names, spacing=4)
print (res)

样例输出:

def decision_tree(f1, f2, f3):
    # DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=3,
    #                        max_features=None, max_leaf_nodes=None,
    #                        min_impurity_decrease=0.0, min_impurity_split=None,
    #                        min_samples_leaf=1, min_samples_split=2,
    #                        min_weight_fraction_leaf=0.0, presort=False,
    #                        random_state=42, splitter='best')
    if f1 <= 12.5:
        if f2 <= 17.5:
            if f1 <= 10.5:
                return 2
            if f1 > 10.5:
                return 3
        if f2 > 17.5:
            if f2 <= 22.5:
                return 1
            if f2 > 22.5:
                return 1
    if f1 > 12.5:
        if f1 <= 17.5:
            if f3 <= 23.5:
                return 2
            if f3 > 23.5:
                return 3
        if f1 > 17.5:
            if f1 <= 25:
                return 1
            if f1 > 25:
                return 2

上面的示例生成了names = ['f'+str(j+1) for j in range(NUM_FEATURES)]。

一个方便的功能是,它可以生成更小的文件大小与减少间距。只需要设置spacing=2。

我修改了Zelazny7提交的代码来打印一些伪代码:

def get_code(tree, feature_names):
        left      = tree.tree_.children_left
        right     = tree.tree_.children_right
        threshold = tree.tree_.threshold
        features  = [feature_names[i] for i in tree.tree_.feature]
        value = tree.tree_.value

        def recurse(left, right, threshold, features, node):
                if (threshold[node] != -2):
                        print "if ( " + features[node] + " <= " + str(threshold[node]) + " ) {"
                        if left[node] != -1:
                                recurse (left, right, threshold, features,left[node])
                        print "} else {"
                        if right[node] != -1:
                                recurse (left, right, threshold, features,right[node])
                        print "}"
                else:
                        print "return " + str(value[node])

        recurse(left, right, threshold, features, 0)

如果你在同一个例子中调用get_code(dt, df.columns),你会得到:

if ( col1 <= 0.5 ) {
return [[ 1.  0.]]
} else {
if ( col2 <= 4.5 ) {
return [[ 0.  1.]]
} else {
if ( col1 <= 2.5 ) {
return [[ 1.  0.]]
} else {
return [[ 0.  1.]]
}
}
}

从这个答案中,您可以得到一个可读且高效的表示:https://stackoverflow.com/a/65939892/3746632

输出如下所示。X为一维向量,表示单个实例的特征。

from numba import jit,njit
@njit
def predict(X):
    ret = 0
    if X[0] <= 0.5: # if w_pizza <= 0.5
        if X[1] <= 0.5: # if w_mexico <= 0.5
            if X[2] <= 0.5: # if w_reusable <= 0.5
                ret += 1
            else:  # if w_reusable > 0.5
                pass
        else:  # if w_mexico > 0.5
            ret += 1
    else:  # if w_pizza > 0.5
        pass
    if X[0] <= 0.5: # if w_pizza <= 0.5
        if X[1] <= 0.5: # if w_mexico <= 0.5
            if X[2] <= 0.5: # if w_reusable <= 0.5
                ret += 1
            else:  # if w_reusable > 0.5
                pass
        else:  # if w_mexico > 0.5
            pass
    else:  # if w_pizza > 0.5
        ret += 1
    if X[0] <= 0.5: # if w_pizza <= 0.5
        if X[1] <= 0.5: # if w_mexico <= 0.5
            if X[2] <= 0.5: # if w_reusable <= 0.5
                ret += 1
            else:  # if w_reusable > 0.5
                ret += 1
        else:  # if w_mexico > 0.5
            ret += 1
    else:  # if w_pizza > 0.5
        pass
    if X[0] <= 0.5: # if w_pizza <= 0.5
        if X[1] <= 0.5: # if w_mexico <= 0.5
            if X[2] <= 0.5: # if w_reusable <= 0.5
                ret += 1
            else:  # if w_reusable > 0.5
                ret += 1
        else:  # if w_mexico > 0.5
            pass
    else:  # if w_pizza > 0.5
        ret += 1
    if X[0] <= 0.5: # if w_pizza <= 0.5
        if X[1] <= 0.5: # if w_mexico <= 0.5
            if X[2] <= 0.5: # if w_reusable <= 0.5
                ret += 1
            else:  # if w_reusable > 0.5
                pass
        else:  # if w_mexico > 0.5
            pass
    else:  # if w_pizza > 0.5
        pass
    if X[0] <= 0.5: # if w_pizza <= 0.5
        if X[1] <= 0.5: # if w_mexico <= 0.5
            if X[2] <= 0.5: # if w_reusable <= 0.5
                ret += 1
            else:  # if w_reusable > 0.5
                pass
        else:  # if w_mexico > 0.5
            ret += 1
    else:  # if w_pizza > 0.5
        ret += 1
    if X[0] <= 0.5: # if w_pizza <= 0.5
        if X[1] <= 0.5: # if w_mexico <= 0.5
            if X[2] <= 0.5: # if w_reusable <= 0.5
                ret += 1
            else:  # if w_reusable > 0.5
                pass
        else:  # if w_mexico > 0.5
            pass
    else:  # if w_pizza > 0.5
        ret += 1
    if X[0] <= 0.5: # if w_pizza <= 0.5
        if X[1] <= 0.5: # if w_mexico <= 0.5
            if X[2] <= 0.5: # if w_reusable <= 0.5
                ret += 1
            else:  # if w_reusable > 0.5
                pass
        else:  # if w_mexico > 0.5
            pass
    else:  # if w_pizza > 0.5
        pass
    if X[0] <= 0.5: # if w_pizza <= 0.5
        if X[1] <= 0.5: # if w_mexico <= 0.5
            if X[2] <= 0.5: # if w_reusable <= 0.5
                ret += 1
            else:  # if w_reusable > 0.5
                pass
        else:  # if w_mexico > 0.5
            pass
    else:  # if w_pizza > 0.5
        pass
    if X[0] <= 0.5: # if w_pizza <= 0.5
        if X[1] <= 0.5: # if w_mexico <= 0.5
            if X[2] <= 0.5: # if w_reusable <= 0.5
                ret += 1
            else:  # if w_reusable > 0.5
                pass
        else:  # if w_mexico > 0.5
            pass
    else:  # if w_pizza > 0.5
        pass
    return ret/10

在0.18.0版本中,有一个新的DecisionTreeClassifier方法decision_path。开发人员提供了一个广泛的(文档良好的)演练。

演练中打印树结构的第一部分代码似乎没有问题。但是,我修改了第二节中的代码来检查一个示例。我的更改用# <——表示

在拉取请求#8653和#10951中指出错误后,下面代码中由# <——标记的更改已在演练链接中更新。现在就容易多了。

sample_id = 0
node_index = node_indicator.indices[node_indicator.indptr[sample_id]:
                                    node_indicator.indptr[sample_id + 1]]

print('Rules used to predict sample %s: ' % sample_id)
for node_id in node_index:

    if leave_id[sample_id] == node_id:  # <-- changed != to ==
        #continue # <-- comment out
        print("leaf node {} reached, no decision here".format(leave_id[sample_id])) # <--

    else: # < -- added else to iterate through decision nodes
        if (X_test[sample_id, feature[node_id]] <= threshold[node_id]):
            threshold_sign = "<="
        else:
            threshold_sign = ">"

        print("decision id node %s : (X[%s, %s] (= %s) %s %s)"
              % (node_id,
                 sample_id,
                 feature[node_id],
                 X_test[sample_id, feature[node_id]], # <-- changed i to sample_id
                 threshold_sign,
                 threshold[node_id]))

Rules used to predict sample 0: 
decision id node 0 : (X[0, 3] (= 2.4) > 0.800000011921)
decision id node 2 : (X[0, 2] (= 5.1) > 4.94999980927)
leaf node 4 reached, no decision here

更改sample_id以查看其他示例的决策路径。我没有向开发人员询问这些更改,只是在示例中看起来更直观。

这是您需要的代码

我已经修改了顶部喜欢的代码缩进在一个jupyter笔记本python 3正确

import numpy as np
from sklearn.tree import _tree

def tree_to_code(tree, feature_names):
    tree_ = tree.tree_
    feature_name = [feature_names[i] 
                    if i != _tree.TREE_UNDEFINED else "undefined!" 
                    for i in tree_.feature]
    print("def tree({}):".format(", ".join(feature_names)))

    def recurse(node, depth):
        indent = "    " * depth
        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            print("{}if {} <= {}:".format(indent, name, threshold))
            recurse(tree_.children_left[node], depth + 1)
            print("{}else:  # if {} > {}".format(indent, name, threshold))
            recurse(tree_.children_right[node], depth + 1)
        else:
            print("{}return {}".format(indent, np.argmax(tree_.value[node])))

    recurse(0, 1)