| title | emoji | colorFrom | colorTo | sdk | sdk_version | app_file | pinned | short_description | license |
|---|---|---|---|---|---|---|---|---|---|
Eye Disease Detection Models |
🛕🛕 |
green |
blue |
gradio |
5.29.0 |
gradio-inference.py |
true |
Eye disease detection using deep learning models |
apache-2.0 |
This repository contains a Gradio web application for eye disease detection using deep learning models. The application allows users to upload fundus photographs and get predictions for common eye conditions.
- Easy-to-use web interface for eye disease detection
- Support for multiple model architectures (MobileNetV4, LeViT, EfficientViT, GENet, RegNetX)
- Custom model loading from saved model checkpoints
- Visualization of prediction probabilities
- Attention heatmap visualization using GradCAM to show which regions the model focuses on
- Dockerized deployment option
The system can detect the following eye conditions:
- Central Serous Chorioretinopathy
- Diabetic Retinopathy
- Disc Edema
- Glaucoma
- Healthy (normal eye)
- Macular Scar
- Myopia
- Retinal Detachment
- Retinitis Pigmentosa
- Python 3.12+
- PyTorch 2.7.0+
- CUDA-compatible GPU (optional, but recommended for faster inference)
-
Clone this repository:
git clone https://github.com/GilbertKrantz/eye-disease-detection.git cd eye-disease-detection -
Install the required packages:
pip install -r requirements.txt
-
Run the application:
python gradio_inference.py
-
Open your browser and go to http://localhost:7860
-
Build the Docker image:
docker build -t eye-disease-detection . -
Run the container:
docker run -p 7860:7860 eye-disease-detection
-
Open your browser and go to http://localhost:7860
- Upload a fundus image of the eye
- (Optional) Specify the path to your trained model file (.pth)
- Select the model architecture (MobileNetV4, LeViT, EfficientViT, GENet, RegNetX)
- Click "Analyze Image" to get the prediction
- View the results including:
- Probability distribution across all disease classes
- Attention heatmap showing which regions the model focused on for its prediction
The attention heatmap is generated using GradCAM (Gradient-weighted Class Activation Mapping), which visualizes the regions of the fundus image that the model considers most important for making its prediction:
- Red/Yellow areas: Regions the model focuses on most strongly
- Blue/Green areas: Regions with less influence on the prediction
This visualization helps in:
- Understanding the model's decision-making process
- Validating that the model is looking at clinically relevant features
- Building trust in the AI's predictions by making them interpretable
This repository focuses on inference. For training your own models, refer to the main training script and follow these steps:
- Prepare your dataset in the required directory structure
- Train a model using the main.py script:
python main.py --train-dir "/path/to/training/data" --eval-dir "/path/to/eval/data" --model mobilenetv4 --epochs 20 --save-model "my_model.pth"
- Use the saved model with the inference application
.
├── gradio_inference.py # Main Gradio application
├── requirements.txt # Python dependencies
├── Dockerfile # Docker configuration
├── README.md # This documentation
├── utils/ # Utility modules
│ ├── ModelCreator.py # Model architecture definitions
│ ├── Evaluator.py # Model evaluation utilities
│ ├── DatasetHandler.py # Dataset handling utilities
│ ├── Trainer.py # Model training utilities
│ └── Callback.py # Training callbacks
└── main.py # Main training script
The performance of the models depends on the quality of training data and the specific architecture used. In general, these models can achieve accuracy rates of 85-95% on standard eye disease datasets.
You can customize the application in several ways:
- Add example images in the Gradio interface
- Extend the list of supported classes by modifying the CLASSES variable in gradio_inference.py
- Add support for additional model architectures in ModelCreator.py
This project is licensed under the Apache License 2.0 - see the LICENSE file for details.
- The models are built using PyTorch and the TIMM library
- The web interface is built using Gradio
- Special thanks to the open-source community for making this project possible