top of page

Skin Disease Classification Using Deep Learning by Implementing MobileNet, a Pre-trained Model – Part 2


In the previous blog post, we discussed how to build a Convolutional Neural Network (CNN) from scratch for skin disease classification using the HAM10000 dataset. We covered the steps from setting up the dataset, creating training and validation sets, augmenting the data, and training a CNN model. The CNN architecture was designed with several convolutional layers, followed by max-pooling and dropout layers, which helped in learning the intricate features of skin disease images. The model was then evaluated using various metrics to ensure its effectiveness.


In this second part of the blog series, we will delve into a more advanced approach by implementing a pre-trained model, MobileNet, for skin disease classification. MobileNet is a lightweight deep learning model that is highly efficient and well-suited for mobile and edge devices. We will explore how to modify MobileNet for our specific classification task, train it, and evaluate its performance.


Overview of MobileNet

MobileNet is a deep learning model developed by Google, designed to perform well on mobile and embedded vision applications. It is built on depthwise separable convolutions, which significantly reduce the number of parameters and computation, making it faster and more efficient compared to traditional CNNs.


What Will Be Covered in This Blog


In this blog, we will cover the following steps:


  1. Loading and Preparing the Data: Setting up the dataset for training and validation using ImageDataGenerator.

  2. Modifying MobileNet: Customizing the MobileNet architecture to fit our skin disease classification task.

  3. Training the Model: Training the modified MobileNet with appropriate class weights and callbacks.

  4. Evaluating the Model: Assessing the model's performance on the validation set.

  5. Visualizing Results: Plotting training curves, confusion matrices, and generating classification reports.


Step-by-Step Implementation

1. Import Libraries

from numpy.random import seed
seed(101)
from tensorflow import set_random_seed
set_random_seed(101)

import pandas as pd
import numpy as np

import tensorflow
from tensorflow.keras.layers import Dense, Dropout
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.metrics import categorical_crossentropy
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint

import os

from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split
import itertools
import shutil
import matplotlib.pyplot as plt
# %matplotlib inline

from tensorflow.keras.metrics import categorical_accuracy, top_k_categorical_accuracy
import matplotlib.pyplot as plt

2. Loading and Preparing the Data

We start by setting up the data generators for training, validation, and testing. The ImageDataGenerator is used to apply preprocessing specific to MobileNet.


from tensorflow.keras.preprocessing.image import ImageDataGenerator

datagen = ImageDataGenerator(
    preprocessing_function=tensorflow.keras.applications.mobilenet.preprocess_input)

train_batches = datagen.flow_from_directory(train_path,
                                            target_size=(image_size, image_size),
                                            batch_size=train_batch_size)

valid_batches = datagen.flow_from_directory(valid_path,
                                            target_size=(image_size, image_size),
                                            batch_size=val_batch_size)

test_batches = datagen.flow_from_directory(valid_path,
                                            target_size=(image_size, image_size),
                                            batch_size=1,
                                            shuffle=False)

Here, the preprocessing_function parameter applies the necessary preprocessing required by the MobileNet model, such as scaling the pixel values to the range required by the model.


Output :


2. Modifying MobileNet

Next, we load the pre-trained MobileNet model and modify it for our classification task. The original MobileNet model is designed for general image classification tasks with 1000 classes. We adapt it to our 7-class skin disease classification by adding custom layers.

mobile = tensorflow.keras.applications.mobilenet.MobileNet()

x = mobile.layers[-6].output
x = Dropout(0.25)(x)
predictions = Dense(7, activation='softmax')(x)

model = Model(inputs=mobile.input, outputs=predictions)

for layer in model.layers[:-23]:
    layer.trainable = False

model.summary()

Here, we:

  • Add a Dropout Layer: This layer helps prevent overfitting by randomly setting a fraction of input units to 0 at each update during training.

  • Add a Dense Layer: This final dense layer has 7 units with a softmax activation function, corresponding to the 7 classes in our dataset.

  • Freeze Initial Layers: We freeze the first 23 layers of the model to retain the pre-trained weights and focus on training the new layers.


Output :


3. Training the Model

We compile the model with an Adam optimizer and set up custom metrics for evaluating top-k categorical accuracy. The model is then trained using class weights to handle class imbalance and a set of callbacks to save the best model and reduce the learning rate when the validation accuracy plateaus.

from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau

def top_3_accuracy(y_true, y_pred):
    return top_k_categorical_accuracy(y_true, y_pred, k=3)

def top_2_accuracy(y_true, y_pred):
    return top_k_categorical_accuracy(y_true, y_pred, k=2)

model.compile(Adam(lr=0.01), loss='categorical_crossentropy',
              metrics=[categorical_accuracy, top_2_accuracy, top_3_accuracy])

class_weights = {
    0: 1.0,  # akiec
    1: 1.0,  # bcc
    2: 1.0,  # bkl
    3: 1.0,  # df
    4: 3.0,  # mel (Make the model more sensitive to Melanoma)
    5: 1.0,  # nv
    6: 1.0   # vasc
}

filepath = "model.h5"
checkpoint = ModelCheckpoint(filepath, monitor='val_top_3_accuracy', verbose=1,
                             save_best_only=True, mode='max')

reduce_lr = ReduceLROnPlateau(monitor='val_top_3_accuracy', factor=0.5, patience=2,
                              verbose=1, mode='max', min_lr=0.00001)

callbacks_list = [checkpoint, reduce_lr]

history = model.fit_generator(train_batches, steps_per_epoch=train_steps,
                              class_weight=class_weights,
                              validation_data=valid_batches,
                              validation_steps=val_steps,
                              epochs=30, verbose=1,
                              callbacks=callbacks_list)
  • Class Weights: Class weights are applied to make the model more sensitive to underrepresented classes like Melanoma.

  • Callbacks: The ModelCheckpoint saves the model with the best top-3 accuracy on the validation set, and ReduceLROnPlateau reduces the learning rate when the validation accuracy plateaus.

Output :


4. Evaluating the Model

After training, the model is evaluated on the validation set to check its performance.

val_loss, val_cat_acc, val_top_2_acc, val_top_3_acc = \
model.evaluate_generator(test_batches, steps=len(df_val))

print('val_loss:', val_loss)
print('val_cat_acc:', val_cat_acc)
print('val_top_2_acc:', val_top_2_acc)
print('val_top_3_acc:', val_top_3_acc)

model.load_weights('model.h5')

val_loss, val_cat_acc, val_top_2_acc, val_top_3_acc = \
model.evaluate_generator(test_batches, steps=len(df_val))

print('val_loss:', val_loss)
print('val_cat_acc:', val_cat_acc)
print('val_top_2_acc:', val_top_2_acc)
print('val_top_3_acc:', val_top_3_acc)

The model's performance is evaluated based on categorical accuracy and top-k accuracy metrics, which are crucial for understanding how well the model is distinguishing between similar classes.


Output :


5. Visualizing Results

We visualize the training process by plotting the training and validation loss and accuracy over the epochs.

import matplotlib.pyplot as plt

acc = history.history['categorical_accuracy']
val_acc = history.history['val_categorical_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
train_top2_acc = history.history['top_2_accuracy']
val_top2_acc = history.history['val_top_2_accuracy']
train_top3_acc = history.history['top_3_accuracy']
val_top3_acc = history.history['val_top_3_accuracy']
epochs = range(1, len(acc) + 1)

plt.plot(epochs, loss, 'bo', label='Training loss')
plt.plot(epochs, val_loss, 'b', label='Validation loss')
plt.title('Training and validation loss')
plt.legend()
plt.figure()

plt.plot(epochs, acc, 'bo', label='Training cat acc')
plt.plot(epochs, val_acc, 'b', label='Validation cat acc')
plt.title('Training and validation cat accuracy')
plt.legend()
plt.figure()

plt.plot(epochs, train_top2_acc, 'bo', label='Training top2 acc')
plt.plot(epochs, val_top2_acc, 'b', label='Validation top2 acc')
plt.title('Training and validation top2 accuracy')
plt.legend()
plt.figure()

plt.plot(epochs, train_top3_acc, 'bo', label='Training top3 acc')
plt.plot(epochs, val_top3_acc, 'b', label='Validation top3 acc')
plt.title('Training and validation top3 accuracy')
plt.legend()

plt.show()

These plots provide insights into how the model's performance improves over time and help identify any overfitting or underfitting.


Output :



6. Confusion Matrix and Classification Report

Finally, we generate a confusion matrix and classification report to evaluate the model's performance across all classes.

from sklearn.metrics import confusion_matrix, classification_report

predictions = model.predict_generator(test_batches, steps=len(df_val), verbose=1)

cm = confusion_matrix(test_batches.classes, predictions.argmax(axis=1))

def plot_confusion_matrix(cm, classes,
                          normalize=False,
                          title='Confusion matrix',
                          cmap=plt.cm.Blues):
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')
    print(cm)

    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.tight_layout()

cm_plot_labels = ['akiec', 'bcc', 'bkl', 'df', 'mel','nv', 'vasc']
plot_confusion_matrix(cm, cm_plot_labels, title='Confusion Matrix')

y_pred = np.argmax(predictions, axis=1)
y_true = test_batches.classes

report = classification_report(y_true, y_pred, target_names=cm_plot_labels)
print(report)

This matrix and report help us understand which classes the model is performing well on and where it might be confusing certain skin conditions.


Output :


7. Saving the Model

Finally, we save the trained model and its architecture for future use.

model_json = model.to_json()
with open("model.json", "w") as json_file:
    json_file.write(model_json)
model.save_weights("model1.h5")
print("Saved model to disk")

In this blog post, we leveraged the power of MobileNet, a lightweight and efficient model, to classify skin diseases. By fine-tuning the model for our specific dataset and using advanced techniques like class weighting and callbacks, we achieved strong performance. This approach demonstrates how pre-trained models can be adapted to specialized tasks with minimal computational resources, making them highly valuable in medical imaging and other domains requiring efficient, real-time processing.


If you require any assistance with this project or Machine Learning projects, please do not hesitate to contact us. We have a team of experienced developers who specialize in Machine Learning and can provide you with the necessary support and expertise to ensure the success of your project. You can reach us through our website or by contacting us directly via email or phone.


Comments


bottom of page