from imageai.Detection.Custom import DetectionModelTrainer, CustomObjectDetection
import os
import tensorflow as tf


def identify(image):
    config = tf.compat.v1.ConfigProto()
    config.gpu_options.allow_growth = True
    session = tf.compat.v1.Session(config=config)

    root_path = "../../../salamander-api"
    os_directory_path = os.path.abspath(root_path)

    image_detector = CustomObjectDetection()
    image_detector.setModelTypeAsYOLOv3()
    image_detector.setModelPath(os.path.join(os_directory_path,
                                             "algorithm/train_src/imageai_model/models/detection_model-ex-012--loss-0018.483.h5"))
    image_detector.setJsonPath(
        os.path.join(os_directory_path, "algorithm/train_src/imageai_model/json/detection_config.json"))
    image_detector.loadModel()

    predictions = image_detector.detectObjectsFromImage(input_image=image,
                                                        output_image_path=os.path.join(os_directory_path,
                                                                                       "img_analyze/detected_salamander.png"),
                                                        minimum_percentage_probability=40)
    # winner = {"name": "male", "percentage_probability": 0}
    # for prediction in predictions:
    #     if prediction['percentage_probability'] > winner['percentage_probability']:
    #         winner['percentage_probability'] = prediction['percentage_probability']
    #         winner['name'] = prediction['name']

    largest_p = max(predictions, key=lambda x: x['percentage_probability'])['name']
    return largest_p


def train():
    config = tf.compat.v1.ConfigProto()
    config.gpu_options.allow_growth = True
    session = tf.compat.v1.Session(config=config)

    directory_path = "../../../salamander-api/img_analyze"
    os_directory_path = os.path.abspath(directory_path)

    trainer = DetectionModelTrainer()
    trainer.setModelTypeAsYOLOv3()
    trainer.setDataDirectory(data_directory=os.path.join(os_directory_path, "algorithm/train_src/imageai_model"))
    trainer.setTrainConfig(object_names_array=["male", "female"], batch_size=8, num_experiments=200)
    # In the above,when training for detecting multiple objects,
    # set object_names_array=["object1", "object2", "object3",..."objectz"]
    trainer.trainModel()
