Los modelos de aprendizaje automático pueden fallar cuando intentan hacer predicciones para personas que estaban subrepresentadas en los conjuntos de datos en los que fueron entrenados.
Por ejemplo, un modelo que predice la mejor opción de tratamiento para alguien con una enfermedad crónica puede entrenarse utilizando un conjunto de datos que contenga principalmente pacientes masculinos. Ese modelo podría hacer predicciones incorrectas para las pacientes femeninas cuando estén internadas en un hospital.
Para mejorar los resultados, los ingenieros pueden intentar equilibrar el conjunto de datos de entrenamiento eliminando puntos de datos hasta que todos los subgrupos estén representados por igual. Si bien el equilibrio del conjunto de datos es prometedor, a menudo requiere eliminar una gran cantidad de datos, lo que perjudica el rendimiento general del modelo.
Los investigadores del MIT desarrollaron una nueva técnica que identifica y elimina puntos específicos en un conjunto de datos de entrenamiento que más contribuyen a las fallas de un modelo en subgrupos minoritarios. Al eliminar muchos menos puntos de datos que otros enfoques, esta técnica mantiene la precisión general del modelo y al mismo tiempo mejora su rendimiento con respecto a los grupos subrepresentados.
Además, la técnica puede identificar fuentes ocultas de sesgo en un conjunto de datos de entrenamiento que carece de etiquetas. Los datos sin etiquetar son mucho más frecuentes que los datos etiquetados para muchas aplicaciones.
Este método también podría combinarse con otros enfoques para mejorar la equidad de los modelos de aprendizaje automático implementados en situaciones de alto riesgo. Por ejemplo, algún día podría ayudar a garantizar que los pacientes subrepresentados no sean diagnosticados erróneamente debido a un modelo de IA sesgado.
“Muchos otros algoritmos que intentan abordar este problema asumen que cada punto de datos es tan importante como cualquier otro. En este artículo, demostramos que esa suposición no es cierta. Hay puntos específicos en nuestro conjunto de datos que contribuyen a este sesgo, y podemos encontrar esos puntos de datos, eliminarlos y obtener un mejor rendimiento”, dice Kimia Hamidieh, estudiante de posgrado en ingeniería eléctrica e informática (EECS) en el MIT y compañía. -Autor principal de un artículo sobre esta técnica.
Escribió el artículo con los coautores principales Saachi Jain PhD ’24 y su compañero estudiante graduado de EECS Kristian Georgiev; Andrew Ilyas MEng ’18, PhD ’23, miembro Stein de la Universidad de Stanford; y los autores principales Marzyeh Ghassemi, profesor asociado en EECS y miembro del Instituto de Ciencias de Ingeniería Médica y del Laboratorio de Sistemas de Información y Decisión, y Aleksander Madry, profesor de Sistemas de Diseño de Cadencia en el MIT. La investigación se presentará en la Conferencia sobre Sistemas de Procesamiento de Información Neural.
Eliminando malos ejemplos
A menudo, los modelos de aprendizaje automático se entrenan utilizando enormes conjuntos de datos recopilados de muchas fuentes en Internet. Estos conjuntos de datos son demasiado grandes para ser cuidadosamente seleccionados a mano, por lo que pueden contener malos ejemplos que perjudican el rendimiento del modelo.
Los científicos también saben que algunos puntos de datos afectan más que otros el rendimiento de un modelo en determinadas tareas posteriores.
Los investigadores del MIT combinaron estas dos ideas en un enfoque que identifica y elimina estos puntos de datos problemáticos. Buscan resolver un problema conocido como error del peor grupo, que ocurre cuando un modelo tiene un rendimiento inferior en subgrupos minoritarios en un conjunto de datos de entrenamiento.
La nueva técnica de los investigadores está impulsada por trabajos anteriores en los que introdujeron un método, llamado TRAK, que identifica los ejemplos de entrenamiento más importantes para un resultado de modelo específico.
Para esta nueva técnica, toman predicciones incorrectas que hizo el modelo sobre subgrupos minoritarios y usan TRAK para identificar qué ejemplos de entrenamiento contribuyeron más a esa predicción incorrecta.
«Al agregar esta información de las predicciones de pruebas incorrectas de la manera correcta, podemos encontrar las partes específicas del entrenamiento que están reduciendo la precisión general del peor grupo», explica Ilyas.
Luego eliminan esas muestras específicas y vuelven a entrenar el modelo con los datos restantes.
Dado que tener más datos generalmente produce un mejor rendimiento general, eliminar solo las muestras que generan peores fallas en los grupos mantiene la precisión general del modelo y al mismo tiempo mejora su rendimiento en los subgrupos minoritarios.
Un enfoque más accesible
En tres conjuntos de datos de aprendizaje automático, su método superó a múltiples técnicas. En un caso, aumentó la precisión del peor grupo y al mismo tiempo eliminó alrededor de 20.000 muestras de entrenamiento menos que un método de equilibrio de datos convencional. Su técnica también logró una mayor precisión que los métodos que requieren realizar cambios en el funcionamiento interno de un modelo.
Debido a que el método MIT implica cambiar un conjunto de datos, sería más fácil de usar para un profesional y se puede aplicar a muchos tipos de modelos.
También se puede utilizar cuando se desconoce el sesgo porque los subgrupos de un conjunto de datos de entrenamiento no están etiquetados. Al identificar los puntos de datos que contribuyen más a una característica que el modelo está aprendiendo, pueden comprender las variables que utiliza para hacer una predicción.
“Esta es una herramienta que cualquiera puede utilizar cuando entrena un modelo de aprendizaje automático. Pueden observar esos puntos de datos y ver si están alineados con la capacidad que están tratando de enseñar al modelo”, dice Hamidieh.
Usar la técnica para detectar sesgos de subgrupos desconocidos requeriría intuición sobre qué grupos buscar, por lo que los investigadores esperan validarlo y explorarlo más completamente a través de futuros estudios en humanos.
También quieren mejorar el rendimiento y la confiabilidad de su técnica y garantizar que el método sea accesible y fácil de usar para los profesionales que algún día podrían implementarlo en entornos del mundo real.
«Cuando tienes herramientas que te permiten observar críticamente los datos y descubrir qué puntos de datos conducirán a sesgos u otros comportamientos indeseables, te da un primer paso hacia la construcción de modelos que serán más justos y más confiables». dice Ilyas.
Este trabajo está financiado, en parte, por la Fundación Nacional de Ciencias y la Agencia de Proyectos de Investigación Avanzada de Defensa de EE. UU.