Select Git revision
fileserver.sh
-
Marius Raes authoredMarius Raes authored
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
sex_identification.py 2.31 KiB
from imageai.Detection.Custom import DetectionModelTrainer, CustomObjectDetection
import os
import tensorflow as tf
def identify(image):
config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True
session = tf.compat.v1.Session(config=config)
root_path = "../../../salamander-api"
os_directory_path = os.path.abspath(root_path)
image_detector = CustomObjectDetection()
image_detector.setModelTypeAsYOLOv3()
image_detector.setModelPath(os.path.join(os_directory_path,
"algorithm/train_src/imageai_model/models/detection_model-ex-012--loss-0018.483.h5"))
image_detector.setJsonPath(
os.path.join(os_directory_path, "algorithm/train_src/imageai_model/json/detection_config.json"))
image_detector.loadModel()
predictions = image_detector.detectObjectsFromImage(input_image=image,
output_image_path=os.path.join(os_directory_path,
"img_analyze/detected_salamander.png"),
minimum_percentage_probability=40)
# winner = {"name": "male", "percentage_probability": 0}
# for prediction in predictions:
# if prediction['percentage_probability'] > winner['percentage_probability']:
# winner['percentage_probability'] = prediction['percentage_probability']
# winner['name'] = prediction['name']
largest_p = max(predictions, key=lambda x: x['percentage_probability'])['name']
return largest_p
def train():
config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True
session = tf.compat.v1.Session(config=config)
directory_path = "../../../salamander-api/img_analyze"
os_directory_path = os.path.abspath(directory_path)
trainer = DetectionModelTrainer()
trainer.setModelTypeAsYOLOv3()
trainer.setDataDirectory(data_directory=os.path.join(os_directory_path, "algorithm/train_src/imageai_model"))
trainer.setTrainConfig(object_names_array=["male", "female"], batch_size=8, num_experiments=200)
# In the above,when training for detecting multiple objects,
# set object_names_array=["object1", "object2", "object3",..."objectz"]
trainer.trainModel()