Are you interested in building a web application that can diagnose plant diseases using AI? Imagine uploading a photo of a plant leaf and receiving an instant diagnosis, indicating whether the plant is healthy or suffering from a disease. In this tutorial, we will guide you through the process of deploying a pre-trained plant disease detection model using TensorFlow and Flask. This step-by-step guide will walk you through the code, explaining each part, so you can create your own AI-powered web application.
Introduction
With advancements in AI, we can now deploy sophisticated machine learning models to the web, making powerful tools accessible to anyone with an internet connection. In this tutorial, we’ll deploy a plant disease detection model trained on images of various plant diseases. The model is served using Flask, a lightweight web framework in Python, which allows us to build a web application that users can interact with.
By the end of this guide, you'll have a fully functional web app that can predict plant diseases based on uploaded images. This application can be extremely valuable for farmers, gardeners, or anyone interested in plant health.
What You Will Learn
In this tutorial, you will learn:
How to set up a Flask web application.
How to load a pre-trained TensorFlow model.
How to process and predict plant diseases from uploaded images.
How to display the prediction results in a user-friendly format.
Prerequisites
Before we begin, make sure you have the following:
Basic understanding of Python and TensorFlow.
Flask installed in your Python environment (pip install flask).
TensorFlow installed (pip install tensorflow).
A pre-trained model (PlantDNet.h5).
Setting Up the Environment
First, ensure you have Flask and TensorFlow installed in your environment. You can install these packages using pip:
pip install flask tensorflow
Once these dependencies are installed, you’re ready to dive into the code.
Understanding the Code
Let's break down the code into its core components to understand how the application works.
1. Importing Required Libraries
import os
import tensorflow as tf
import numpy as np
from tensorflow import keras
from skimage import io
from tensorflow.keras.preprocessing import image
from flask import Flask, redirect, url_for, request, render_template
from werkzeug.utils import secure_filename
from gevent.pywsgi import WSGIServer
Here, we import the necessary libraries:
os: For interacting with the operating system, such as handling file paths.
tensorflow and keras: For loading the pre-trained model and processing images.
flask: For setting up the web application.
werkzeug.utils: For securely handling file uploads.
gevent.pywsgi: For serving the Flask application.
2. Initializing the Flask Application
app = Flask(__name__)
This line initializes the Flask application, creating an instance of the Flask class.
3. Loading the Pre-Trained Model
model = tf.keras.models.load_model('PlantDNet.h5', compile=False)
print('Model loaded. Check http://127.0.0.1:5000/')
We load the pre-trained model using load_model from TensorFlow. The model file PlantDNet.h5 should be placed in the same directory as the script. Disabling compilation (compile=False) speeds up the loading process since we won't be training the model, only using it for inference.
4. Defining the Prediction Function
def model_predict(img_path, model):
img = image.load_img(img_path, grayscale=False, target_size=(64, 64))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = np.array(x, 'float32')
x /= 255
preds = model.predict(x)
return preds
The model_predict function processes the uploaded image and makes a prediction:
image.load_img: Loads the image from the specified path, resizing it to 64x64 pixels.
image.img_to_array: Converts the image to a numpy array.
np.expand_dims: Adds an extra dimension to the array to match the input shape expected by the model.
model.predict: Feeds the processed image to the model, returning the prediction.
5. Setting Up the Main Route
@app.route('/', methods=['GET'])
def index():
return render_template('index.html')
The index function handles GET requests to the root URL. It renders an HTML template (index.html), which serves as the main page of the application where users can upload images.
6. Handling Image Uploads and Predictions
@app.route('/predict', methods=['GET', 'POST'])
def upload():
if request.method == 'POST':
# Get the file from post request
f = request.files['file']
# Save the file to ./uploads
basepath = os.path.dirname(__file__)
file_path = os.path.join(basepath, 'uploads', secure_filename(f.filename))
f.save(file_path)
# Make prediction
preds = model_predict(file_path, model)
print(preds[0])
disease_class = ['Pepper__bell___Bacterial_spot', 'Pepper__bell___healthy', 'Potato___Early_blight',
'Potato___Late_blight', 'Potato___healthy', 'Tomato_Bacterial_spot', 'Tomato_Early_blight',
'Tomato_Late_blight', 'Tomato_Leaf_Mold', 'Tomato_Septoria_leaf_spot',
'Tomato_Spider_mites_Two_spotted_spider_mite', 'Tomato__Target_Spot',
'Tomato__Tomato_YellowLeaf__Curl_Virus', 'Tomato__Tomato_mosaic_virus', 'Tomato_healthy']
a = preds[0]
ind = np.argmax(a)
print('Prediction:', disease_class[ind])
result = disease_class[ind]
return result
return None
The upload function handles POST requests to the /predict route. When a user uploads an image:
The image file is saved securely to the uploads directory.
The image is passed to the model_predict function, which returns prediction probabilities for each class.
The class with the highest probability is selected as the predicted disease, and the result is returned to the user.
The disease_class list contains the names of all the classes the model can predict. The np.argmax(a) function finds the index of the highest probability in the prediction array, which corresponds to the predicted class.
Lets Understand Line by code :
Route Definition and Function Setup
@app.route('/predict', methods=['GET', 'POST'])
def upload():
@app.route('/predict', methods=['GET', 'POST']): This line is a route decorator in Flask that maps the URL path /predict to the upload function. The route supports both GET and POST HTTP methods, but the primary focus here is on handling POST requests since that's when file uploads and predictions occur.
def upload():: This line defines the upload function, which is executed when a request is made to the /predict route.
Handling POST Requests
if request.method == 'POST':
if request.method == 'POST':: This conditional checks if the request method is POST. If it is, the code inside the block will be executed. POST requests are typically used to submit data to the server, such as a file in this case.
File Upload and Saving
# Get the file from post request
f = request.files['file']
# Save the file to ./uploads
basepath = os.path.dirname(__file__)
file_path = os.path.join(basepath, 'uploads', secure_filename(f.filename))
f.save(file_path)
f = request.files['file']: This line retrieves the file from the POST request. The request.files object is a dictionary-like object where the key is the name of the file input field in the HTML form (e.g., <input type="file" name="file">). The file itself is stored in f.
basepath = os.path.dirname(__file__): This line determines the base directory path of the current script. The os.path.dirname(__file__) function returns the directory name of the file where this code is running.
file_path = os.path.join(basepath, 'uploads', secure_filename(f.filename)): Here, the code constructs the full path where the uploaded file will be saved. The os.path.join() function concatenates the base directory, the 'uploads' directory, and the filename to create this path. The secure_filename() function ensures that the filename is safe to use in different operating systems by removing or replacing any potentially dangerous characters.
f.save(file_path): This line saves the uploaded file to the specified path on the server. The file is stored in the 'uploads' directory within the base path.
Making a Prediction
# Make prediction
preds = model_predict(file_path, model)
print(preds[0])
preds = model_predict(file_path, model): After saving the file, the code calls a function named model_predict(), passing the file path and the pre-loaded model as arguments. This function likely processes the image and uses the model to generate predictions about the disease affecting the plant in the image.
print(preds[0]): This line prints the raw prediction results (the first element in preds) to the server console. This is often used for debugging purposes to verify what the model is predicting before it's translated into a more understandable format.
Interpreting and Returning the Prediction
disease_class = ['Pepper__bell___Bacterial_spot', 'Pepper__bell___healthy', 'Potato___Early_blight',
'Potato___Late_blight', 'Potato___healthy', 'Tomato_Bacterial_spot', 'Tomato_Early_blight',
'Tomato_Late_blight', 'Tomato_Leaf_Mold', 'Tomato_Septoria_leaf_spot',
'Tomato_Spider_mites_Two_spotted_spider_mite', 'Tomato__Target_Spot',
'Tomato__Tomato_YellowLeaf__Curl_Virus', 'Tomato__Tomato_mosaic_virus', 'Tomato_healthy']
a = preds[0]
ind = np.argmax(a)
print('Prediction:', disease_class[ind])
result = disease_class[ind]
disease_class = [...]: This list contains the possible classes or labels that the model can predict. Each entry corresponds to a specific plant disease or a healthy state for different crops (e.g., pepper, potato, tomato). The position of each label in the list corresponds to the index in the prediction array returned by the model.
a = preds[0]: This line extracts the first (and possibly only) set of predictions from the preds array. The prediction array a contains the probabilities or confidence scores that the image belongs to each class in disease_class.
ind = np.argmax(a): The np.argmax() function returns the index of the highest value in the array a. This index corresponds to the class label in disease_class that the model predicts as the most likely.
print('Prediction:', disease_class[ind]): This line prints the predicted disease class to the server console, allowing you to see which disease (or healthy state) the model has identified.
result = disease_class[ind]: The predicted disease class is stored in the variable result, which will be returned to the client.
Returning the Result
return result
return result: The final prediction result is returned as the response to the POST request. This result could be displayed on the web page, sent as part of a JSON response, or used in other ways depending on how the rest of the application is set up
Handling Non-POST Requests
return None
return None: If the request method is not POST (e.g., if the user accessed the route via a GET request without uploading a file), the function simply returns None. This can be a fallback to ensure the function doesn't proceed with operations that require a file.
7. Running the Flask Application
if __name__ == '__main__':
# Serve the app with gevent
http_server = WSGIServer(('', 5000), app)
http_server.serve_forever()
Finally, the Flask app is served using WSGIServer, which is a high-performance web server suitable for production. The application listens on port 5000.
Building the Frontend
To complete the application, you’ll need to create the index.html file that allows users to upload images. Below is a simple example:
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Plant Disease Detection</title>
</head>
<body>
<h1>Upload a plant leaf image to detect disease</h1>
<form action="/predict" method="post" enctype="multipart/form-data">
<input type="file" name="file">
<input type="submit" value="Predict">
</form>
</body>
</html>
This HTML form allows users to select and upload an image file, which is then sent to the /predict route for processing.
Now we have successfully created a plant disease detection web application using
TensorFlow and Flask. This application can be expanded and integrated into larger systems for agriculture, where real-time plant disease detection could have a significant impact on crop management and yield.
You can modify and improve this basic framework by adding more features such as displaying the image along with the prediction, handling multiple uploads, or even deploying it to a cloud service for broader access.
Complete Code
import os
import tensorflow as tf
import numpy as np
from tensorflow import keras
from skimage import io
from tensorflow.keras.preprocessing import image
# Flask utils
from flask import Flask, redirect, url_for, request, render_template
from werkzeug.utils import secure_filename
from gevent.pywsgi import WSGIServer
# Define a flask app
app = Flask(__name__)
# Model saved with Keras model.save()
# You can also use pretrained model from Keras
# Check https://keras.io/applications/
model =tf.keras.models.load_model('PlantDNet.h5',compile=False)
print('Model loaded. Check http://127.0.0.1:5000/')
def model_predict(img_path, model):
img = image.load_img(img_path, grayscale=False, target_size=(64, 64))
show_img = image.load_img(img_path, grayscale=False, target_size=(64, 64))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = np.array(x, 'float32')
x /= 255
preds = model.predict(x)
return preds
@app.route('/', methods=['GET'])
def index():
# Main page
return render_template('index.html')
@app.route('/predict', methods=['GET', 'POST'])
def upload():
if request.method == 'POST':
# Get the file from post request
f = request.files['file']
# Save the file to ./uploads
basepath = os.path.dirname(__file__)
file_path = os.path.join(
basepath, 'uploads', secure_filename(f.filename))
f.save(file_path)
# Make prediction
preds = model_predict(file_path, model)
print(preds[0])
# x = x.reshape([64, 64]);
disease_class = ['Pepper__bell___Bacterial_spot', 'Pepper__bell___healthy', 'Potato___Early_blight',
'Potato___Late_blight', 'Potato___healthy', 'Tomato_Bacterial_spot', 'Tomato_Early_blight',
'Tomato_Late_blight', 'Tomato_Leaf_Mold', 'Tomato_Septoria_leaf_spot',
'Tomato_Spider_mites_Two_spotted_spider_mite', 'Tomato__Target_Spot',
'Tomato__Tomato_YellowLeaf__Curl_Virus', 'Tomato__Tomato_mosaic_virus', 'Tomato_healthy']
a = preds[0]
ind=np.argmax(a)
print('Prediction:', disease_class[ind])
result=disease_class[ind]
return result
return None
if __name__ == '__main__':
# app.run(port=5002, debug=True)
# Serve the app with gevent
http_server = WSGIServer(('', 5000), app)
http_server.serve_forever()
app.run()
Screenshots
2.
3.
4.
Project Demo Video
This tutorial demonstrates the power of combining machine learning with web technologies. With just a few lines of code, you can turn a pre-trained TensorFlow model into a fully functional web application. Whether for personal projects or real-world applications, the ability to deploy AI models to the web opens up endless possibilities.
Happy coding!
For the complete solution or any assistance with deploying a Plant Disease Detection Model with TensorFlow and Flask, feel free to contact us.
Comments