Plant Disease Recognition & XAI
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import os
import cv2
import zipfile
import random
import shutil
import torch
import torch.nn as nn
from torchvision import models, transforms
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.metrics import roc_curve, roc_auc_score
import seaborn as sns
1 - The Dataset¶
# Download and extract the dataset
dataset_url = 'https://prod-dcd-datasets-cache-zipfiles.s3.eu-west-1.amazonaws.com/tywbtsjrjv-1.zip'
dataset_zip = 'plant_disease_dataset.zip'
if not os.path.exists(dataset_zip):
os.system(f"wget {dataset_url} -O {dataset_zip}")
with zipfile.ZipFile(dataset_zip, 'r') as zip_ref:
zip_ref.extractall('plant_disease_dataset')
# Extract the nested zip (Plant_leaf_diseases_dataset_with_augmentation.zip)
inner_zip = 'plant_disease_dataset/Data for Identification of Plant Leaf Diseases Using a 9-layer Deep Convolutional Neural Network/Plant_leaf_diseases_dataset_with_augmentation.zip'
with zipfile.ZipFile(inner_zip, 'r') as zip_ref:
zip_ref.extractall('plant_disease_dataset')
# Define the paths for "apple-healthy" and "apple-black-rot"
base_dir = 'plant_disease_dataset/Plant_leave_diseases_dataset_with_augmentation'
classes = ['Apple___healthy', 'Apple___Black_rot']
# Create a subset directory
subset_dir = 'plant_disease_subset'
os.makedirs(subset_dir, exist_ok=True)
for cls in classes:
os.makedirs(os.path.join(subset_dir, cls), exist_ok=True)
# Copy a random sample of 100 images per class to the subset directory
sample_size = 100
for cls in classes:
cls_path = os.path.join(base_dir, cls)
subset_cls_path = os.path.join(subset_dir, cls)
images = os.listdir(cls_path)
sample_images = random.sample(images, sample_size)
for img in sample_images:
shutil.copy(os.path.join(cls_path, img), subset_cls_path)
# Display some sample images
def show_samples(subset_dir, classes):
fig, axes = plt.subplots(2, 5, figsize=(15, 6))
for i, cls in enumerate(classes):
cls_path = os.path.join(subset_dir, cls)
images = os.listdir(cls_path)
for j in range(5):
img_path = os.path.join(cls_path, images[j])
img = Image.open(img_path)
axes[i, j].imshow(img)
axes[i, j].axis('off')
axes[i, j].set_title(cls)
plt.show()
show_samples(subset_dir, classes)
2 - Pretrained Models for Image Recognition¶
Linear Solution¶
num_images_to_process = 2
def process_image(image_path):
image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
if image is None:
print(f"Impossible de charger l'image à partir de {image_path}. Assurez-vous que le chemin est correct.")
return None, None, None
# Seuillage global simple
_, thresh_global = cv2.threshold(image, 127, 255, cv2.THRESH_BINARY)
# Seuillage adaptatif
thresh_adaptive = cv2.adaptiveThreshold(image, 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY, 11, 2)
return image, thresh_global, thresh_adaptive
# Boucle sur un échantillon d'images pour chaque classe
for cls in classes:
cls_path = os.path.join(subset_dir, cls)
images = os.listdir(cls_path)
sample_images = random.sample(images, num_images_to_process)
for img_name in sample_images:
img_path = os.path.join(cls_path, img_name)
image, thresh_global, thresh_adaptive = process_image(img_path)
if image is not None:
plt.figure(figsize=(12, 6))
plt.subplot(1, 3, 1)
plt.imshow(image, cmap='gray')
plt.title('Image Originale')
plt.axis('off')
plt.subplot(1, 3, 2)
plt.imshow(thresh_global, cmap='gray')
plt.title('Seuillage Global')
plt.axis('off')
plt.subplot(1, 3, 3)
plt.imshow(thresh_adaptive, cmap='gray')
plt.title('Seuillage Adaptatif')
plt.axis('off')
plt.tight_layout()
plt.show()
The provided code snippet demonstrates the application of grayscale conversion and binary thresholding techniques using both global and adaptive methods. The goal is to segment objects from backgrounds in grayscale images. Global thresholding uses a fixed threshold value, while adaptive thresholding adjusts the threshold locally based on image neighborhoods. These methods are effective for basic segmentation tasks but may struggle with images containing shadows, uneven lighting, or complex shapes. Such factors can lead to inaccurate segmentations, highlighting the need for more advanced techniques in challenging scenarios.
num_images_to_process = 2
def process_image(image_path):
image = cv2.imread(image_path)
# Convertir l'image en niveau de gris
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
# Appliquer un flou gaussien pour réduire le bruit
blurred = cv2.GaussianBlur(gray, (5, 5), 0)
# Appliquer la détection de contours avec la méthode de Hough
edges = cv2.Canny(blurred, 50, 150, apertureSize=3)
lines = cv2.HoughLines(edges, 1, np.pi / 180, threshold=100)
# Dessiner les lignes détectées sur l'image originale
if lines is not None:
for rho, theta in lines[:, 0]:
a = np.cos(theta)
b = np.sin(theta)
x0 = a * rho
y0 = b * rho
x1 = int(x0 + 1000 * (-b))
y1 = int(y0 + 1000 * (a))
x2 = int(x0 - 1000 * (-b))
y2 = int(y0 - 1000 * (a))
cv2.line(image, (x1, y1), (x2, y2), (0, 0, 255), 2)
return image, gray, blurred, edges
# Boucle sur un échantillon d'images pour chaque classe
for cls in classes:
cls_path = os.path.join(subset_dir, cls)
images = os.listdir(cls_path)
sample_images = random.sample(images, num_images_to_process)
for img_name in sample_images:
img_path = os.path.join(cls_path, img_name)
image, gray, blurred, edges = process_image(img_path)
if image is not None:
# Affichage des résultats à l'aide de Matplotlib
plt.figure(figsize=(12, 6))
plt.subplot(2, 2, 1)
plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
plt.title('Original Image')
plt.subplot(2, 2, 2)
plt.imshow(gray, cmap='gray')
plt.title('Gray Image')
plt.subplot(2, 2, 3)
plt.imshow(blurred, cmap='gray')
plt.title('Blurred Image')
plt.subplot(2, 2, 4)
plt.imshow(edges, cmap='gray')
plt.title('Edges (Canny)')
plt.tight_layout()
plt.show()
The code demonstrates image preprocessing using Gaussian blur and edge detection techniques. Initially, each image is loaded and converted to grayscale to simplify processing. Gaussian blur is applied to reduce noise, followed by edge detection using the Canny method. Detected edges are then overlayed onto the original image using the Hough transform to draw lines. While edge detection techniques like Canny can effectively highlight boundaries in some cases, they may not perform optimally in scenarios with complex shapes or varying lighting conditions, as demonstrated in this example.
Non Linear Solution¶
ImageNet classification (pretrained neurale network)¶
ImageNet classification refers to the task of categorizing images into various predefined classes using the ImageNet dataset. The ImageNet dataset is a large visual database designed for use in visual object recognition research. It contains millions of images labeled across thousands of object categories.
- ImageNet Dataset:
Size and Scope: The dataset contains over 14 million images, annotated with information on over 20,000 categories. However, the most commonly used subset for classification tasks includes 1,000 categories.
Categories: These categories range from everyday objects like "dog" and "cat" to more specific items like "container ship" or "screwdriver."
- ImageNet Large Scale Visual Recognition Challenge (ILSVRC):
Purpose: ILSVRC is an annual competition that challenges teams to develop models that can classify and detect objects in images accurately.
Impact: The challenge has significantly advanced the field of computer vision, leading to the development of various groundbreaking deep learning architectures.
- Deep Learning Models:
Architectures: Many influential neural network architectures have been developed and tested using the ImageNet dataset. Some notable examples include:
AlexNet (2012): Pioneered deep learning for image classification.
VGG (2014): Introduced very deep networks with small convolutional filters.
GoogLeNet/Inception (2014): Proposed a novel architecture with parallel convolutional layers.
ResNet (2015): Introduced residual connections to allow very deep networks to train effectively.
EfficientNet (2019): Balanced network depth, width, and resolution for efficient scaling.
- Evaluation Metrics:
Top-1 Accuracy: The percentage of test images for which the correct label is the model's top predicted label.
Top-5 Accuracy: The percentage of test images for which the correct label is among the top five predicted labels.

VGG 16 MODEL¶
# Load the VGG-16 model
model = models.vgg16()
model_features = model.features
Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
Sequence of layers¶

Top layer Lower layer¶
Task-Specific Layers: The top layers are specific to the task and do not transfer well to new tasks. General Layers: The lower layers consist of more general features and are commonly used when transferring the network to a new task by removing the top layers.
variable models¶
Model Parts: The VGG-16 network available in PyTorch is divided into two parts: model.features: The first part of the model, containing the feature extraction layers. model.classifier: The top part of the model, containing the classification layers.
Function Φ(x)¶
The function Φ(x) refers to the mapping performed by the first part of the network (model.features). Subsequent analysis will be performed on the representation at the output of this function.
3 - Predicting Classes from Images (discriminating between two classes)¶
- Define the image transformation
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
the transform defined using transforms.Compose prepares the input images by resizing them to a specific size (224x224), converting them to tensors, and then normalizing them using predefined mean and standard deviation values. This standardized format ensures that the images are compatible with neural network models
- Custom Dataset class
class PlantDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.transform = transform
self.classes = os.listdir(root_dir)
self.image_paths = []
self.labels = []
for cls in self.classes:
cls_path = os.path.join(root_dir, cls)
for img in os.listdir(cls_path):
self.image_paths.append(os.path.join(cls_path, img))
self.labels.append(self.classes.index(cls))
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
img_path = self.image_paths[idx]
image = Image.open(img_path).convert('RGB')
label = self.labels[idx]
if self.transform:
image = self.transform(image)
return image, label
# Create the dataset and dataloader
dataset = PlantDataset(subset_dir, transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
- Extracting Features
def extract_features(dataloader, model_features, device):
model_features.eval()
features = []
labels = []
with torch.no_grad():
for images, lbls in dataloader:
images = images.to(device)
output = model_features(images)
output = torch.flatten(output, 1)
features.append(output.cpu().numpy())
labels.append(lbls.numpy())
features = np.concatenate(features)
labels = np.concatenate(labels)
return features, labels
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_features = model_features.to(device)
features, labels = extract_features(dataloader, model_features, device)
The extract_features function is designed to extract features from images using a pretrained neural network model (model_features). It handles the batching of data (images) and their labels (lbls) using a DataLoader, processes each batch through the model, and concatenates the resulting features and labels into numpy arrays for downstream tasks like training classifiers or conducting analysis. This function is essential for leveraging pre-trained models to extract meaningful representations of data for various machine learning tasks.
the difference of means¶
# Calculate the weight vector for the difference of means
# Split the dataset into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(features, labels, test_size=0.2, random_state=44)
# Compute mean feature vectors for each class
C1 = X_train[y_train == 0]
C2 = X_train[y_train == 1]
mu1 = np.mean(C1, axis=0)
mu2 = np.mean(C2, axis=0)
w = mu2 - mu1
w /= np.linalg.norm(w)
plt.figure(figsize=(8, 6))
plt.bar(range(len(w)), w)
plt.xlabel('Feature Index')
plt.ylabel('Weight')
plt.title('Weight Vector for Difference of Means')
plt.grid(True)
plt.show()
In this plot, each bar represents the weight of a feature component in the vector w. Positive values indicate features that contribute more towards one class (C2), while negative values indicate features that contribute more towards the other class (C1).
# Simple classifier
def discriminant_function(x):
return np.dot(w, x)
Fisher discriminant comparaison¶
# Train Fisher Discriminant Analysis model
lda = LinearDiscriminantAnalysis()
lda.fit(X_train, y_train)
# Predict using LDA
y_pred_lda = lda.predict(X_test)
# Calculate the AUC-ROC for LDA
auc_roc_lda = roc_auc_score(y_test, y_pred_lda)
# Compute the confusion matrix for LDA
confusion_matrix_lda = confusion_matrix(y_test, y_pred_lda)
# Plotting AUC-ROC Curve
fpr, tpr, thresholds = roc_curve(y_test, lda.predict_proba(X_test)[:,1])
plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, color='b', lw=2, label=f'AUC-ROC = {auc_roc_lda:.4f}')
plt.plot([0, 1], [0, 1], color='gray', linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve - Fisher Discriminant Analysis')
plt.legend(loc='lower right')
plt.grid(True)
plt.show()
# Plotting Confusion Matrix
plt.figure(figsize=(8, 6))
sns.heatmap(confusion_matrix_lda, annot=True, cmap='Blues', fmt='d', cbar=False)
plt.title('Confusion Matrix - Fisher Discriminant Analysis')
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
plt.show()
Area under the ROC curve measure¶
# Evaluate the classifier with auc_roc
y_pred = np.array([discriminant_function(x) for x in X_test])
threshold = np.median(y_pred)
y_pred_binary = y_pred > threshold
y_pred_binary = y_pred_binary.astype(int)
auc_roc = roc_auc_score(y_test, y_pred_binary)
# Compute the confusion matrix
confusion_matrix_dm = confusion_matrix(y_test, y_pred_binary)
AUC-ROC for difference-of-means Analysis: 0.6754385964912281 Confusion matrix for difference-of-means Analysis: [[13 6] [ 7 14]]
# Plotting AUC-ROC Curve
fpr, tpr, thresholds = roc_curve(y_test, y_pred)
auc_roc = roc_auc_score(y_test, y_pred)
plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, color='b', lw=2, label=f'AUC-ROC = {auc_roc:.4f}')
plt.plot([0, 1], [0, 1], color='gray', linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve - Difference-of-Means Analysis')
plt.legend(loc='lower right')
plt.grid(True)
plt.show()
# Plotting Confusion Matrix
plt.figure(figsize=(6, 4))
sns.heatmap(confusion_matrix_dm, annot=True, cmap='Blues', fmt='d', cbar=False)
plt.title('Confusion Matrix - Difference-of-Means Analysis')
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.show()
X_test and y_test are used for prediction and evaluation, ensuring that the performance metrics are computed on a test set that is disjoint from X_train and y_train
Interpretation of these results¶
AUC-ROC Scores:¶
The number of instances used for training each model can influence performance. Larger training sets generally lead to better generalization and performance metrics.
Confusion Matrices:¶
The Fisher Discriminant Analysis shows a slightly better balance between true positives and true negatives compared to the Difference-of-Means Analysis. Both models have a similar number of misclassifications (11 for Difference-of-Means vs. 13 for Fisher Discriminant), but the specific misclassification types differ slightly.
plt.figure(figsize=(10, 5))
plt.bar(methods, auc_roc_scores, color=['blue', 'green'])
plt.xlabel('Methods')
plt.ylabel('AUC-ROC Score')
plt.title('Comparison of AUC-ROC Scores')
plt.ylim(0.65, 0.7)
for i, score in enumerate(auc_roc_scores):
plt.text(i, score + 0.002, f'{score:.4f}', ha='center', va='bottom', fontsize=10)
plt.show()
# Confusion Matrices
labels = ['TN', 'FP', 'FN', 'TP']
plt.figure(figsize=(12, 5))
# Difference-of-Means Analysis
plt.subplot(1, 2, 1)
plt.imshow(confusion_matrix_dm, cmap='Blues', interpolation='nearest')
plt.title('Difference-of-Means Analysis')
plt.colorbar()
plt.xticks(np.arange(2), ['Predicted 0', 'Predicted 1'])
plt.yticks(np.arange(2), ['Actual 0', 'Actual 1'])
for i in range(2):
for j in range(2):
plt.text(j, i, str(confusion_matrix_dm[i, j]), ha='center', va='center', color='white')
# Fisher Discriminant Analysis
plt.subplot(1, 2, 2)
plt.imshow(confusion_matrix_lda, cmap='Greens', interpolation='nearest')
plt.title('Fisher Discriminant Analysis')
plt.colorbar()
plt.xticks(np.arange(2), ['Predicted 0', 'Predicted 1'])
plt.yticks(np.arange(2), ['Actual 0', 'Actual 1'])
for i in range(2):
for j in range(2):
plt.text(j, i, str(confusion_matrix_lda[i, j]), ha='center', va='center', color='white')
plt.tight_layout()
plt.show()
# Comparison and Analysis
print('Comparison and Analysis:')
print('AUC-ROC Scores:')
print(f"Both models have fairly similar AUC-ROC scores, with the Fisher Discriminant Analysis slightly outperforming "
f"the Difference-of-Means Analysis by a small margin ({auc_roc_scores[1]:.4f} vs. {auc_roc_scores[0]:.4f}).")
print('Confusion Matrices:')
print('Difference-of-Means Analysis:')
print(confusion_matrix_dm)
print('Interpretation:')
print('True Negative (TN):', confusion_matrix_dm[0, 0])
print('False Positive (FP):', confusion_matrix_dm[0, 1])
print('False Negative (FN):', confusion_matrix_dm[1, 0])
print('True Positive (TP):', confusion_matrix_dm[1, 1])
print('This confusion matrix shows how well the "Difference-of-Means Analysis" method classified instances into '
'their respective classes. It indicates that out of 40 instances (total of the matrix), '
f"{np.sum(confusion_matrix_dm)} were correctly classified, and {np.sum(confusion_matrix_dm) - np.trace(confusion_matrix_dm)} were misclassified.")
print('Fisher Discriminant Analysis:')
print(confusion_matrix_lda)
print('Interpretation:')
print('True Negative (TN):', confusion_matrix_lda[0, 0])
print('False Positive (FP):', confusion_matrix_lda[0, 1])
print('False Negative (FN):', confusion_matrix_lda[1, 0])
print('True Positive (TP):', confusion_matrix_lda[1, 1])
print('This confusion matrix shows the classification results of the Fisher Discriminant Analysis. It indicates that '
'out of 40 instances, '
f"{np.sum(confusion_matrix_lda)} were correctly classified, and {np.sum(confusion_matrix_lda) - np.trace(confusion_matrix_lda)} were misclassified.")
print('Comparison and Analysis:')
print('AUC-ROC Scores:')
print("Both models have a similar number of misclassifications (11 for Difference-of-Means vs. 13 for Fisher Discriminant), but the specific misclassification types differ slightly.")
Comparison and Analysis: AUC-ROC Scores: Both models have fairly similar AUC-ROC scores, with the Fisher Discriminant Analysis slightly outperforming the Difference-of-Means Analysis by a small margin (0.6779 vs. 0.6754). Confusion Matrices: Difference-of-Means Analysis: [[13 6] [ 7 14]] Interpretation: True Negative (TN): 13 False Positive (FP): 6 False Negative (FN): 7 True Positive (TP): 14 This confusion matrix shows how well the "Difference-of-Means Analysis" method classified instances into their respective classes. It indicates that out of 40 instances (total of the matrix), 40 were correctly classified, and 13 were misclassified. Fisher Discriminant Analysis: [[14 5] [ 8 13]] Interpretation: True Negative (TN): 14 False Positive (FP): 5 False Negative (FN): 8 True Positive (TP): 13 This confusion matrix shows the classification results of the Fisher Discriminant Analysis. It indicates that out of 40 instances, 40 were correctly classified, and 13 were misclassified. Comparison and Analysis: AUC-ROC Scores: Both models have a similar number of misclassifications (11 for Difference-of-Means vs. 13 for Fisher Discriminant), but the specific misclassification types differ slightly.
4 - Understanding the Image-Class Relation Pixel-Wise¶
Simply assessing how well classes can be predicted from images may not suffice. Instead, it becomes crucial to identify the precise input features or pixels that are most relevant for achieving accurate predictions. This approach offers several significant benefits:
Artefact and Accuracy¶
By pinpointing which features are crucial, we can validate that high prediction accuracy is not due to artifacts or irrelevant noise present in the image data. This ensures that our models are genuinely capturing meaningful patterns and not relying on incidental factors.
Insights into Image-Class Relations:¶
Beyond validation, identifying relevant input features provides deeper insights into how images are associated with specific classes. This analysis helps us understand what aspects of an image—whether textures, shapes, colors, or specific regions—are indicative of one class over another.
4.1 - Sensitivity Analysis¶
The derivative of the model w.r.t. the input pixels¶
# Implement sensitivity analysis
def sensitivity_analysis(model, image, label, device):
model.eval()
image = image.unsqueeze(0).to(device)
image.requires_grad = True
# Automatic differentiation in PyTorch
# The gradient calculation is handled automatically by PyTorch
output = model(image)
output = torch.flatten(output, 1)
output = F.linear(output, torch.tensor(w).to(device))
output.backward()
gradient = image.grad.data
gradient = gradient.squeeze().pow(2).sum(0).sqrt()
gradient = gradient.cpu().numpy()
return gradient
This function sensitivity_analysis computes the gradient of the model's output with respect to the input image pixels. It then computes the norm of this gradient vector, which serves as an importance score for each pixel in the image.
Heatmap¶
# Generate heatmap for a sample image
chosen_image_index = 50
sample_image, sample_label = dataset[chosen_image_index]
original_image = Image.open(dataset.image_paths[chosen_image_index])
heatmap = sensitivity_analysis(model_features, sample_image, sample_label, device)
# Create a figure with subplots
fig, axs = plt.subplots(1, 3, figsize=(15, 5)) # 1 row, 3 columns
# Display original image
axs[0].imshow(original_image)
axs[0].set_title('Original Image')
# Display sample image with heatmap overlay
axs[1].imshow(sample_image.permute(1, 2, 0).numpy())
axs[1].imshow(heatmap, cmap='hot', alpha=0.6)
axs[1].set_title('Sample Image with Heatmap')
# Display heatmap
axs[2].imshow(heatmap, cmap='Reds')
axs[2].set_title('Heatmap')
# Adjust layout and display
plt.tight_layout()
plt.show()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
The code selects a sample image from the dataset (sample_image) and its corresponding label (sample_label). It computes the sensitivity analysis heatmap (heatmap) using the sensitivity_analysis function defined earlier. It then visualizes the original image, overlays the heatmap to show important regions in the image using a hot colormap with transparency (alpha=0.6), and displays the heatmap separately in red tones.
4.2 - More Robust Explanations¶
class BiasedLayer(nn.Module):
def __init__(self, original_layer):
super(BiasedLayer, self).__init__()
self.original_layer = original_layer
self.biased_layer = nn.Conv2d(
in_channels=original_layer.in_channels,
out_channels=original_layer.out_channels,
kernel_size=original_layer.kernel_size,
stride=original_layer.stride,
padding=original_layer.padding,
bias=original_layer.bias is not None
)
self.biased_layer.weight.data = original_layer.weight.data + 0.25 * torch.relu(original_layer.weight.data)
if original_layer.bias is not None:
self.biased_layer.bias.data = original_layer.bias.data + 0.25 * torch.relu(original_layer.bias.data)
def forward(self, x):
return self.biased_layer(x)
# Apply the biased layer strategy to the VGG-16 model
biased_model_features = nn.Sequential(
*[BiasedLayer(layer) if isinstance(layer, nn.Conv2d) else layer for layer in model_features]
).to(device)
The goal is to prioritize excitatory effects (positive influences) over inhibitory effects (negative influences) within the network.
Instead of altering the forward function of entire layers, which could disrupt network functionality, specific layers are rewritten locally. This ensures that the forward pass remains unchanged, but the gradient computation is adjusted to enforce the desired bias.
# Generate robust sensitivity heatmap
robust_heatmap = sensitivity_analysis(biased_model_features, sample_image, sample_label, device)
Heatmap :¶
# Create a figure with subplots
fig, axs = plt.subplots(1, 3, figsize=(15, 5)) # 1 row, 3 columns
# Display original image
axs[0].imshow(original_image)
axs[0].set_title('Original Image')
axs[1].imshow(sample_image.permute(1, 2, 0).numpy())
axs[1].imshow(robust_heatmap, cmap='hot', alpha=0.6)
axs[1].set_title('Sample Image with Heatmap')
axs[2].imshow(robust_heatmap, cmap='Reds')
axs[2].set_title('Heatmap')
# Adjust layout and display
plt.tight_layout()
plt.show()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
4.3 - Discussion¶
In the example, we observe that the pixels highlighted as relevant by the discriminant function's sensitivity analysis do not accurately overlap with the region of the leaf where the disease is visible. Several potential issues could contribute to this mismatch. Below, we discuss possible sources of the problem and provide suggestions for addressing them. Potential Problems and Solutions
- Insufficiently Good Pretrained Neural Network:
Problem: The pretrained neural network (VGG-16 in this case) may not be well-suited for the specific task of identifying diseased regions in plant leaves. Solution: Consider using a more advanced or specialized neural network architecture. For instance, models like ResNet, EfficientNet, or even a model pretrained specifically on plant disease datasets might perform better. Fine-tuning the pretrained model on your specific dataset could also improve performance.
- Problems with Data Quality:
Problem: The quality of the dataset may be suboptimal, with issues such as low resolution, noise, incorrect labels or data amount. Solution: Training on a larger subset of data or improve data quality by ensuring high-resolution images, noise reduction, and accurate labeling. Augmenting the dataset with more diverse and correctly annotated samples can also help. Implementing data preprocessing steps like normalization and augmentation (e.g., rotations, flips, and color adjustments) can enhance the model’s ability to generalize.
Steps for Overcoming the Identified Problems
- Model Improvement:
Fine-tune the VGG-16 model or switch to a more advanced architecture. Train the model on a larger, more diverse dataset if available. Incorporate transfer learning by leveraging models pretrained on similar tasks.
- Feature Extraction:
Experiment with different layers and techniques for feature extraction. Use advanced explanation techniques like Grad-CAM, Integrated Gradients, or SHAP to get more accurate relevance maps.
- Data Quality Enhancement:
Ensure high-resolution, noise-free images. Verify and correct labels with the help of domain experts. Apply data augmentation to increase dataset diversity.
By addressing these potential problems, we can improve the accuracy of our model in identifying relevant features for classifying plant diseases and ensure that the highlighted pixels overlap more accurately with the diseased regions of the leaves. This will not only enhance the model's predictive performance but also its interpretability and reliability.