Mentoring in Machine Learning
Ever wonder what it’s like to be a research intern with Stanford’s Center for Artificial Intelligence in Medicine & Imaging (AIMI)?
We (Rebekah Westerlind, a Cornell University junior double majoring in Operations Research & Information Engineering and Kathryn Garcia, a Stanford University sophomore majoring in Computer Science) interned during Summer 2020 in the Lungren lab and AIMI Center. We built a convolutional neural network to better classify chest X-rays under the guidance of two PhD-students-turned-friends Anuj Pareek and Mars Huang.
The Stanford AIMI team had already developed a model with good performance detecting 14 different abnormalities in chest X-rays (you can read about it in this CheXnet paper), but its accuracy was lower than desired in detecting one particular condition: pneumothorax. Pneumothorax, or collapsed lung, is a visually subtle, but life-threatening condition. Since the consequences of missing a pneumothorax are drastic, we set to work building a model to improve classification accuracy of chest X-rays exhibiting this characteristic. We hoped to integrate our model into the original to improve its overall accuracy and its performance prioritizing radiologist reading lists in hospital settings.
Radiologists are tasked with reading hundreds of images a day. Because of the complexity of medical images, there are often long delays for patients to receive a diagnosis. Images can be flagged as urgent, for patients who are clearly in critical condition, but most often the priority list is organized according to the time the image was taken. This leaves some at-risk patients waiting in line while crucial moments pass. The CheXpert model was created to ameliorate this problem. First the model classifies a chest X-ray as healthy or unhealthy with varying degrees of severity. Then, the images are prioritized based on a weighted combination of the severity of the model’s prediction and how long that patient had been waiting for their diagnosis. In this way, critical conditions can be seen by radiologists more quickly. Radiologists can use this model as a tool to improve their efficiency. With this vision of improving medical outcomes in mind, we were eager to get to work.
Over the course of the next several months, we went through three stages of development. First, we familiarized ourselves with the tools we would be using throughout the project. Second, we created the building blocks of a convolutional neural network. Finally, we conducted experiments, fine tuned parameters, and dove deep into the world of debugging.
Stage 1: Familiarizing Ourselves with the Tools
We spent the first few weeks of this project laying the groundwork and familiarizing ourselves with the tools required to build a convolutional neural network. Our initial tasks were foundational and allowed us to explore popular tools and recent advancements in the field.
Coding style conventions and best practices
With each pull request (code submission) on GitHub, Mars gave both content and style feedback. As he shared commenting conventions and explained conda environments, we quickly learned how to make our code clean, readable, and reproducible. Here is a link to our GitHub repository for reference: https://github.com/marshuang80/pneumothorax-detection
Collaboration
From the beginning, it was necessary to collaborate in all stages from ideation to coding to debugging. Normally, we (Kathryn and Rebekah) would divide up the tasks to create an initial implementation of each feature. When we encountered difficult bugs, we would work together on a shared screen to resolve them. We also shared our codebase and kept track of revisions through GitHub.
Weekly book club
Every week, Mars and Anuj provided us with notable research papers to read and present at the group meeting. This was something we looked forward to every week, and enabled us to learn about various technical aspects of the field. At the start, we read foundational articles about different model architectures. Eventually, we covered topics like the use of contrastive learning in the training process and ethical issues related to the use of machine learning in healthcare. Listed below are some of the papers we particularly enjoyed:
- Gradient Based Learning Applied to Document Recognition - LeCun et al
An overview of LeNet5, one of the earliest convolutional neural networks, with a simple architecture. - Learning Deep Features for Discriminative Localization - Zhou et al
Builds on the previous research of global average pooling, and introduces Class Activation Mapping for weakly supervised networks. - Implementing Machine Learning in Health Care - Char et al
Considers the ethical challenges of introducing machine learning in healthcare, such as discrimination and racial biases from training data. - Supervised Contrastive Learning - Khosla et al
Details the novel performance achieved when using contrastive learning on a fully-supervised network.
Tools
We also invested time learning basic tools and frameworks, such as PyTorch, tensors and how to use multiple GPUs (DataParallel). To log our metrics, we used Tensorboard at first, but found that we preferred the Weights & Biases API, especially because it was easier to add custom charts and graphs. Later, we used PyTorch Lightning to organize our code.
Stage 2: Creating Our Building Blocks
In the second phase, we began setting up our model. There are many building blocks in a convolutional neural network, but we specifically customized our network architecture, loss function, image augmentations, data loader, and metric tracking. We added complexity to each of these building blocks as we experimented and found ways to improve our model’s performance. Here is a brief overview of what we did with each of these building blocks:
Network architecture
We built our own model by implementing transfer learning on several commonly used models. We did this by replacing the final classification layer with our own classification layer.
Customized loss function
We created a weighted loss function that would penalize misclassified pneumothorax cases more than other conditions. This forced the model to increase accuracy on pneumothorax cases in particular instead of just catching the majority of unhealthy cases and missing a key subgroup of one condition.
Custom image augmentations
We defined image augmentations such as rotation, resizing, cropping, Gaussian blur, and Gaussian noise. Many of these transformations are available through the Torchvision package, but we customized some of them for our specific purposes. Most notably, we combined the resizing and cropping in one transformation in order to be able to random crop for training batches and center crop for validation and testing batches.
Custom data loader
We made our data loader compatible with our custom dataset class. Defining the data loader in our own class also gave us complete control over when we would shuffle, which samplers we would use, and which transforms to apply to each batch. In particular, we defined our own sampler so that we could experiment with different sampling weights for the categories of healthy, unhealthy but not pneumothorax, and pneumothorax. This enabled us to monitor overfitting that was occurring due to having a limited number of cases in some of the categories.
Metric calculation and logging
We worked with both Tensorboard and Weights & Biases. In addition to the commonly tracked metrics (accuracy, ROC curve, F1 score, precision, sensitivity, etc), we tracked AUROC and AUPRC for several subgroups within the dataset. Each image was labeled for 13 different chest conditions by an NLP model on radiologist descriptions of the image. For each of these 13 conditions, or subgroups, we checked model performance. This was important because we found that the overall model could achieve a high degree of accuracy while still classifying most pneumothorax cases incorrectly, likely because the pneumothorax cases were such a small percentage of the overall dataset. By tracking the subgroup metrics, we were able to notice when inconsistencies and adjust our weighted loss function or sampler to compensate.
Stage 3: Experimentation and Debugging
With the basics in place, it was time for experimentation. Once we set up our training process, we began optimizing our model architecture, hyperparameters, and the types of images included in the dataset. As the need arose, we implemented custom features and experimented with those as well. Here is an overview of our process:
Initial experimentation
We started running binary classification to identify each x-ray as healthy or unhealthy. We tried out various combinations of
- architectures (AlexNet, VGG, DenseNets, ResNets, Inception)
- loss functions (BCE, BCEWithLogits, Categorical Cross Entropy, Focal Loss) and
- optimizers (Adam, AdamW, SGD, RMSprop).
We used Bayesian search methods to find the highest performing combination: DenseNet201 with BCEWithLogitsLoss and an AdamW optimizer. From there, we fine tuned various hyperparameters to improve performance.
General experimentation strategies
Once this initial experimentation was complete, we faced the more difficult task of coming up with creative changes to improve performance. One way to figure out potentially impactful changes is by looking at the predicted results in as many different ways as you can. We looked at prediction histograms, ROC curves, falsely classified images, raw labels/predictions, and more for each of the chest conditions. Tracking each of these metrics unlocked insights into what we could adjust. Some examples of these insights are discussed below.
Addressing overfitting
When comparing the training and validation loss, it became clear that the model was overfitting. In other words, the model was memorizing the training data instead of learning features it could then recognize in the validation data. This can be seen when the validation loss decreases and then plateaus while the training loss continues to decrease. To address this, we made three adjustments. We implemented a learning rate scheduler that automatically decreased the learning rate as training progressed to improve our accuracy and to avoid loss plateaus. We implemented a weighted sampler so that we could upsample certain conditions. This was helpful when conditions had so few images that the model quickly learned these exact images and instead of looking for similar images during validation, began looking for those exact images. We also began validating in the middle of each training epoch. This allowed us to save a model checkpoint from the middle of the epoch before the model began overfitting later in the epoch.
Addressing low scores in the pneumothorax condition
When comparing the metrics across various conditions, we noticed that the model was performing significantly worse on pneumothorax cases than any other condition. We hypothesized that our model was detecting breathing tubes (or support devices), rather than features of pneumothorax itself. Radiologists, such as Luke Oakden-Rayner, have also identified this issue with AI models. Read his article here.
To address this, we transitioned from binary classification to multilabel classification with three classes: healthy, pneumothorax identified, and support device identified. Since the model then provided predictions for each of these classes, we also implemented a weighted loss function that assigned a higher penalty to misclassifying the pneumothorax cases.
Checking with physicians
Because our project has direct clinical ramifications, we reviewed our results with a physician. We showed them several of the misclassified images and asked for their perspective on what might have been confusing the model. We learned that pneumothorax is extremely hard to identify, even by trained professionals, and some of the image augmentations typically used in machine learning had made a subset of our images unidentifiable. In response, we increased the image resolution and dropped the gaussian blur from the images. It was also important to drop the lateral images and solely focus on frontal images because pneumothoraces look different from different viewpoints.
After several rounds of experimentation, we found a model that performed well across all conditions, and singled out and detected pneumothorax with a high degree of accuracy. This model used multilabel classification, our own specified sampling weights, and an unweighted loss function.
Conclusion
This project provided an introduction to only some of the techniques available for machine learning. Going forward, we would like to develop an ensemble model by training several models and aggregating their results.
We’d like to thank Mars Huang, Anuj Pareek, and Matt Lungren for their investment of time and support. Together, they created a very effective mentoring framework for introducing undergraduate students to machine learning and enabling us to contribute meaningfully to a technical project. We’d like to thank the graduate students for bringing the team together in friendship. Mars, thank you for guiding us through the technical side of this project, being patient as we improved our coding fluency and turned your tips and tricks into habits, and for always being willing to hop on a call when we encountered particularly difficult bugs. Anuj, thank you for guiding us through the clinical and administrative portions of this project, selecting quality research papers for discussion, and providing a real world clinical perspective on our results, while dealing with a 9 hour time difference. Matt, thank you for believing us and giving us this chance to learn as members of your lab. We couldn’t have asked for better mentors!
Feel free to contact Rebekah Westerlind (rswesterlind@gmail.com) or Kathryn Garcia (kathrynvgarcia@gmail.com) with any questions or comments.
Please contact aimicenter@stanford.edu for questions regarding the Stanford Center for Artificial Intelligence in Medicine and Imaging. Students interested in Stanford AIMI internships can find application information here.