stub What is a Decision Tree? - Unite.AI
Connect with us

AI 101

What is a Decision Tree?

mm
Updated on

What is a Decision Tree?

A decision tree is a useful machine learning algorithm used for both regression and classification tasks. The name “decision tree” comes from the fact that the algorithm keeps dividing the dataset down into smaller and smaller portions until the data has been divided into single instances, which are then classified. If you were to visualize the results of the algorithm, the way the categories are divided would resemble a tree and many leaves.

That’s a quick definition of a decision tree, but let’s take a deep dive into how decision trees work. Having a better understanding of how decision trees operate, as well as their use cases, will assist you in knowing when to utilize them during your machine learning projects.

Format of a Decision Tree

A decision tree is a lot like a flowchart. To utilize a flowchart you start at the starting point, or root, of the chart and then based on how you answer the filtering criteria of that starting node you move to one of the next possible nodes. This process is repeated until an ending is reached.

Decision trees operate in essentially the same manner, with every internal node in the tree being some sort of test/filtering criteria. The nodes on the outside, the endpoints of the tree, are the labels for the datapoint in question and they are dubbed “leaves”. The branches that lead from the internal nodes to the next node are features or conjunctions of features. The rules used to classify the datapoints are the paths that run from the root to the leaves.

Algorithms for Decision Trees

Decision trees operate on an algorithmic approach which splits the dataset up into individual data points based on different criteria. These splits are done with different variables, or the different features of the dataset. For example, if the goal is to determine whether or not a dog or cat is being described by the input features, variables the data is split on might be things like “claws” and “barks”.

So what algorithms are used to actually split the data into branches and leaves? There are various methods that can be used to split a tree up, but the most common method of splitting is probably a technique referred to as “recursive binary split”. When carrying out this method of splitting, the process starts at the root and the number of features in the dataset represents the possible number of possible splits. A function is used to determine how much accuracy every possible split will cost, and the split is made using the criteria that sacrifices the least accuracy. This process is carried out recursively and sub-groups are formed using the same general strategy.

In order to determine the cost of the split, a cost function is used. A different cost function is used for regression tasks and classification tasks. The goal of both cost functions is to determine which branches have the most similar response values, or the most homogenous branches. Consider that you want test data of a certain class to follow certain paths and this makes intuitive sense.

In terms of the regression cost function for recursive binary split, the algorithm used to calculate the cost is as follows:

sum(y – prediction)^2

The prediction for a particular group of data points is the mean of the responses of the training data for that group. All the data points are run through the cost function to determine the cost for all the possible splits and the split with the lowest cost is selected.

Regarding the cost function for classification, the function is as follows:

G = sum(pk * (1 – pk))

This is the Gini score, and it is a measurement of the effectiveness of a split, based on how many instances of different classes are in the groups resulting from the split. In other words, it quantifies how mixed the groups are after the split. An optimal split is when all the groups resulting from the split consist only of inputs from one class. If an optimal split has been created the “pk” value will be either 0 or 1 and G will be equal to zero. You might be able to guess that the worst-case split is one where there is a 50-50 representation of the classes in the split, in the case of binary classification. In this case, the “pk” value would be 0.5 and G would also be 0.5.

The splitting process is terminated when all the data points have been turned into leaves and classified. However, you may want to stop the growth of the tree early. Large complex trees are prone to overfitting, but several different methods can be used to combat this. One method of reducing overfitting is to specify a minimum number of data points that will be used to create a leaf. Another method of controlling for overfitting is restricting the tree to a certain maximum depth, which controls how long a path can stretch from the root to a leaf.

Another process involved in the creation of decision trees is pruning. Pruning can help increase the performance of a decision tree by stripping out branches containing features that have little predictive power/little importance for the model. In this way, the complexity of the tree is reduced, it becomes less likely to overfit, and the predictive utility of the model is increased.

When conducting pruning, the process can start at either the top of the tree or the bottom of the tree. However, the easiest method of pruning is to start with the leaves and attempt to drop the node that contains the most common class within that leaf. If the accuracy of the model doesn’t deteriorate when this is done, then the change is preserved. There are other techniques used to carry out pruning, but the method described above – reduced error pruning – is probably the most common method of decision tree pruning.

Considerations For Using Decision Trees

Decision trees are often useful when classification needs to be carried out but computation time is a major constraint. Decision trees can make it clear which features in the chosen datasets wield the most predictive power. Furthermore, unlike many machine learning algorithms where the rules used to classify the data may be hard to interpret, decision trees can render interpretable rules. Decision trees are also able to make use of both categorical and continuous variables which means that less preprocessing is needed, compared to algorithms that can only handle one of these variable types.

Decision trees tend not to perform very well when used to determine the values of continuous attributes. Another limitation of decision trees is that, when doing classification, if there are few training examples but many classes the decision tree tends to be inaccurate.

Blogger and programmer with specialties in Machine Learning and Deep Learning topics. Daniel hopes to help others use the power of AI for social good.