Skip to content
Snippets Groups Projects
Select Git revision
  • main
  • master
2 results

Handler.php

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    predict_salamander_abdomen.py 3.63 KiB
    # This file contains modified code from DeepLabCut's source code
    
    import os.path
    import time
    import numpy as np
    from pathlib import Path
    from skimage.util import img_as_ubyte
    
    
    def run_prediction_dlc(config: str, image: np.ndarray, shuffle: int = 1,
                           trainingsetindex: int = 0, gputouse: int = None):
        from deeplabcutcore.pose_estimation_tensorflow.nnet import predict
        from deeplabcutcore.pose_estimation_tensorflow.config import load_config
        from tensorflow.python.framework.ops import reset_default_graph
        from deeplabcutcore.utils import auxiliaryfunctions
    
        if 'TF_CUDNN_USE_AUTOTUNE' in os.environ:
            # was potentially set during training
            del os.environ['TF_CUDNN_USE_AUTOTUNE']
    
        if gputouse is not None:  # gpu selection
            os.environ['CUDA_VISIBLE_DEVICES'] = str(gputouse)
    
        reset_default_graph()
        # record cwd to return to this directory in the end:
    
        cfg = auxiliaryfunctions.read_config(config)
        train_fraction = cfg['TrainingFraction'][trainingsetindex]
        model_folder = os.path.join(cfg["project_path"],
                                    str(auxiliaryfunctions.GetModelFolder(train_fraction, shuffle, cfg)))
        path_test_config = Path(model_folder) / 'test' / 'pose_cfg.yaml'
        try:
            dlc_cfg = load_config(str(path_test_config))
        except FileNotFoundError:
            raise FileNotFoundError(
                "It seems the model for shuffle %s and trainFraction %s does not exist." % (shuffle, train_fraction))
        # Check which snapshots are available and sort them by # iterations
        try:
            snapshots = np.array(
                [fn.split('.')[0] for fn in os.listdir(os.path.join(model_folder, 'train')) if "index" in fn])
        except FileNotFoundError:
            raise FileNotFoundError("Snapshots not found!\
           It seems the dataset for shuffle %s has not been trained/does not exist.\n \
           Please train it before using it to analyze videos.\n Use the function \
           'train_network' to train the network for shuffle %s." % (shuffle, shuffle))
    
        if cfg['snapshotindex'] == 'all':
            # print("Snapshotindex is set to 'all' in the config.yaml file.\
            #  Running video analysis with all snapshots is very costly!\
            #   Use the function 'evaluate_network' to choose the best the snapshot.\
            #    For now, changing snapshot index to -1!")
            snapshot_index = -1
        else:
            snapshot_index = cfg['snapshotindex']
    
        increasing_indices = np.argsort([int(m.split('-')[1]) for m in snapshots])
        snapshots = snapshots[increasing_indices]
    
        ##################################################
        # Load and setup CNN part detector
        ##################################################
    
        # Check if data already was generated:
        dlc_cfg['init_weights'] = os.path.join(model_folder, 'train', snapshots[snapshot_index])
    
        # Update batchsize (based on parameters in config.yaml)
        dlc_cfg['batch_size'] = 1
    
        sess, inputs, outputs = predict.setup_pose_prediction(dlc_cfg)
    
        # update number of outputs and adjust pandas indices
        dlc_cfg['num_outputs'] = cfg.get('num_outputs', 1)
    
        if gputouse is not None:  # gpu selectinon
            os.environ['CUDA_VISIBLE_DEVICES'] = str(gputouse)
    
        ##################################################
        # Loading the images
        ##################################################
    
        ny, nx, nc = np.shape(image)
        nframes = 1
        PredictedData = np.zeros((nframes, dlc_cfg['num_outputs'] * 3 * len(dlc_cfg['all_joints_names'])))
    
        frame = img_as_ubyte(image)
        pose = predict.getpose(frame, dlc_cfg, sess, inputs, outputs)
        PredictedData[0, :] = pose.flatten()
        return PredictedData[0], nframes, nx, ny