Learn
Decision Trees
Classifying New Data

We can finally use our tree as a classifier! Given a new data point, we start at the top of the tree and follow the path of the tree until we hit a leaf. Once we get to a leaf, we’ll use the classes of the points from the training set to make a classification.

We’ve slightly changed the way our build_tree() function works. Instead of returning a list of branches or a Counter object, the build_tree() function now returns a Leaf object or an Internal_Node object. We’ll explain how to use these objects in the instructions!

Let’s write a function that will use our tree to classify new points!

Instructions

1.

We’ve created a tree named tree using a lot of car data. Use the print_tree() function with tree as a parameter to see it.

Notice that the tree now knows which feature was used to split the data. This new information is contained in the Leaf and Internal_Node classes. This will come in handy when we write our classify function!

Comment out printing the tree once you get a sense of how large it is!

2.

Let’s start writing the classify() function. classify() should take a datapoint and a tree as a parameter.

The first thing classify should do is check to see if we’re at a leaf.

Check to see if tree is a Leaf by using the isinstance() function.

For example, isinstance(a, list) will be True if a is a list. You should check if tree is a Leaf.

If we’ve found a Leaf, that means we want to return the label with the highest count. The label counts are stored in tree.labels.

You could find the label with the largest count by using a for loop, or by using this rather complicated line of code:

return max(tree.labels.items(), key=operator.itemgetter(1))[0]
3.

If we’re not at a leaf, we want to find the branch that corresponds to our data point. For example, if we’re splitting on index 0 and our data point is ['med', 'low', '4', '2', 'big', 'low'], we want to find the branch that contains all of the points with med at index 0.

To start, let’s find datapoint‘s value of the feature we’re looking for. If datapoint were the example above, and the feature we’re interested is 0, this would be med.

Outside the if statement, create a variable named value and set it equal to datapoint[tree.feature]. tree.feature contains the index of the feature that we’re splitting on, so datapoint[tree.feature] is the value at that index.

To help us check your code, return value.

4.

Start by deleting return value.

Let’s now loop through all of the branches in the tree to find the one that has all the data points with value at the correct index.

Your loop should look like this:

for branch in tree.branches:

Next, inside the loop, check to see if branch.value is equal to value. If it is, we’ve found the branch that we’re looking for! We want to now recursively call classify() on that branch:

return classify(datapoint, branch)

We know that one of these branches will be the one we’re looking for, so we know that this return statement will happen once.

5.

Finally, outside of your function, call classify() using test_point and tree as parameters. Print the results. You should see a classification for this new point.

Folder Icon

Sign up to start coding

Already have an account?