from cv2 import cv2
import numpy as np
import math
import matplotlib.pyplot as plt
import constants
from scipy.interpolate import interp1d
import time

width = constants.width
height = constants.height

aspect_ratio = height/width

'''
Takes the image and a set of points on the
belly curve and returns the straightened image
'''
def straighten(img, points):
	start = time.time()
	curve = get_smooth_curve(points, width)
	end = time.time()
	print("Bicubic interpolation of spine took " + str(end - start) + "s")

	start = time.time()
	map = generate_map_from_bellycurve(curve)
	result = cv2.remap(img, map[:, :, 0], map[:, :, 1], cv2.INTER_LINEAR)

	end = time.time()
	print("Straightening took " + str(end - start) + "s")
	return result

'''
Generates a map that can be passed to cv2.remap to extract the belly pattern
'''
def generate_map_from_bellycurve(curve):
	salamander_length = np.linalg.norm(curve[len(curve) - 1] - curve[0])
	salamander_width = salamander_length * aspect_ratio

	gradient = np.gradient(curve, axis=0)

	gradientlength = np.linalg.norm(gradient, axis=1)

	tu = np.divide(gradient, gradientlength.reshape((width, 1)))

	su = np.ndarray(tu.shape)

	# Invert the vectors
	for i in range(0, len(tu)):
		su[i, 1] = tu[i, 0]
		su[i, 0] = -tu[i, 1]

	s = np.linspace(-1, 1, height).reshape((height, 1, 1))

	map = s * su * (salamander_width/2) + curve

	return map.astype('float32')

'''
Takes in an array of 2-D points and returns a new set of points
with numpoints elements where the missing points are interpolated
using cubic interpolation
'''
def get_smooth_curve(points, num_points):
	# Calculate distance between all the points
	distance = np.sqrt(np.sum(np.diff(points, axis=0)**2, axis=1))

	# Accumulate distance to use as parameter
	accumulated_distance = np.cumsum(distance)

	# Insert starting point
	accumulated_distance = np.insert(accumulated_distance, 0, 0)

	# Make it go from 0 to 1
	accumulated_distance /= accumulated_distance[-1]

	alpha = np.linspace(0, 1, num_points)

	f = interp1d(accumulated_distance, points, kind='cubic', axis=0)

	interpoints = f(alpha)

	return interpoints
