Decision Trees for Regression?

Jason Drummond
6 min readNov 23, 2020

--

After learning all of the course material in the Flatiron Data Science curriculum I realized that some of my projects would be better “fitted” for techniques that we would learn later in the program. In particular the dataset that I wound up using for my Mod 2 project would work better with some of the models we learned in Mod 3. For this Blog post I set out to completely redo my Mod 2 project and frame it as a classification model, or that’s what I thought.

To remind everyone I gathered data from the NYC Open Data Website and set out to predict a persons salary based on the features available to me. As I was working with this dataset originally I wound up having to use way too many categorical features making me believe that I chose a horrible dataset in order to try and fit a regression model. Once I learned about the different models that were introduced in mod 3, Logistic, Decision Tree, Random Forrest, and a few others I figured it would be best to reframe this project as a classification model.

In my research online I came across the a statquest video, linked below, that made me rethink my choice. I learned that you can do a regression model with Decision Trees, I will briefly go over the inner workings of how this model works however I highly recommend watching the video as he goes into great detail showing us how it works.

How Do Regression Trees Work?

Simple Linear Regression Model

From the above graph we can see that this dataset would be a perfect example to use linear regression techniques in order to create a model. It’s easy to see that as the drug dosage increases the effectiveness of the drug also increases. What if this relationship was not so easy to spot? As I and probably many of you have seen from doing our projects we are not always able to spot these “perfect” relationships. In some cases it may be better to find a different model or perhaps try and reframe the problem at hand. For instance what if we had a dataset that looked like the following graph?

Can we use Regression techniques on this type of model?

The above graph seems pretty odd doesn’t it, if we tried to fit a regression line like we did in the above example our predictions would be way off. If only there were a way to somehow harness the power of the classification models that we just learned as well as not being limited to just binary classification. Allow me to introduce you to Regression Trees. Regression Trees are awesome, at least in my opinion, it seems like this type of model would have been perfectly suited for my dataset.

The Way Regression Trees works is kind of simple, imagine a line going vertically between the first two data points in the above graph. We will average those two points and use that value as our temporary node, whatever is below that value will get averaged and be used as a leaf and anything above that value will get averaged and used as another leaf. At this split we will calculate the residual sum of squares and then move on to the next split. We will now imagine an imaginary line going vertically between the second and third data points, these point will get averaged and that value will be used as a temporary node. Again we will average all data points below this value and above this value and use these values as our leaves. Now we will again figure out the residual sum of squares and store this away to compare after we have done this for every split in our dataset. Now that we have all of our residual sum of squares values we can find the lowest value and this will be used to as our root node! Now we can do the same thing for the right node and the left node until we find our leaves.

Wait, Wait, Wait! Won’t this basically create a leaf for every observation in our dataset or in more professional terms wont this method overfit our model? That is completely true but we have a way to stop this overfitting to our training data we will only split our data if we have more than a minimum amount of data points in that node, generally this number is twenty. We can tune our model by changing this value to better suit our dataset, as we can see in the above dataset twenty would be way too high of a value as we only have a small number of datapoints. If instead we chose this minimum split value to say 7 for our dataset we would wind up having 4 leafs and the value in each leaf would be the average value for the data-points in that leaf.

Regression Trees in Action

First we will need to import the necessary model from sklearn

We will then follow along as we have for any other model, before we proceed though you still want to make sure that you have performed necessary preprocessing steps, scaled your data if needed, and split your data into train and test sets. At this step all we now have to do is fit our data to our training set and then predict on our test set

Wasn’t that easy? At this point we can tune our hyperparameters if we want to, remember that we can not tune all the same hyperparameters as we did with the DecisionTreeClassifier. One that we may definitely want to try and tune though is, ‘min_samples_split’. This is the parameter that lets us know when to stop splitting our data that i talked about in the example above.

Results on Mod 2 Dataset

To see if these models were any better I fit a Regression Tree and a Random Forrest Regressor to my Mod 2 dataset then compared these to the best model I found originally, the results are as follows.

These results are quite impressive the random forrest model has an impressive 30 percent boost in RMSE score as compared to the regular logistic regression model that was found to be best during my Mod 2 project. I also found that the r squared value had just as impressive of an increase as it went from about .58 to an astounding .88.

Conclusion

Regression Trees are quite impressive and seem to be great for data that has a lot of categorical data as well as continuous. We can harness the power of the Decision Tree model and use that to predict a continuous target variable. I highly recommend checking out DecisionTreeRegressor as well as the StatQuest video on this topic that is linked below.

Resources

--

--

No responses yet