Cómo ajustar árboles de clasificación y regresión en R

Actualizado por ultima vez el 7 de mayo de 2021, por .

Cuando la relación entre un conjunto de variables predictoras y una variable de respuesta es lineal, métodos como la regresión lineal múltiple pueden producir modelos predictivos precisos.

Sin embargo, cuando la relación entre un conjunto de predictores y una respuesta es más compleja, los métodos no lineales a menudo pueden producir modelos más precisos.

Uno de estos métodos son los árboles de clasificación y regresión (CART), que utilizan un conjunto de variables predictoras para construir árboles de decisión que predicen el valor de una variable de respuesta.

Si la variable de respuesta es continua, podemos construir árboles de regresión y si la variable de respuesta es categórica, podemos construir árboles de clasificación.

Este tutorial explica cómo construir árboles de regresión y clasificación en R.

Ejemplo 1: creación de un árbol de regresión en R

Para este ejemplo, usaremos el conjunto de datos de Hitters del paquete ISLR , que contiene información diversa sobre 263 jugadores de béisbol profesionales.

Usaremos este conjunto de datos para construir un árbol de regresión que use las variables predictoras de jonrones y años jugados para predecir el salario de un jugador dado.

Utilice los siguientes pasos para construir este árbol de regresión.

Paso 1: Cargue los paquetes necesarios.

Primero, cargaremos los paquetes necesarios para este ejemplo:

biblioteca (ISLR) #contiene la 
biblioteca de conjuntos de datos de Hitters ( rpart
 ) #para ajustar la biblioteca de árboles de decisión (rpart.plot) #para trazar árboles de decisión

Paso 2: Construya el árbol de regresión inicial.

Primero, construiremos un gran árbol de regresión inicial. Podemos asegurarnos de que el árbol sea grande usando un valor pequeño para cp , que significa «parámetro de complejidad».

Esto significa que realizaremos nuevas divisiones en el árbol de regresión siempre que el R-cuadrado general del modelo aumente al menos en el valor especificado por cp.

Luego usaremos la función printcp () para imprimir los resultados del modelo:

#build the initial tree
 tree <- rpart (Salario ~ Años + HmRun, data = Hitters, control = rpart. control (cp = .0001 ))

#Ver resultados
printcp (árbol)

Variables realmente utilizadas en la construcción de árboles:
[1] HmRun años

Error de nodo raíz: 53319113/263 = 202734

n = 263 (59 observaciones eliminadas debido a falta)

           CP nsplit rel error xerror xstd
1 0,24674996 0 1,00000 1,00756 0,13890
2 0,10806932 1 0,75325 0,76438 0,12828
3 0,01865610 2 0,64518 0,70295 0,12769
4 0,01761100 3 0,62652 0,70339 0,12337
5 0,01747617 4 0,60891 0,70339 0,12337
6 0,01038188 5 0,59144 0,66629 0,11817
7 0,01038065 6 0,58106 0,65697 0,11687
8 0,00731045 8 0,56029 0,67177 0,11913
9 0,00714883 9 0,55298 0,67881 0,11960
10 0,00708618 10 0,54583 0,68034 0,11988
11 0,00516285 12 0,53166 0,68427 0,11997
12 0,00445345 13 0,52650 0,68994 0,11996
13 0,00406069 14 0,52205 0,68988 0,11940
14 0,00264728 15 0,51799 0,68874 0,11916
15 0,00196586 16 0,51534 0,68638 0,12043
16 0.00016686 17 0.51337 0.67577 0.11635
17 0.00010000 18 0.51321 0.67576 0.11615
n = 263 (59 observaciones eliminadas debido a falta)

Paso 3: poda el árbol.

A continuación, podaremos el árbol de regresión para encontrar el valor óptimo a utilizar para cp (el parámetro de complejidad) que conduce al error de prueba más bajo.

Tenga en cuenta que el valor óptimo de cp es el que conduce al xerror más bajo en la salida anterior, que representa el error en las observaciones de los datos de validación cruzada.

#identifique el mejor valor cp para usar el
 mejor <- árbol $ cptable [which. min (árbol $ cptable [, " xerror "]), " CP "]

#producir un árbol podado basado en el mejor valor de cp
 pruned_tree <- podar (árbol, cp = mejor)

#plote el árbol podado
 prp (pruned_tree,
    faclen = 0 , #use nombres completos para etiquetas de factor 
    extra = 1 , #muestra el número de obs. para cada nodo terminal 
    roundint = F , # no redondear a números enteros en los 
    dígitos de salida = 5 ) # mostrar 5 lugares decimales en la salida

Árbol de regresión en R

Podemos ver que el árbol podado final tiene seis nodos terminales. Cada nodo terminal muestra el salario previsto de los jugadores en ese nodo junto con el número de observaciones del conjunto de datos original que pertenecen a esa nota.

Por ejemplo, podemos ver que en el conjunto de datos original había 90 jugadores con menos de 4.5 años de experiencia y su salario promedio era de $ 225.83k.

Interpretar un árbol de regresión en R

Paso 4: usa el árbol para hacer predicciones.

Podemos usar el árbol podado final para predecir el salario de un jugador determinado en función de sus años de experiencia y jonrones promedio.

Por ejemplo, un jugador que tiene 7 años de experiencia y 4 jonrones promedio tiene un salario previsto de 502,81 mil dólares .

Ejemplo de árbol de regresión en R

Podemos usar la función de predicción () en R para confirmar esto:

#define new player
 new <- data.frame (Years = 7, HmRun = 4)

#utilice el árbol podado para predecir el salario de este jugador
 predecir (pruned_tree, newdata = new)

502.8079 

Ejemplo 2: Creación de un árbol de clasificación en R

Para este ejemplo, vamos a utilizar el ptitanic conjunto de datos de la rpart.plot paquete, que contiene diversa información sobre los pasajeros a bordo del Titanic.

Usaremos este conjunto de datos para construir un árbol de clasificación que use las variables predictoras class , sex y age para predecir si un pasajero dado sobrevivió o no.

Utilice los siguientes pasos para crear este árbol de clasificación.

Paso 1: Cargue los paquetes necesarios.

Primero, cargaremos los paquetes necesarios para este ejemplo:

biblioteca ( rpart
 ) # para ajustar árboles de decisión biblioteca (rpart.plot) # para trazar árboles de decisión

Paso 2: Construya el árbol de clasificación inicial.

Primero, construiremos un gran árbol de clasificación inicial. Podemos asegurarnos de que el árbol sea grande usando un valor pequeño para cp , que significa «parámetro de complejidad».

Esto significa que realizaremos nuevas divisiones en el árbol de clasificación siempre que el ajuste general del modelo aumente al menos en el valor especificado por cp.

Luego usaremos la función printcp () para imprimir los resultados del modelo:

#build the initial tree
 tree <- rpart (sobrevivido ~ pclass + sexo + edad, datos = ptitanic, control = rpart. control (cp = .0001 ))

#Ver resultados
printcp (árbol)

Variables realmente utilizadas en la construcción de árboles:
[1] edad pclass sexo   

Error de nodo raíz: 500/1309 = 0.38197

n = 1309 

      CP nsplit rel error xerror xstd
1 0,4240 0 1,000 1,000 0,035158
2 0,0140 1 0,576 0,576 0,029976
3 0,0095 3 0,548 0,578 0,030013
4 0,0070 7 0,510 0,552 0,029517
5 0,0050 9 0,496 0,528 0,029035
6 0,0025 11 0,486 0,532 0,029117
7 0,0020 19 0,464 0,536 0,029198
8 0,0001 22 0,458 0,528 0,029035

Paso 3: poda el árbol.

A continuación, podaremos el árbol de regresión para encontrar el valor óptimo a utilizar para cp (el parámetro de complejidad) que conduce al error de prueba más bajo.

Tenga en cuenta que el valor óptimo de cp es el que conduce al xerror más bajo en la salida anterior, que representa el error en las observaciones de los datos de validación cruzada.

#identifique el mejor valor cp para usar el
 mejor <- árbol $ cptable [which. min (árbol $ cptable [, " xerror "]), " CP "]

#producir un árbol podado basado en el mejor valor de cp
 pruned_tree <- podar (árbol, cp = mejor)

#plote el árbol podado
 prp (pruned_tree,
    faclen = 0 , #use nombres completos para las etiquetas de los factores
    extra = 1 , #muestra el número de obs. para cada nodo terminal
    roundint = F , # no redondear a números enteros en la salida
    dígitos = 5 ) # mostrar 5 posiciones decimales en la salida

Árbol de clasificación en R

Podemos ver que el árbol podado final tiene 10 nodos terminales. Cada nodo terminal muestra el número de pasajeros que murieron junto con el número que sobrevivió.

Por ejemplo, en el nodo del extremo izquierdo vemos que 664 pasajeros murieron y 136 sobrevivieron.

Interpretación del árbol de clasificación en R

Paso 4: usa el árbol para hacer predicciones.

Podemos usar el árbol podado final para predecir la probabilidad de que un pasajero determinado sobreviva según su clase, edad y sexo.

Por ejemplo, un pasajero masculino que está en 1ra clase y tiene 8 años tiene una probabilidad de supervivencia del 29/11 = 37,9%.

Árbol de clasificación en R

Puede encontrar el código R completo utilizado en estos ejemplos aquí .

  • https://r-project.org
  • https://www.python.org/
  • https://www.stata.com/

Deja un comentario

Los investigadores a menudo están interesados ​​en responder preguntas sobre poblaciones como: ¿Cuál es la altura media de una determinada…
statologos comunidad-2

Compartimos información EXCLUSIVA y GRATUITA solo para suscriptores (cursos privados, programas, consejos y mucho más)

You have Successfully Subscribed!