Skip to content

Tree

Tree

Tree dataclass

Tree(root: AbstractNode)

plot

plot(dataspec: DataSpecification, max_depth: Optional[int], label_classes: Optional[Sequence[str]], options: Optional[PlotOptions] = None, d3js_url: str = 'https://d3js.org/d3.v6.min.js') -> TreePlot

Plots a decision tree.

Parameters:

Name Type Description Default
dataspec DataSpecification

Dataspec of the tree.

required
max_depth Optional[int]

Maximum tree depth of the plot. Set to None for full depth.

required
label_classes Optional[Sequence[str]]

For classification, label classes of the dataset.

required
options Optional[PlotOptions]

Advanced options for plotting. Set to None for default style.

None
d3js_url str

URL to load the d3.js library from.

'https://d3js.org/d3.v6.min.js'

Returns:

Type Description
TreePlot

The html content displaying the tree.

pretty

pretty(dataspec: DataSpecification, max_depth: Optional[int] = 6) -> str

Returns a printable representation of the decision tree.

Usage example:

model = ydf.load_model("my_model")
tree = model.get_tree(0)
print(tree.pretty(model.data_spec()))

Parameters:

Name Type Description Default
dataspec DataSpecification

Dataspec of the tree.

required
max_depth Optional[int]

Maximum printed depth.

6

Conditions

AbstractCondition dataclass

AbstractCondition(missing: bool, score: float)

Generic condition.

Attrs

missing: Result of the evaluation of the condition if the input feature is missing. score: Score of a condition. The semantic depends on the learning algorithm.

NumericalHigherThanCondition dataclass

NumericalHigherThanCondition(missing: bool, score: float, attribute: int, threshold: float)

Bases: AbstractCondition

Condition of the form "attribute >= threshold".

Attrs

attribute: Attribute tested by the condition. threshold: Threshold.

CategoricalIsInCondition dataclass

CategoricalIsInCondition(missing: bool, score: float, attribute: int, mask: Sequence[int])

Bases: AbstractCondition

Condition of the form "attribute in mask".

Attrs

attribute: Attribute tested by the condition. mask: Sorted mask values.

CategoricalSetContainsCondition dataclass

CategoricalSetContainsCondition(missing: bool, score: float, attribute: int, mask: Sequence[int])

Bases: AbstractCondition

Condition of the form "attribute intersect mask != empty".

Attrs

attribute: Attribute tested by the condition. mask: Sorted mask values.

DiscretizedNumericalHigherThanCondition dataclass

DiscretizedNumericalHigherThanCondition(missing: bool, score: float, attribute: int, threshold_idx: int)

Bases: AbstractCondition

Condition of the form "attribute >= bounds[threshold]".

Attrs

attribute: Attribute tested by the condition. threshold_idx: Index of threshold in dataspec.

IsMissingInCondition dataclass

IsMissingInCondition(missing: bool, score: float, attribute: int)

Bases: AbstractCondition

Condition of the form "attribute is missing".

Attrs

attribute: Attribute (or one of the attributes) tested by the condition.

IsTrueCondition dataclass

IsTrueCondition(missing: bool, score: float, attribute: int)

Bases: AbstractCondition

Condition of the form "attribute is true".

Attrs

attribute: Attribute tested by the condition.

NumericalSparseObliqueCondition dataclass

NumericalSparseObliqueCondition(missing: bool, score: float, attributes: Sequence[int], weights: Sequence[float], threshold: float)

Bases: AbstractCondition

Condition of the form "attributes * weights >= threshold".

Attrs

attributes: Attribute tested by the condition. weights: Weights for each of the attributes. threshold: Threshold value of the condition.

Nodes

AbstractNode

is_leaf abstractmethod property

is_leaf: bool

Tells if a node is a leaf.

Leaf dataclass

Leaf(value: AbstractValue)

Bases: AbstractNode

NonLeaf dataclass

NonLeaf(value: Optional[AbstractValue] = None, condition: Optional[AbstractCondition] = None, pos_child: Optional[AbstractNode] = None, neg_child: Optional[AbstractNode] = None)

Bases: AbstractNode

Values

AbstractValue dataclass

AbstractValue(num_examples: float)

A generic value/prediction/output.

Attrs

num_examples: Number of examples in the node with weight.

RegressionValue dataclass

RegressionValue(num_examples: float, value: float, standard_deviation: Optional[float] = None)

Bases: AbstractValue

The regression value of a regressive tree.

Can also be used in gradient-boosted-trees for classification and ranking.

Attrs

value: Value of the tree. The semantic depends on the tree: For Regression Random Forest and Regression GBDT, this value is a regressive value in the same unit as the label. For classification and ranking GBDTs, this value is a logit. standard_deviation: Optional standard deviation attached to the value.

ProbabilityValue dataclass

ProbabilityValue(num_examples: float, probability: Sequence[float])

Bases: AbstractValue

A probability distribution value.

Used for random Forest / CART classification trees.

Attrs

probability: An array of probabilities of the label classes i.e. the i-th value is the probability of the "label_value_idx_to_value(..., i)" class. Note that the first value is reserved for the Out-of-vocabulary

UpliftValue dataclass

UpliftValue(num_examples: float, treatment_effect: Sequence[float])

Bases: AbstractValue

The uplift value of a classification or regression uplift tree.

Attrs

treatment_effect: An array of the effects on the treatment groups. The i-th element of this array is the effect of the "i+1"th treatment compared to the control group.