#!/usr/bin/python
import wx
import numpy as np
import nmrglue as ng
import copy
import math
import sys
import os
import time
from pprint import pprint
import argparse
import matplotlib
from matplotlib.figure import Figure
matplotlib.use('WXAgg')
from matplotlib.backends.backend_wxagg import FigureCanvasWxAgg as FigureCanvas
from matplotlib.backends.backend_wxagg import NavigationToolbar2WxAgg as NavigationToolbar

from scipy.ndimage import label, find_objects
from scipy.spatial import distance_matrix, distance
from scipy.optimize import leastsq, differential_evolution
import scipy.stats

### Parameters ###

parser = argparse.ArgumentParser()
parser.add_argument('-p', '--path', help='path to directory')
parser.add_argument('-th', '--threshold', help='Threshold for plotting')
parser.add_argument('-tr', '--temp_range', help='Temperature range [K]')
parser.add_argument('-ti', '--temp_init', help='Initial Temperature [K]')
parser.add_argument('-pl', '--peak_list', help='Peak list in initial temperature')


args = parser.parse_args()
if args.path is not None:
    os.chdir(args.path)
else:
    path=raw_input("Enter path to your directory:\n")
    os.chdir(path)
if args.threshold is not None:
    thresh=float(args.threshold)
else:
    thresh=float(raw_input("Enter threshold value:\n"))

if args.temp_init is not None:
    temp_initial=float(args.temp_init)
else:
    temp_initial = float(raw_input("Enter initial temperature (in Kelvins):\n "))

if args.temp_range is not None:
    temp_diff = float(args.temp_range)
else:
    temp_diff = float(raw_input("Enter temperature range (in Kelvins):\n "))
if args.peak_list is not None:
    pl_name = args.peak_list
else:
    pl_name=raw_input("Enter peak list name:\n")

file_list = [f for f in os.listdir('.') if os.path.isfile(os.path.join('.', f))]
filenames_ft3 = np.array([], dtype=int)
filenames_eps = np.array([], dtype=int)
for filename in file_list:
    if ".ft3" in filename:
        filenames_ft3 = np.append(filenames_ft3, int(filename.replace(".ft3", "").replace("tempscan", "").replace(".ucsf", "")))
    if "figure" in filename:
                filenames_eps = np.append(filenames_eps, int(filename.replace("figure", "").replace(".eps", "")))
try:
    last_figurename = np.max(filenames_eps) + 1
except:
    last_figurename = 1

max_i = np.max(filenames_ft3) + 1
print '\n', max_i, 'spectra in the series\n'
first_spectrum = np.min(filenames_ft3)


print 'Threshold for contour maps and peak-picking set to: %e' % thresh

temp_range = np.linspace(temp_initial, temp_initial+temp_diff, num=max_i, endpoint=True) # +

### Peak picking and tracking ###

start = time.time()
dic, data = ng.fileio.pipe.read('tempscan'+str(0)+'.ft3')
data = np.flip(np.flip(np.flip(data, 0), 1), 2)

def detect_peak_3d(data, dx1, dx2, dx3, s):  # s - neighbour test size 
    data = np.abs(data)
    range_x1 = np.arange(dx1.start, dx1.stop, 1)
    range_x2 = np.arange(dx2.start, dx2.stop, 1)
    range_x3 = np.arange(dx3.start, dx3.stop, 1)
    peaks = np.array([[], []]) 
    for x1 in range_x1:
        for x2 in range_x2:
            for x3 in range_x3:
                neighbouring_area = data[x1-s[0]:x1+s[0]+1, x2-s[1]:x2+s[1]+1, x3-s[2]:x3+s[2]+1]
                if np.size(neighbouring_area)!=0:
                    if data[x1, x2, x3] >= np.max(neighbouring_area.flatten()):
                        try:
                            peaks = np.vstack((peaks, np.array([x1, x2, x3])))
                        except ValueError:
                            peaks = np.array([x1, x2, x3])
    return peaks


def find_all_peaks(data, thresh, s):

    data_denoise = np.zeros(np.shape(data))
    ind_above_noise = np.where(np.abs(data) >= thresh)  # data above threshold (15*sigma works well)
    data_denoise[ind_above_noise] = data[ind_above_noise]

    labeled, num_labeled = label(data_denoise)  # num_labeled - number of labeled regions (separate regions above threshold)
    print 'labeled regions: ', num_labeled
    slices = find_objects(labeled)

    p = np.array([[], []], int)
    for dx1, dx2, dx3 in slices:
        local_max = detect_peak_3d(data, dx1, dx2, dx3, s)  # s - size area: "neighbours test"
        if np.size(local_max)!=0:
            try:
                p = np.vstack((p, local_max))
            except ValueError:
                p = local_max

    if p.ndim==1:
        h = data[p[0], p[1], p[2], p[2]]

    else:
        h = np.zeros((np.shape(p)[0],))
        for z in range(np.shape(p)[0]):
            h[z] = data[p[z, 0], p[z, 1], p[z, 2]]
    return p, h


def correspond(list1, list2, h_list2):
    if (list1.ndim!=1) & (list2.ndim!=1):
        dist = distance_matrix(list1, list2)
        closest = np.argmin(dist, axis=1)  # closest new peak to each old peak  
        u_closest, u_ind, u_ind_inv, u_count = np.unique(closest, return_index=True, return_inverse=True, return_counts=True)

        count_mask = u_count > 1
        dupl_new = u_closest[count_mask]  # new peaks which repeat themselves in "closest"
        suppressed_peaks = np.array([])  # "extra" old peaks (not close to any of new ones)
        for i in dupl_new:
            candidates_for_suppression = np.where(closest == i)[0]
            real_closest_peak = candidates_for_suppression[np.argmin(dist[candidates_for_suppression, i])]
            suppressed_peaks = np.append(suppressed_peaks, np.setdiff1d(candidates_for_suppression, real_closest_peak))
        suppressed_peaks = np.array(suppressed_peaks, dtype=int)
        new_peaks = np.setdiff1d(np.arange(len(list2)), u_closest)  # "extra" new peaks (not close to any of old ones)

        list2 = np.row_stack((list2[closest], list2[new_peaks]))
        nans_to_add = np.column_stack((np.repeat(-9223372036854775808, len(suppressed_peaks)), np.repeat(-9223372036854775808, len(suppressed_peaks)), np.repeat(-9223372036854775808, len(suppressed_peaks))))
        list2[suppressed_peaks, :] = nans_to_add

        h_list2 = np.concatenate((h_list2[closest], h_list2[new_peaks]))
        # h_list2 = array(h_list2, dtype=float32)
        h_list2[suppressed_peaks] = np.zeros(np.shape(suppressed_peaks))

    elif (list1.ndim==1) & (list2.ndim!=1) & (list1.size*list2.size!=0):
        dist = np.zeros(np.shape(list2)[0])
        for i in range(np.shape(list2)[0]):
            dist[i] = distance.euclidean(list1, list2[i, :])
        closest = np.argmin(dist)  # closest new peak to each old peak
        new_peaks = np.setdiff1d(np.arange(len(list2)), closest)  # "extra" new peaks (not close to any of old ones)
        list2 = np.row_stack((list2[closest], list2[new_peaks]))
        h_list2 = np.concatenate(([h_list2[closest]], h_list2[new_peaks]))

    elif (list1.ndim!=1) & (list2.ndim==1) & (list1.size*list2.size!=0):
        for i in range(np.shape(list1)[0]):
            dist[i] = distance.euclidean(list1[i, :], list2)
        closest = np.argmin(dist)  # old peak closest the (one) new one
        h_list2_new = np.zeros(len(list1))
        h_list2_new[closest] = h_list2
        h_list2 = h_list2_new       
        list2_new = np.full(np.shape(list1), -9223372036854775808) # 'nan' -> '-9223372036854775808'
        list2_new[closest, :] = list2
        list2 = list2_new
    
    elif (list1.ndim==1) & (list2.ndim==1): # if array is empty, its ndim=1, => both cases of 1 peak and 0 peaks are considered here
        pass

    elif (list1.size==0) & (list2.size!=0):
        pass

    elif (list1.size!=0) & (list2.size==0):
        h_list2 = np.zeros(len(list1))
        list2 = np.full(shape(list1), -9223372036854775808)
    return list2, h_list2


def peak_dict(d, p, h, frame):
    d[str('frame_') + str(frame).zfill(5)] = {}
    if p.ndim==1:
        d[str('frame_') + str(frame).zfill(5)][str('peak_') + str(0).zfill(5)] = {}
        d[str('frame_') + str(frame).zfill(5)][str('peak_') + str(0).zfill(5)]['position'] = np.flip(p, 0)
        d[str('frame_') + str(frame).zfill(5)][str('peak_') + str(0).zfill(5)]['height'] = h
    else:
        for i in range(len(h)):
            if h[i]!=0:
                d[str('frame_') + str(frame).zfill(5)][str('peak_') + str(i).zfill(5)] = {}
                d[str('frame_') + str(frame).zfill(5)][str('peak_') + str(i).zfill(5)]['position'] = p[i, :]
                d[str('frame_') + str(frame).zfill(5)][str('peak_') + str(i).zfill(5)]['height'] = h[i]
    return d
    

def points2ppm(data, points, dic, axis):
    uc = ng.pipe.make_uc(dic, data, axis)
    min_ppm = uc.ppm_scale()[-1]
    max_ppm = uc.ppm_scale()[0]
    N = np.shape(data)[axis]
    points_ppm = (max_ppm - min_ppm)*(points+1)*1.0/N + min_ppm
    return points_ppm

ppm0 = points2ppm(data, np.arange(np.shape(data)[0]), dic, 0)
ppm1 = points2ppm(data, np.arange(np.shape(data)[1]), dic, 1)
ppm2 = points2ppm(data, np.arange(np.shape(data)[2]), dic, 2)

if os.path.isfile('peaks_dictionary.npy'):
    print 'Loading peak tracking results...'
    d = np.load('peaks_dictionary.npy', allow_pickle=True).item()
    print 'loaded!'
    print 'Threshold in peak tracking results:  %e' % d['threshold']
    thresh = d['threshold']

else:
    print 'Performing peak tracking...'
    s = (1, 1, 1)
    print '\nSize for "neighbour test" area: ', s
    print '(change threshold in line 63 of "peak_tracker.py" if necessary)\n'

    print 'Spectrum nr ', 0
    p_old, h = find_all_peaks(data, thresh, s)
    # print 'p_old', p_old
    print 'Number of peaks: ', np.shape(p_old)[0]

    p_old = p_old.astype('float')
    for i in range(np.shape(p_old)[0]):
        p_old[i, 0] = points2ppm(data, p_old[i, 0], dic, 0)
        p_old[i, 1] = points2ppm(data, p_old[i, 1], dic, 1)
        p_old[i, 2] = points2ppm(data, p_old[i, 2], dic, 2)

    d = {}
    d = peak_dict(d, p_old, h, 0)

    end = np.zeros((max_i,))
    end[max_i-1] = time.time()
    print 'Time: ', end[max_i-1] - start

    for i in range(1, max_i):
        start = time.time()
        dic, data = ng.fileio.pipe.read('tempscan'+str(i)+'.ft3')
        data = np.flip(np.flip(np.flip(data, 0), 1), 2)
        print 'Spectrum nr ', i

        p_new, h = find_all_peaks(data, thresh, s)
        print 'Number of peaks: ', np.shape(p_new)[0]

        p_new = p_new.astype('float')
        p_new[:, 0] = points2ppm(data, p_new[:, 0], dic, 0)
        p_new[:, 1] = points2ppm(data, p_new[:, 1], dic, 1)
        p_new[:, 2] = points2ppm(data, p_new[:, 2], dic, 2)


        p_new, h = correspond(p_old, p_new, h)  # distance matrix between peaks of frame i-1 and frame i
        d = peak_dict(d, p_new, h, i)  # dictionary with all peaks of all frames
        p_old = p_new

        end = time.time()
        print 'Time: ', end - start

    end_t = time.time()
    print 'Total time: ', end_t - start

    d['threshold'] = thresh

    print '...writing peaks to file...'
    start_writing = time.time()
    np.save('peaks_dictionary.npy', d)
    with open('peaks_dictionary.txt', 'wt') as out: # for human-readable peak tracking results
        pprint(dict(d.items()), stream=out) # for human-readable peak tracking results
    end_writing = time.time()
    print 'peak tracking done!'


###Correspondence between peak tracking results and sparky peak list ###

test_list = np.loadtxt(pl_name, skiprows=1, usecols=(1,2,3))
print '\n', np.shape(test_list)[0], 'peaks in sparky peak list'
names = np.loadtxt(pl_name, dtype='|S15', skiprows=1, usecols=(0))

names1 = np.zeros((len(names), ), dtype=int)
for i in range(len(names)):
    newstr = ''.join((ch if ch in '0123456789' else ' ') for ch in names[i])
    listOfNumbers = [float(a) for a in newstr.split()]
    names1[i] = int(listOfNumbers[0])

first_frame_keys = d['frame_'+str(0).zfill(5)].keys()
peak_number = []
first_frame_positions = np.empty([1, 3])
for k in first_frame_keys:
    peak_number = np.append(peak_number, int(k.replace('peak_', ''))) # e.g. k = 'peak_00023'
    first_frame_positions = np.row_stack((first_frame_positions, d['frame_'+str(0).zfill(5)][k]['position']))
pn = int(max(peak_number)) + 1 # peaks in last spectrum
peak_number = peak_number.astype(int)
first_frame_positions = first_frame_positions[1:, :]

print len(peak_number), 'peaks in first spectrum (peak tracker results) \n'

dist = distance_matrix(test_list, first_frame_positions)
closest = np.argmin(dist, axis=1)  # closest new peak to each old peak  

u_closest, u_ind, u_ind_inv, u_count = np.unique(closest, return_index=True, return_inverse=True, return_counts=True)
if np.shape(u_closest)!=np.shape(closest): # if peak indices in "closest" are not unique
    print 'Problem with peak list correspondence!!!'

count_mask = u_count > 1
dupl_new = u_closest[count_mask]  # peaks which repeat themselves in 'closest'

for i in range(len(dupl_new)):
    problem_indices = np.where(closest==dupl_new[i])[0]
    problem_names = names[np.where(closest==dupl_new[i])[0]] # names of peaks where there is smth wrong
    print 'Problem peak names: ', problem_names

selected = peak_number[closest] # number of selected peaks in peaks_dictionary

positions = np.full([pn, max_i, 3], np.nan) # (194, 281, 3)
f = d.keys()
for i in range(pn):
    for fj in f:
        if fj!='threshold':
            j = int(fj.replace('frame_', ''))
            try:
                positions[i, j, :] = np.array(d[fj][str('peak_') + str(i).zfill(5)]['position'])
            except KeyError:
                pass

### Kinetics ###

corr = np.zeros((len(names), 10, 3)) # peaks-a,b,c,res_sq,k,b,res_lin-dimensions

if os.path.isfile('corr_H.txt'):
    print '\nLoading kinetics results...'
    corr[:, :, 0] = np.loadtxt('corr_N.txt', delimiter='\t', skiprows=1)
    corr[:, :, 1] = np.loadtxt('corr_CO.txt', delimiter='\t', skiprows=1)
    corr[:, :, 2] = np.loadtxt('corr_H.txt', delimiter='\t', skiprows=1)
    print 'loaded!'
else:
    corr = np.zeros((len(names), 10, 3)) # peaks-a,b,c,res_sq,k,b,res_lin-dimensions

    def fit_positions(p, names, i, axis):
        print 'Axis', axis
        non_nan = np.where(~np.isnan(p))[0]
        if len(non_nan) != 1:
            k, b1 = np.polyfit(temp_range[non_nan], p[non_nan], 1)
            a, b, c = np.polyfit(temp_range[non_nan], p[non_nan], 2)
        else:
            k, b1 = 0, p[non_nan[0]]
            a, b, c = [0, 0, p[non_nan[0]]]

        res_lin = sum(np.square(k*temp_range[non_nan]+b1 - p[non_nan]))/len(non_nan)
        print 'Linear fit coefficients (k, b in k*x+b): ', k, b1
        print 'Linear fit residual: ', res_lin
        
        res_sq = sum(np.square(a*np.square(temp_range[non_nan]) + b*temp_range[non_nan] + c - p[non_nan]))/len(non_nan)
        print 'Square fit coefficients (a, b, c in a*x^2+bx+c): ', a, b, c
        print 'Square fit residual', res_sq

        F = (res_lin - res_sq)*(len(non_nan)-3)/(1*res_sq)
        # try:
        q = 1 - scipy.stats.f.sf(F, dfd=1, dfn=len(non_nan)-3)
        # except:
        #     q = np.nan

        return a, b, c, res_sq, k, b1, res_lin, len(non_nan), q

    for i in range(len(names)):
        print '\nPeak ', names[i], '(nr', i, ')' 
        corr[i, 0, :] = names1[i]
        for axis in [0, 1, 2]:
            p = np.array(positions[selected[i], :, axis])*1e3 # parts per billion!
            print 'Number of spectra where peak is present: ', len(np.where(~np.isnan(p))[0])
            a, b, c, res_sq, k, b1, res_lin, ns, q = fit_positions(p, names, i, axis)
            corr[i, 1:, axis] = a, b, c, res_sq, k, b1, res_lin, ns, q

    h = 'Peak nr \t a \t b \t c \t Square fit res \t k \t b \t Linear fit res \t peak present in ... spectra \t Likelihood of quadratic model'

    np.savetxt('corr_N.txt', corr[:, :, 0], ['%03d', '%e', '%e', '%e', '%e', '%e', '%e', '%e', '%d', '% 8.5f'], delimiter='\t', header=h)
    np.savetxt('corr_CO.txt', corr[:, :, 1], ['%03d', '%e', '%e', '%e', '%e', '%e', '%e', '%e', '%d', '% 8.5f'], delimiter='\t', header=h)
    np.savetxt('corr_H.txt', corr[:, :, 2], ['%03d', '%e', '%e', '%e', '%e', '%e', '%e', '%e', '%d', '% 8.5f'], delimiter='\t', header=h)


### Graphic representation - kinetics ###

def kineticsplot(axis, axes, peakname, positions_selectedpeak):
    newstr = ''.join((ch if ch in '0123456789' else ' ') for ch in peakname)
    listOfNumbers = [float(a) for a in newstr.split()]
    n = int(listOfNumbers[0])
    pp = np.where(corr[:, 0, 0]==n)[0][0]

    if axis==0:
        color1 = 'r'
        coefficients = corr[pp, 1:, 0]
        dimension = 'N'
    elif axis==1:
        color1='g'
        coefficients = corr[pp, 1:, 1]
        dimension = 'CO'
    elif axis==2:
       color1='b'
       coefficients = corr[pp, 1:, 2]
       dimension = 'H'

    a, b, c = coefficients[:3] # a, b, c, res_sq, k, b1, res_lin -> 0, 1, 2, 3, 4, 5, 6
    k, b1 = coefficients[4], coefficients[5]

    non_nan_s = np.where(~np.isnan(positions_selectedpeak))[0]
    
    axes.plot(temp_range, positions_selectedpeak[:, axis]*1e3, color=color1)
    axes.plot(temp_range[non_nan_s], k*temp_range[non_nan_s]+b1, 'black')
    axes.plot(temp_range[non_nan_s], a*(temp_range[non_nan_s])**2+b*(temp_range[non_nan_s])+c, 'c')
    print dimension, 'dimension:'
    print 'Square fit residual: ', coefficients[3], '("per point")'
    print 'Linear fit residual: ', coefficients[6], '("per point")'

class KinNFrame(wx.Frame):
    def __init__(self, parent, peakname, positions_selectedpeak):
		newstr = ''.join((ch if ch in '0123456789' else ' ') for ch in peakname)
		listOfNumbers = [float(a) for a in newstr.split()]
		pknr = int(listOfNumbers[0])
		wx.Frame.__init__(self, None, size=(600,500), title='Peak '+str(pknr)+', N: '+os.path.split(os.getcwd())[1])
		self.axis = 0

		self.figure = Figure()
		self.axesobject = self.figure.add_subplot(111)
		self.canvas = FigureCanvas(self, -1, self.figure)
		self.axesobject.set_xlabel('degrees, K')
		self.axesobject.set_ylabel('ppb')

		kineticsplot(self.axis, self.axesobject, peakname, positions_selectedpeak)
		self.canvas.draw()
        
class KinCOFrame(wx.Frame):
    def __init__(self, parent, peakname, positions_selectedpeak):
		newstr = ''.join((ch if ch in '0123456789' else ' ') for ch in peakname)
		listOfNumbers = [float(a) for a in newstr.split()]
		pknr = int(listOfNumbers[0])
		wx.Frame.__init__(self, None, size=(600,500), title='Peak '+str(pknr)+', CO: '+os.path.split(os.getcwd())[1])
		self.axis = 1

		self.figure = Figure()
		self.axesobject = self.figure.add_subplot(111)
		self.canvas = FigureCanvas(self, -1, self.figure)
		self.axesobject.set_xlabel('degrees, K')
		self.axesobject.set_ylabel('ppb')


		kineticsplot(self.axis, self.axesobject, peakname, positions_selectedpeak)
		self.canvas.draw()

class KinHFrame(wx.Frame):
    def __init__(self, parent, peakname, positions_selectedpeak):
		newstr = ''.join((ch if ch in '0123456789' else ' ') for ch in peakname)
		listOfNumbers = [float(a) for a in newstr.split()]
		pknr = int(listOfNumbers[0])
		wx.Frame.__init__(self, None, size=(600,500), title='Peak '+str(pknr)+', H: '+os.path.split(os.getcwd())[1])
		self.axis = 2

		self.figure = Figure()
		self.axesobject = self.figure.add_subplot(111)
		self.canvas = FigureCanvas(self, -1, self.figure)
		self.axesobject.set_xlabel('degrees, K')
		self.axesobject.set_ylabel('ppb')

		kineticsplot(self.axis, self.axesobject, peakname, positions_selectedpeak)
		self.canvas.draw()
		print '...select next peak if necessary...'


### Graphic representation - peak tracker ###

class SpectrumPanel(wx.Panel): 
    def __init__(self, parent, positions_selectedpeak, peakname): 
        wx.Panel.__init__(self, parent, -1) 
        self.zoomclicked=False

        self.val_spec = 0
        self.dic, self.data = ng.fileio.pipe.read('tempscan'+str(self.val_spec)+'.ft3')
        self.data = np.flip(np.flip(np.flip(self.data, 0), 1), 2)

        self.udic = ng.pipe.guess_udic(self.dic, self.data)
        self.ucN = ng.pipe.make_uc(self.dic, self.data, dim=0)
        self.ucCO = ng.pipe.make_uc(self.dic, self.data, dim=1)
        self.ucH = ng.pipe.make_uc(self.dic, self.data, dim=2)       

        self.xlabel='ppm, '+ (self.udic[2]['label'].split("'"))[0] # H
        self.ylabel='ppm, '+ (self.udic[1]['label'].split("'"))[0] # C=O
        self.zlabel='ppm, '+ (self.udic[0]['label'].split("'"))[0] # N ## slider
        
        self.xaxis = ppm2
        self.yaxis = ppm1
        self.zaxis = ppm0
        self.xlim=[self.xaxis[-1], self.xaxis[0]]
        self.ylim=[self.yaxis[-1], self.yaxis[0]]

        ndim = positions_selectedpeak[0, 0]
        self.Ndim_peak = self.zaxis.flat[np.abs(self.zaxis - ndim).argmin()]
        self.Ndim = self.Ndim_peak

        self.Ndim_pts = np.where(ppm0==self.Ndim)[0]

        
        self.sld1 = wx.Slider(self, value=0, minValue=first_spectrum, maxValue=max_i-1)
        self.sld1.Bind(wx.EVT_SCROLL, self.OnSliderScroll1)
        self.txt1 = wx.StaticText(self, label='Spectrum nr ' + str(0) + ' , T=' + "%.2f" % (temp_range[0]) + ' K')

        self.sld2 = wx.Slider(self, value=self.Ndim_pts, minValue=0, maxValue=len(ppm0)-1, style=wx.SL_INVERSE)
        self.sld2.Bind(wx.EVT_SCROLL, self.OnSliderScroll2)
        self.txt2 = wx.StaticText(self, label="%.3f" %(self.Ndim)+'   '+self.zlabel)

        self.SaveButton = wx.Button(self, -1, "Save figure")
        self.SaveButton.Bind(wx.EVT_BUTTON, self.SaveFigure) 
        self.ButtonMinusSpectrum = wx.Button(self, -1, "-1", size=(35, 35))
        self.ButtonMinusSpectrum.Bind(wx.EVT_BUTTON, self.MinusSpectrum)
        self.ButtonPlusSpectrum = wx.Button(self, -1, "+1", size=(35, 35))
        self.ButtonPlusSpectrum.Bind(wx.EVT_BUTTON, self.PlusSpectrum)

        self.sibut = wx.Button(self,-1,"Zoom")
        self.sibut.Bind(wx.EVT_BUTTON, self.zoom)         
        self.hmbut = wx.Button(self,-1,"Reset")
        self.hmbut.Bind(wx.EVT_BUTTON, self.reset)         
        self.hibut = wx.Button(self,-1,"Pan")
        self.hibut.Bind(wx.EVT_BUTTON, self.pan)

        self.figure = Figure()
        self.axesobject = self.figure.add_subplot(111)
        self.canvas = FigureCanvas(self, -1, self.figure)
        self.toolbar = NavigationToolbar(self.canvas)
        self.toolbar.Hide()


        self.sizer = wx.BoxSizer(wx.VERTICAL)
        self.sizer.Add(self.sld1, proportion=0, flag=wx.LEFT | wx.TOP | wx.GROW)
        self.sizer.Add(self.txt1, proportion=0, flag=wx.RIGHT | wx.TOP | wx.GROW)
        self.sizer.Add(self.canvas, 1, wx.LEFT | wx.TOP | wx.GROW)
        self.sizer.Add(self.sld2, proportion=0, flag=wx.LEFT | wx.BOTTOM | wx.GROW)
        self.sizer.Add(self.txt2, proportion=0, flag=wx.RIGHT | wx.BOTTOM | wx.GROW)
        

        self.sizerZoom = wx.BoxSizer(wx.HORIZONTAL)
        self.sizerZoom.Add(self.sibut, flag=wx.LEFT)
        self.sizerZoom.Add(self.hmbut, flag=wx.ALIGN_CENTER_HORIZONTAL)
        self.sizerZoom.Add(self.hibut, flag=wx.RIGHT)

        self.sizer.Add(self.sizerZoom, flag=wx.ALIGN_CENTER_HORIZONTAL)
        self.sizer.Add(self.SaveButton, flag=wx.ALIGN_CENTER_HORIZONTAL)

        self.sizerH = wx.BoxSizer(wx.HORIZONTAL)
        self.sizerH.Add(self.ButtonMinusSpectrum, proportion=0, flag=wx.LEFT | wx.TOP)
        self.sizerH.Add(self.sizer, proportion=0)
        self.sizerH.Add(self.ButtonPlusSpectrum, proportion=0, flag=wx.RIGHT | wx.TOP)
        self.SetSizer(self.sizerH)
       
        self.axesobject.set_xlabel(self.xlabel)
        self.axesobject.set_ylabel(self.ylabel)
        self.axesobject.yaxis.set_label_position("right")
        self.axesobject.yaxis.tick_right()

        self.factor = 1.3
        self.levels = 20
        if 'thresh' in globals():
            self.thrmin = thresh
        else:
            self.thrmin = max(np.amax(self.partition(self.data, self.Ndim)[0]), np.amax(self.partition(self.data, self.Ndim)[1])) * 0.2 # threshold factor
        print 'Contour plot threshold: %e' % self.thrmin, '\n'

        self.contplot(self.data, self.Ndim, self.thrmin, self.xaxis, self.yaxis, self.factor, self.levels, positions_selectedpeak)
        self.plotpeakpos(self.val_spec, positions_selectedpeak) 

        self.kinN = KinNFrame(self, peakname, positions_selectedpeak)
        self.kinN.Show()
        self.kinCO = KinCOFrame(self, peakname, positions_selectedpeak)
        self.kinCO.Show()
        self.kinH = KinHFrame(self, peakname, positions_selectedpeak)
        self.kinH.Show()

        self.positions_selectedpeak = positions_selectedpeak

        self.figurename = last_figurename

    def partition(self, data, Ndim):
        Ndim_pts = np.abs(self.zaxis - Ndim).argmin()
        data_plus = np.squeeze(copy.deepcopy(data[Ndim_pts, :, :])) # N dimension is the first one
        data_plus[data_plus < 0] = 0
        data_minus = np.squeeze(copy.deepcopy(data[Ndim_pts, :, :]))
        data_minus[data_minus > 0] = 0
        data_minus = abs(data_minus)
        return data_plus, data_minus

    def contplot(self, data, Ndim, thrmin, xaxis, yaxis, factor, levels, positions_selectedpeak):
        self.axesobject.clear()
        self.axesobject.set_xlabel(self.xlabel)
        self.axesobject.set_ylabel(self.ylabel)

        self.axesobject.yaxis.set_label_position("right")
        self.axesobject.yaxis.tick_right()
        if np.max(np.max(self.partition(self.data, self.Ndim)[0])) >= self.thrmin:
            self.axesobject.contour(self.xaxis, self.yaxis, self.partition(self.data, self.Ndim)[0], np.logspace(math.log(self.thrmin, \
                self.factor), math.log(self.thrmin * self.factor ** self.levels, self.factor), self.levels, endpoint=True, \
                    base=self.factor), cmap='winter')
        if np.max(np.max(self.partition(self.data, self.Ndim)[1])) >= self.thrmin:
            self.axesobject.contour(self.xaxis, self.yaxis, self.partition(self.data, self.Ndim)[1], np.logspace(math.log(self.thrmin, \
                self.factor), math.log(self.thrmin * self.factor ** self.levels, self.factor), self.levels, endpoint=True, \
                    base=self.factor), cmap='autumn')


        if not self.zoomclicked:
            self.axesobject.set_xlim(self.xaxis[-1], self.xaxis[0])
            self.axesobject.set_ylim(self.yaxis[-1], self.yaxis[0])
        else:
            self.axesobject.set_xlim(self.xlim)
            self.axesobject.set_ylim(self.ylim)

        for i in range(pn):
            if positions[i, self.val_spec, 0] == self.Ndim:
                if i in selected:
                    self.axesobject.text(positions[i, self.val_spec, 2], positions[i, self.val_spec, 1], names1[np.where(selected==i)[0][0]], color="green", fontsize=12)
        self.canvas.draw()

    def plotpeakpos(self, val_spec, positions_selectedpeak):
        self.axesobject.plot(positions_selectedpeak[val_spec, 2], positions_selectedpeak[val_spec, 1], marker='+', mew=2, ms=20, color='black')
        self.canvas.draw()

    def plotpeakpos_pale(self, val_spec, positions_selectedpeak):
        self.axesobject.plot(positions_selectedpeak[val_spec, 2], positions_selectedpeak[val_spec, 1], marker='+', mew=0.5, ms=15, color='cornflowerblue')
        self.canvas.draw()


    def OnSliderScroll1(self, e1): # spectrum number
        self.xlim=self.axesobject.get_xlim()
        self.ylim=self.axesobject.get_ylim()
        obj = e1.GetEventObject()
        val = obj.GetValue()
        self.val_spec = int(val)
        self.txt1.SetLabel('Spectrum nr ' + str(val) + ' , T=' + "%.2f" % (temp_range[val]) + ' K')
        self.dic, self.data = ng.fileio.pipe.read('tempscan'+str(int(val))+'.ft3')
        self.data = np.flip(np.flip(np.flip(self.data, 0), 1), 2)
        ndim = self.positions_selectedpeak[self.val_spec, 0]
        if np.isnan(ndim):
            self.contplot(self.data, self.Ndim, self.thrmin, self.xaxis, self.yaxis, self.factor, self.levels, self.positions_selectedpeak) #!!
        else:
            self.Ndim_peak = self.zaxis.flat[np.abs(self.zaxis - ndim).argmin()]
            if self.Ndim!=self.Ndim_peak:
                self.Ndim=self.Ndim_peak
                self.Ndim_pts = np.where(ppm0==self.Ndim)[0]
                self.sld2.SetValue(self.Ndim_pts)
                self.txt2.SetLabel("%.3f" %(self.Ndim)+'   '+self.zlabel)
            self.contplot(self.data, self.Ndim, self.thrmin, self.xaxis, self.yaxis, self.factor, self.levels, self.positions_selectedpeak) 
            self.plotpeakpos(self.val_spec, self.positions_selectedpeak)


    def OnSliderScroll2(self, e2): # N (0) dim
        self.xlim = self.axesobject.get_xlim()
        self.ylim = self.axesobject.get_ylim()
        obj = e2.GetEventObject()
        val = obj.GetValue()
        self.Ndim_pts = val
        self.Ndim = ppm0[self.Ndim_pts]
        self.txt2.SetLabel("%.3f" %(self.Ndim)+'   '+self.zlabel)
        if self.Ndim==self.Ndim_peak:
            self.contplot(self.data, self.Ndim, self.thrmin, self.xaxis, self.yaxis, self.factor, self.levels, self.positions_selectedpeak) #!!
            self.plotpeakpos(self.val_spec, self.positions_selectedpeak)
        else:
            self.contplot(self.data, self.Ndim, self.thrmin, self.xaxis, self.yaxis, self.factor, self.levels, self.positions_selectedpeak) #!!
            self.plotpeakpos_pale(self.val_spec, self.positions_selectedpeak)


    def SaveFigure(self, e):
        self.figure.savefig('figure'+str(self.figurename)+'.eps', format='eps')
        print 'Figure saved as ', '"figure'+str(self.figurename)+'.eps"'
        self.figurename=self.figurename+1

    def MinusSpectrum(self, e):
        self.xlim = self.axesobject.get_xlim()
        self.ylim = self.axesobject.get_ylim()
        if self.val_spec > first_spectrum:
            self.val_spec = self.val_spec - 1
            self.txt1.SetLabel('Spectrum nr ' + str(self.val_spec) + ' , T=' + "%.2f" % (temp_range[self.val_spec]) + ' K')
            self.dic, self.data = ng.fileio.pipe.read('tempscan'+str(int(self.val_spec))+'.ft3')
            self.data = np.flip(np.flip(np.flip(self.data, 0), 1), 2)
            ndim = self.positions_selectedpeak[self.val_spec, 0]
            if np.isnan(ndim):
                self.contplot(self.data, self.Ndim, self.thrmin, self.xaxis, self.yaxis, self.factor, self.levels, self.positions_selectedpeak) #!!
            else:
                self.Ndim_peak = self.zaxis.flat[np.abs(self.zaxis - ndim).argmin()]
                if self.Ndim!=self.Ndim_peak:
                    self.Ndim=self.Ndim_peak
                    self.Ndim_pts = np.where(ppm0==self.Ndim)[0]
                    self.sld2.SetValue(self.Ndim_pts)
                    self.txt2.SetLabel("%.3f" %(self.Ndim)+'   '+self.zlabel)
                self.contplot(self.data, self.Ndim, self.thrmin, self.xaxis, self.yaxis, self.factor, self.levels, self.positions_selectedpeak) 
                self.plotpeakpos(self.val_spec, self.positions_selectedpeak)
            self.sld1.SetValue(self.val_spec)
        else:
            pass

    def PlusSpectrum(self, e):
        self.xlim = self.axesobject.get_xlim()
        self.ylim = self.axesobject.get_ylim()
        if self.val_spec < max_i - 1:
            self.val_spec = self.val_spec + 1
            self.txt1.SetLabel('Spectrum nr ' + str(self.val_spec)+ ' , T=' +"%.2f" %  (temp_range[self.val_spec]) + ' K')
            self.dic, self.data = ng.fileio.pipe.read('tempscan'+str(int(self.val_spec))+'.ft3')
            self.data = np.flip(np.flip(np.flip(self.data, 0), 1), 2)
            ndim = self.positions_selectedpeak[self.val_spec, 0]
            if np.isnan(ndim):
                self.contplot(self.data, self.Ndim, self.thrmin, self.xaxis, self.yaxis, self.factor, self.levels, self.positions_selectedpeak) #!!
            else:
                self.Ndim_peak = self.zaxis.flat[np.abs(self.zaxis - ndim).argmin()]
                if self.Ndim!=self.Ndim_peak:
                    self.Ndim=self.Ndim_peak
                    self.Ndim_pts = np.where(ppm0==self.Ndim)[0]
                    self.sld2.SetValue(self.Ndim_pts)
                    self.txt2.SetLabel("%.3f" %(self.Ndim)+'   '+self.zlabel)
                self.contplot(self.data, self.Ndim, self.thrmin, self.xaxis, self.yaxis, self.factor, self.levels, self.positions_selectedpeak) 
                self.plotpeakpos(self.val_spec, self.positions_selectedpeak)
            self.sld1.SetValue(self.val_spec)
        else:
            pass


    def zoom(self,event):
        # self.statusbar.SetStatusText("Zoom")
        self.toolbar.zoom()
        self.zoomclicked=True


    def reset(self,event):
        # self.statusbar.SetStatusText("Reset")
        self.toolbar.home()
        self.zoomclicked=False

    def pan(self,event):
        # self.statusbar.SetStatusText("Pan")
        self.toolbar.pan()


class TestFrame(wx.Frame):
    def __init__(self, parent, title, positions_selectedpeak, peakname):
        wx.Frame.__init__(self, parent, title=title, size=(750, 700))

        self.p1 = SpectrumPanel(self, positions_selectedpeak, peakname)
        self.p1.contplot(self.p1.data, self.p1.Ndim, self.p1.thrmin, self.p1.xaxis, self.p1.yaxis, self.p1.factor, self.p1.levels, positions_selectedpeak)
        self.p1.plotpeakpos(self.p1.val_spec, positions_selectedpeak)

### Select peak ###

def select_peak(peakname):
    peakindex = np.where(names==peakname)[0]
    selectedpeakindict = selected[peakindex][0]
    positions_selectedpeak = np.zeros((max_i, 3))

    f = d.keys()
    for fj in f:
        if fj!='threshold':
            j = int(fj.replace('frame_', ''))
            try:
                positions_selectedpeak[j, :] = np.array(d[fj][str('peak_') + str(selectedpeakindict).zfill(5)]['position'])
            except KeyError:
                positions_selectedpeak[j, :] = [np.nan, np.nan, np.nan]
    return positions_selectedpeak


class ListBoxFrame(wx.Frame):
    def __init__(self):
        wx.Frame.__init__(self, None, -1, 'Double-click to select peak',size=(350, 250))
        panel = wx.Panel(self, -1)

        self.listBox = wx.ListBox(panel, -1, (20, 0), (300, 180), names, wx.LB_SINGLE) # single selection
        self.Bind(wx.EVT_LISTBOX_DCLICK, self.OnSelect, id = -1)
        print 'Select peak!'

    def OnSelect(self, event):
        self.peakname = self.listBox.GetString(self.listBox.GetSelection())
        print '\nPeak', self.peakname, 'selected'
        self.positions_selectedpeak = select_peak(self.peakname)

        try:
            self.frame.Hide()
            self.frame.Layout()
        except AttributeError:
            pass

        self.frame = TestFrame(None, os.path.split(os.getcwd())[1], self.positions_selectedpeak, self.peakname)
        self.frame.Show(True)

                
app = wx.App()
ListBoxFrame().Show()
app.MainLoop()
