GradTree: Learning Axis-Aligned Decision Trees with Gradient Descent

Authors

  • Sascha Marton University of Mannheim
  • Stefan Lüdtke University of Rostock
  • Christian Bartelt University of Mannheim
  • Heiner Stuckenschmidt University of Mannheim

DOI:

https://doi.org/10.1609/aaai.v38i13.29345

Keywords:

ML: Neuro-Symbolic Learning, ML: Classification and Regression

Abstract

Decision Trees (DTs) are commonly used for many machine learning tasks due to their high degree of interpretability. However, learning a DT from data is a difficult optimization problem, as it is non-convex and non-differentiable. Therefore, common approaches learn DTs using a greedy growth algorithm that minimizes the impurity locally at each internal node. Unfortunately, this greedy procedure can lead to inaccurate trees. In this paper, we present a novel approach for learning hard, axis-aligned DTs with gradient descent. The proposed method uses backpropagation with a straight-through operator on a dense DT representation, to jointly optimize all tree parameters. Our approach outperforms existing methods on binary classification benchmarks and achieves competitive results for multi-class tasks. The implementation is available under: https://github.com/s-marton/GradTree

Published

2024-03-24

How to Cite

Marton, S., Lüdtke, S., Bartelt, C., & Stuckenschmidt, H. (2024). GradTree: Learning Axis-Aligned Decision Trees with Gradient Descent. Proceedings of the AAAI Conference on Artificial Intelligence, 38(13), 14323-14331. https://doi.org/10.1609/aaai.v38i13.29345

Issue

Section

AAAI Technical Track on Machine Learning IV