检查树¶
设置¶
pip install ydf -U
import ydf
import numpy as np
dataset = {
"x1": np.array([0, 0, 0, 1, 1, 1]),
"x2": np.array([1, 1, 0, 0, 1, 1]),
"y": np.array([0, 0, 0, 0, 1, 1]),
}
dataset
{'x1': array([0, 0, 0, 1, 1, 1]), 'x2': array([1, 1, 0, 0, 1, 1]), 'y': array([0, 0, 0, 0, 1, 1])}
训练模型¶
model = ydf.CartLearner(label="y", min_examples=1, task=ydf.Task.REGRESSION).train(dataset)
model.describe()
Train model on 6 examples Model trained in 0:00:00.000728
Task : REGRESSION
Label : y
Features (2) : x1 x2
Weights : None
Trained with tuner : No
Model size : 3 kB
Number of records: 6 Number of columns: 3 Number of columns by type: NUMERICAL: 3 (100%) Columns: NUMERICAL: 3 (100%) 0: "y" NUMERICAL mean:0.333333 min:0 max:1 sd:0.471405 1: "x1" NUMERICAL mean:0.5 min:0 max:1 sd:0.5 2: "x2" NUMERICAL mean:0.666667 min:0 max:1 sd:0.471405 Terminology: nas: Number of non-available (i.e. missing) values. ood: Out of dictionary. manually-defined: Attribute whose type is manually defined by the user, i.e., the type was not automatically inferred. tokenized: The attribute value is obtained through tokenization. has-dict: The attribute is attached to a string dictionary e.g. a categorical attribute stored as a string. vocab-size: Number of unique values.
The following evaluation is computed on the validation or out-of-bag dataset.
The Random Forest does not have out-of-bag evaluation training logs. Train the model with compute_oob_performances=True to compute the training logs. Make sure the training logs have not been removed with pure_serving_model=True.
Variable importances measure the importance of an input feature for a model.
1. "x1" 1.000000 ################ 2. "x2" 0.500000
1. "x1" 1.000000
1. "x1" 1.000000 2. "x2" 1.000000
1. "x1" 0.666667 2. "x2" 0.666667
Those variable importances are computed during training. More, and possibly more informative, variable importances are available when analyzing a model on a test dataset.
Tree #0: "x1">=0.5 [s:0.111111 n:6 np:3 miss:1] ; pred:0.333333 ├─(pos)─ "x2">=0.5 [s:0.222222 n:3 np:2 miss:1] ; pred:0.666667 | ├─(pos)─ pred:1 | └─(neg)─ pred:0 └─(neg)─ pred:0
绘制模型¶
模型的树在model.describe()
的“结构”选项卡中可见。您也可以使用print_tree
方法打印树。
model.print_tree()
'x1' >= 0.5 [score=0.11111 missing=True] ├─(pos)─ 'x2' >= 0.5 [score=0.22222 missing=True] │ ├─(pos)─ value=1 sd=0 │ └─(neg)─ value=0 sd=0 └─(neg)─ value=0 sd=0
访问树结构¶
get_tree
和 get_all_trees
方法以编程方式访问树的结构。
注意: CART 模型只有一棵树,因此 tree_idx
参数设置为 0
。对于具有多棵树的模型,可以使用 model.num_trees()
获取树的数量。
tree = model.get_tree(tree_idx=0)
tree
Tree(root=NonLeaf(value=RegressionValue(num_examples=6.0, value=0.3333333432674408, standard_deviation=0.4714045207910317), condition=NumericalHigherThanCondition(missing=True, score=0.1111111119389534, attribute=1, threshold=0.5), pos_child=NonLeaf(value=RegressionValue(num_examples=3.0, value=0.6666666865348816, standard_deviation=0.4714045207910317), condition=NumericalHigherThanCondition(missing=True, score=0.2222222238779068, attribute=2, threshold=0.5), pos_child=Leaf(value=RegressionValue(num_examples=2.0, value=1.0, standard_deviation=0.0)), neg_child=Leaf(value=RegressionValue(num_examples=1.0, value=0.0, standard_deviation=0.0))), neg_child=Leaf(value=RegressionValue(num_examples=3.0, value=0.0, standard_deviation=0.0))))
你是否认出上面打印的树的结构?你可以访问树的一部分。例如,你可以访问 x2
上的条件:
tree.root.pos_child.condition
NumericalHigherThanCondition(missing=True, score=0.2222222238779068, attribute=2, threshold=0.5)
为了以更可读的形式显示树,可以使用 pretty
函数。
print(tree.pretty(model.data_spec()))
'x1' >= 0.5 [score=0.11111 missing=True] ├─(pos)─ 'x2' >= 0.5 [score=0.22222 missing=True] │ ├─(pos)─ value=1 sd=0 │ └─(neg)─ value=0 sd=0 └─(neg)─ value=0 sd=0