Skip to content
Snippets Groups Projects
Commit f52e690e authored by Leonardo de Lima Gaspar's avatar Leonardo de Lima Gaspar
Browse files

Improved confusion matrix generation, so it uses samples that are roughly...

Improved confusion matrix generation, so it uses samples that are roughly uniformly distributed between the 2 classes, across the entire validation set. The matrix agrees with the reported accuracy, suggesting something is wrong with my implementation used with OpenCV in c++ (as the acc. is lower).
parent 48db878b
No related branches found
No related tags found
No related merge requests found
Detection_ID,Date,Time of day (hour:min),Source_file,Time in (min:sek),Time out (min:sek),Art
0,2021-07-11,17:03:37,Myggbukta-[2021-07-09_13-48-44]-105,00:09,00:18,
1,2021-07-11,17:03:37,Myggbukta-[2021-07-09_13-48-44]-105,00:40,00:43,
2,2021-07-11,17:03:37,Myggbukta-[2021-07-09_13-48-44]-105,01:28,02:30,
3,2021-07-11,17:03:37,Myggbukta-[2021-07-09_13-48-44]-105,04:20,21:05,
4,2021-07-11,17:03:37,Myggbukta-[2021-07-09_13-48-44]-105,21:42,21:57,
5,2021-07-11,17:03:37,Myggbukta-[2021-07-09_13-48-44]-105,21:59,22:02,
0,2021-07-11,17:03:37,Myggbukta-[2021-07-09_13-48-44]-105,00:09,00:18,
1,2021-07-11,17:03:37,Myggbukta-[2021-07-09_13-48-44]-105,00:40,00:43,
2,2021-07-11,17:03:37,Myggbukta-[2021-07-09_13-48-44]-105,01:28,02:30,
3,2021-07-11,17:03:37,Myggbukta-[2021-07-09_13-48-44]-105,04:20,21:05,
4,2021-07-11,17:03:37,Myggbukta-[2021-07-09_13-48-44]-105,21:42,21:57,
5,2021-07-11,17:03:37,Myggbukta-[2021-07-09_13-48-44]-105,21:59,22:02,
0,2021-07-11,17:03:37,Myggbukta-[2021-07-09_13-48-44]-105,00:09,00:18,
1,2021-07-11,17:03:37,Myggbukta-[2021-07-09_13-48-44]-105,00:40,00:43,
2,2021-07-11,17:03:37,Myggbukta-[2021-07-09_13-48-44]-105,01:28,02:30,
3,2021-07-11,17:03:37,Myggbukta-[2021-07-09_13-48-44]-105,04:20,21:05,
4,2021-07-11,17:03:37,Myggbukta-[2021-07-09_13-48-44]-105,21:42,21:57,
5,2021-07-11,17:03:37,Myggbukta-[2021-07-09_13-48-44]-105,21:59,22:02,
from tkinter import E
import matplotlib.pyplot as plt
import json
def plot_graphs(history, string):
plt.plot(history[string])
plt.plot(history["val_"+string])
plt.xlabel("Epochs")
plt.ylabel(string)
plt.legend([string,"val_"+string])
plt.show()
if __name__ == "__main__":
history = json.load(open("src/python/utilities/trainedModels/3/trainingHistory.json"))
for key in history:
print(key)
plot_graphs(history,'accuracy')
plot_graphs(history,'loss')
\ No newline at end of file
import matplotlib.pyplot as plt
import seaborn as sn
import json
import numpy as np
import tensorflow as tf
from tensorflow.keras.preprocessing import image_dataset_from_directory
from tensorflow.keras import Model
from tensorflow.keras import models
def plot_graphs(history, string):
plt.plot(history[string])
plt.plot(history["val_"+string])
plt.xlabel("Epochs")
plt.ylabel(string)
plt.legend([string,"val_"+string])
plt.show()
def plot_confusionMatrix():
resizeWidth = 288
resizeHeight = 162
datasetLimit = 1500
model = models.load_model("src\\python\\utilities\\trainedModels\\3\\model")
dataset = image_dataset_from_directory(
directory="dataset/validate",
labels="inferred",
label_mode="binary",
class_names=["NoFish", "fish"],
image_size=(resizeWidth, resizeHeight),
batch_size=1,
shuffle=False,
interpolation="nearest"
)
labelsNumpy = []
data = []
iterator = dataset.as_numpy_iterator()
step = int(dataset.cardinality().numpy()/datasetLimit)
for cout in range(datasetLimit):
img, label = iterator.next()
labelsNumpy.append(label[0])
data.append(img)
for i in range(step-1):
iterator.next()
data = np.array(data)
data = np.squeeze(data)
labelsNumpy = np.array(labelsNumpy)
labelsNumpy.flatten()
preds = model.predict(data, batch_size=1)
preds.flatten()
# 2-by-2 matrix for binary classification.
matrix = np.zeros((2,2), dtype=np.int32)
for i in range(len(labelsNumpy)):
y = int(round(preds[i][0], 1))
x = int(round(labelsNumpy[i][0], 1))
matrix[y][x] += 1
classes = ["Not fish", "Fish"]
sn.set(font_scale=2)
ax = sn.heatmap(matrix, annot=True, xticklabels=classes, yticklabels=classes, cbar=True, fmt="g")
ax.set(title="Confusion matrix", xlabel="Ground truth", ylabel="Predicted")
plt.tight_layout()
plt.show()
if __name__ == "__main__":
history = json.load(open("src/python/utilities/trainedModels/3/trainingHistory.json"))
plot_graphs(history,'accuracy')
plot_graphs(history,'loss')
plot_confusionMatrix()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment