Contenido de este artículo
- 0
- 0
- 0
- 0
El análisis discriminante lineal es un método que puede utilizar cuando tiene un conjunto de variables predictoras y le gustaría clasificar una variable de respuesta en dos o más clases.
Este tutorial proporciona un ejemplo paso a paso de cómo realizar un análisis discriminante lineal en Python.
Paso 1: cargue las bibliotecas necesarias
Primero, cargaremos las funciones y bibliotecas necesarias para este ejemplo:
de sklearn. model_selection importar train_test_split desde sklearn. model_selection importar RepeatedStratifiedKFold desde sklearn. model_selection importar cross_val_score de sklearn. discriminant_analysis importar LinearDiscriminantAnalysis de sklearn importar conjuntos de datos importar matplotlib. pyplot como plt importar pandas como pd importar numpy como np
Paso 2: cargue los datos
Para este ejemplo, usaremos el conjunto de datos de iris de la biblioteca sklearn. El siguiente código muestra cómo cargar este conjunto de datos y convertirlo en un DataFrame de pandas para que sea más fácil trabajar con él:
#load iris dataset iris = datasets. load_iris () #convertir conjunto de datos a pandas DataFrame df = pd.DataFrame (data = np.c_ [iris [' data '], iris [' target ']], columnas = iris [' feature_names '] + [' target ']) df [' especie '] = pd. Categórico . from_codes (iris.target, iris.target_names) df.columns = [' s_length ', ' s_width ', ' p_length ', ' p_width ', ' target ', ' especie '] #ver las primeras seis filas de DataFrame df. cabeza () s_length s_width p_length p_width especies de destino 0 5,1 3,5 1,4 0,2 0,0 setosa 1 4,9 3,0 1,4 0,2 0,0 setosa 2 4,7 3,2 1,3 0,2 0,0 setosa 3 4,6 3,1 1,5 0,2 0,0 setosa 4 5,0 3,6 1,4 0,2 0,0 setosa # encontrar cuántas observaciones totales hay en el conjunto de datos len ( índice df. ) 150
Podemos ver que el conjunto de datos contiene 150 observaciones en total.
Para este ejemplo, crearemos un modelo de análisis discriminante lineal para clasificar a qué especie pertenece una flor determinada.
Usaremos las siguientes variables predictoras en el modelo:
- Longitud del sépalo
- Ancho del sépalo
- Longitud del pétalo
- Ancho del pétalo
Y los usaremos para predecir la variable de respuesta Species , que toma las siguientes tres clases potenciales:
- setosa
- versicolor
- virginica
Paso 3: ajuste el modelo LDA
A continuación, ajustaremos el modelo LDA a nuestros datos usando la función LinearDiscriminantAnalsyis de sklearn:
#definir variables de predicción y respuesta X = df [[' s_length ', ' s_width ', ' p_length ', ' p_width ']] y = df [' especie '] #Ajuste el modelo del modelo LDA = LinearDiscriminantAnalysis () modelo. encajar (X, y)
Paso 4: use el modelo para hacer predicciones
Una vez que hemos ajustado el modelo usando nuestros datos, podemos evaluar qué tan bien se desempeñó el modelo usando la validación cruzada estratificada repetida de k-veces.
Para este ejemplo, usaremos 10 pliegues y 3 repeticiones:
#Define el método para evaluar el modelo cv = RepeatedStratifiedKFold (n_splits = 10 , n_repeats = 3 , random_state = 1 ) #evaluar modelo puntuaciones = cross_val_score (modelo, X, y, puntuación = ' precisión ', cv = cv, n_jobs = -1) imprimir (np. media (puntuaciones)) 0.9777777777777779
Podemos ver que el modelo obtuvo una precisión media del 97,78% .
También podemos usar el modelo para predecir a qué clase pertenece una nueva flor, en función de los valores de entrada:
#definir nueva observación nueva = [5, 3, 1, .4] # predice a qué clase pertenece la nueva observación del modelo. predecir ([nuevo]) matriz (['setosa'], dtype = '<U10')
Podemos ver que el modelo predice que esta nueva observación pertenece a la especie llamada setosa .
Paso 5: Visualice los resultados
Por último, podemos crear una gráfica LDA para ver los discriminantes lineales del modelo y visualizar qué tan bien separó las tres especies diferentes en nuestro conjunto de datos:
# definir datos para trazar X = iris.data y = iris.target modelo = LinearDiscriminantAnalysis () diagrama_datos = modelo. encajar (X, y). transformar (X) target_names = iris. target_names #create LDA plot plt. figura () colores = [' rojo ', ' verde ', ' azul '] lw = 2 para color, i, target_name en zip (colores, [0, 1, 2], target_names): plt. scatter (data_plot [y == i, 0], data_plot [y == i, 1], alpha = .8, color = color, label = target_name) #add leyenda para trazar plt. leyenda (loc = ' mejor ', sombra = Falso , puntos de dispersión = 1) #display LDA plot plt. mostrar ()
Puede encontrar el código Python completo utilizado en este tutorial aquí .
- https://r-project.org
- https://www.python.org/
- https://www.stata.com/
¿Te hemos ayudado?
Ayudanos ahora tú, dejanos un comentario de agradecimiento, nos ayuda a motivarnos y si te es viable puedes hacer una donación:La ayuda no cuesta nada
Por otro lado te rogamos que compartas nuestro sitio con tus amigos, compañeros de clase y colegas, la educación de calidad y gratuita debe ser difundida, recuerdalo: