#!/usr/bin/python

import numpy as np
import matplotlib.pyplot as plt
import nmrglue as ng
import math
import os
import itertools
import scipy
import time
import struct

###### General parameters and Mnova data import
factor = 0.2 # proportion of chunks measured
samp_limit = int(1e4)

multiplets = np.loadtxt('multiplets1.txt')
shift_ppm = multiplets[:, 0]
a = multiplets[:, 1]

peakwidths = np.loadtxt('peaks1.txt')
fwhm_Hz = np.mean(peakwidths)*math.pi 

###### PSYCHE parameters
dirs = os.listdir(os.getcwd())
for i in range(len(dirs)):
	if "PSYCHE" in dirs[i]:
		psyche_dir = os.getcwd() + '/' + dirs[i]
		break

dic, data = ng.varian.read(dir=psyche_dir, fid_file='fid', procpar_file='procpar')
dic1 = ng.varian.read_procpar(psyche_dir+'/procpar')
npo = int(dic1['np']['values'][0])/2
at = float(dic1['at']['values'][0])
reffrq = float(dic1['reffrq']['values'][0])
sfrq = float(dic1['sfrq']['values'][0])
sw = float(dic1['sw']['values'][0])
sw1 = float(dic1['sw1']['values'][0])
rp = float(dic1['rp']['values'][0])
lp = float(dic1['lp']['values'][0])

D = len(data)
data = data*np.exp(-1j*rp*math.pi/180) # p0 phasing
data = np.flipud(np.fft.fftshift(np.fft.fft(data)))
data = data*np.exp(-1j*np.arange(D)*lp*math.pi/(D*180)) # p1 phasing
data = np.fft.ifft(np.flipud(np.fft.fftshift(data))) 

######
chunk = int(math.floor(sw/sw1 + 0.5))
print '\nChunk = ', chunk, 'points'
t = np.linspace(0, at, npo)
shift_Hz = (sfrq -(shift_ppm*1e-6*reffrq + reffrq))*1e6
shift = np.round(shift_Hz*npo/sw + npo/2)
fwhm_t = fwhm_Hz*npo/sw
fwhm = fwhm_t*at/npo

###### Simulating singlet spectrum
fid = np.zeros((npo,), dtype=complex)
n = np.linspace(0, npo-1, npo);
for i in range(len(a)):
    fid = fid + a[i]*np.exp((2*math.pi*1j*shift[i]/npo - fwhm)*n)
s = np.fft.fft(fid)
peak_for_normalization = np.argmax(np.abs(s)) 
s = s/np.abs(s[peak_for_normalization]) 

###### Generating samplings
full_number_of_chunks = int(math.floor(npo/chunk)) 
number_of_measured_chunks = int(math.floor(full_number_of_chunks*factor))
print 'Full number of chunks -', full_number_of_chunks, ', number of measured chunks -', number_of_measured_chunks
number_of_samplings = int(math.floor(scipy.misc.comb(full_number_of_chunks-1, number_of_measured_chunks-1)))
print 'Number of all samplings: ', number_of_samplings

print '\nGenerating sampling schedules...'
start = time.time()
np.random.seed(1)
if number_of_samplings > samp_limit:
	print 'Using randomly selected samplings!'
	Chunks_measured = np.ones((samp_limit+2, number_of_measured_chunks))
	for i in range(samp_limit):
		Chunks_measured[i, 1:] = np.sort(np.random.permutation(np.arange(2, full_number_of_chunks)) \
			[:number_of_measured_chunks - 1])
	Chunks_measured[samp_limit, :] = np.linspace(1, number_of_measured_chunks, number_of_measured_chunks) # truncation
	Chunks_measured[samp_limit+1, :] = Chunks_measured[samp_limit, :]*math.floor(1/factor) - \
		(math.floor(1/factor) - 1) # uniform distribution
else:
	print 'Using all samplings!'
	C = np.array(list(itertools.combinations(np.linspace(2, full_number_of_chunks, full_number_of_chunks-1), \
		number_of_measured_chunks-1)))
	Chunks_measured = np.hstack((np.ones((np.shape(C)[0], 1)), C))
Chunks_measured = Chunks_measured.astype(int)
print '\nSamplings: ', Chunks_measured, np.shape(Chunks_measured)

###### Comparing samplings
s_nus = np.zeros((np.shape(Chunks_measured)[0], len(s)), dtype=complex)
max_artefact = np.zeros((np.shape(Chunks_measured)[0],))
L1 = np.zeros((np.shape(Chunks_measured)[0],))
L2 = np.zeros((np.shape(Chunks_measured)[0],))

if os.path.isfile('s_nus.npy'):
	s_nus = np.load('s_nus.npy')
	for j in range(np.shape(Chunks_measured)[0]):
	    max_artefact[j] = np.max(np.abs(s_nus[j, :]-s))
	    L1[j] = np.sum(np.abs(s_nus[j, :]-s))

else:
	for j in range(np.shape(Chunks_measured)[0]):
	    a = Chunks_measured[j, :] 
	    fid_nus = np.zeros(np.shape(fid), dtype=complex)
	    for k in range(np.shape(Chunks_measured)[1]):
	        fid_nus[(a[k]-1)*chunk:a[k]*chunk] = fid[(a[k]-1)*chunk:a[k]*chunk]
	    s_nus[j, :] = np.fft.fft(fid_nus)
	    s_nus[j, :] = s_nus[j, :]/np.abs(s_nus[j, peak_for_normalization]) 

	    max_artefact[j] = np.max(np.abs(s_nus[j, :]-s))
	    L1[j] = np.sum(np.abs(s_nus[j, :]-s))

end = time.time()
print 'done!'
print 'Time: ', end - start, 's'

criterion = {}
criterion['Linf'] = max_artefact
criterion['L1'] = L1

main_folder = os.getcwd()

def best_and_worst_samplings(c, bw):
	global s
	os.chdir(main_folder)
	if bw == 'best':
		C = np.argmin(criterion[c])
	if bw == 'worst':
		C = np.argmax(criterion[c])
	selected_spectrum = s_nus[C, :]
	selected_fid = np.fft.ifft(s_nus[C, :])*np.abs(s_nus[C, peak_for_normalization])
	selected_sampling = Chunks_measured[C, :]
	print selected_sampling

	print 'Saving figures of simulated FIDs...'
	plt.figure()
	plt.plot(t, np.real(selected_fid))
	plt.title(bw + ' sampling sampling - ' + c + ' criterion (simulated FID)')
	plt.savefig(bw + '_' + c + '_FID.eps')

	plt.figure()
	plt.plot(np.abs(selected_spectrum))
	plt.plot(np.abs(s))
	plt.title(bw + ' sampling - ' + c + ' criterion \n(simulated spectrum with artefacts)')
	plt.savefig(bw + '_' + c + '_spec0.eps')
	print 'done!'

	print 'Preparing files for reconstruction...'
	folder_r = main_folder + '/' + c 
	if not os.path.exists(folder_r):
	    os.makedirs(folder_r)
	os.chdir(folder_r)

	folder_r_bw = folder_r + '/' + c + '_' + bw
	if not os.path.exists(folder_r_bw):
		os.makedirs(folder_r_bw)
	os.chdir(folder_r_bw)

	np.savetxt(bw + '_sampling.txt', selected_sampling, fmt='%d')

	folder_r_bw0 = folder_r_bw + '/0th_iteration_' + c + '_' + bw
	if not os.path.exists(folder_r_bw0):
		os.makedirs(folder_r_bw0)
	ng.fileio.varian.write(folder_r_bw0, dic, selected_fid, overwrite=True)
	
	I = len(selected_sampling)*chunk
	ind = np.zeros(I,)
	for k in range(len(selected_sampling)):
	    ind[chunk*k:chunk*(k+1)] = np.arange((selected_sampling[k] - 1)*chunk, selected_sampling[k]*chunk, 1)
	ind = ind.astype(int)

	k = 0
	ind_alternated = np.zeros(2*I)
	for i in range(I):
	    ind_alternated[k] = 2*ind[i]
	    ind_alternated[k + 1] = ind_alternated[k] + 1
	    k = k + 2
	ind_alternated = ind_alternated.astype(int)

	regions = 1
	dimensions = 2
	line1 = str(dimensions) + ' 1 ' + str(regions*2*I) + '\n'
	line2 = str(regions) + ' ' + str(2*D)
	s1 = ''
	for i in range(2*I):
	    s1 = s1 + '\n' + str(0) + ' ' + str(ind_alternated[i])

	folder = folder_r_bw + '/MDD' 
	if not os.path.exists(folder):
	    os.makedirs(folder)
	os.chdir(folder)

	fid_alternated = np.empty(2*I)
	fid_alternated[0::2] = np.real(data[ind])
	fid_alternated[1::2] = np.imag(data[ind])
	f = ''
	for j in range(2*I):
	    f = f + '\n' + str(fid_alternated[j])
	contents = 'mdd asc sparse f180.0 \n ./MDD/region01.mdd \n MDD sparse\n $ \n' + line1 + line2 + s1 + f # contains the data to be written to .mdd file

	name = str(1)
	file = open(name + '.mdd', 'w')
	file.write(contents)
	file.close()

	command = 'cssolver ' + name + ' CS_alg=IST CS_niter=300 CS_VE=n MDD_NOISE=1 > ./' + name + '.log'
	command = "tcsh -c  '" + command + "'"
	os.system(command) # executes .mdd file
	print 'done!'

	print 'Reading reconstruction results...'
	file = open(name + '.cs', 'rb')
	cs = np.empty(2*D)
	for j in range(2*D):
	    cs[j] = float(struct.unpack('f', file.read(4))[0])
	file.close()
	fid_rec = cs[0::2].astype(np.float32) + 1j*cs[1::2].astype(np.float32)
	print 'done!'
	print 'Writing reconstructed FIDs...'
	ng.fileio.varian.write(folder_r_bw, dic, fid_rec, overwrite=True)
	print 'done!'

	return None

bw = ['best', 'worst']
c = ['Linf', 'L1']
for i in c:
	for j in bw:
		print '\n', j, 'sampling for', i, 'criterion'
		best_and_worst_samplings(i, j)

plt.show()