import json
import os
import math
import matplotlib.pyplot as plt
import nmrglue as ng
import numpy as np
import scipy
import sys
import tensorflow as tf

from PyQt5.QtChart import QAreaSeries, QChart, QChartView, QLineSeries
from PyQt5.QtCore import QAbstractTableModel, QDir, QPoint, QPointF, QRectF, QSize, Qt
from PyQt5.QtGui import QImage, QPainter, QPixmap
from PyQt5.QtWidgets import QAbstractItemView, QApplication, QFileDialog, QGridLayout, QHBoxLayout, QHeaderView, QLabel, QLineEdit, QSizePolicy, QTableView, QVBoxLayout, QWidget

#from PyQt6.QtCharts import QAreaSeries, QChart, QChartView, QLineSeries
#from PyQt6.QtCore import QAbstractTableModel, QDir, QPoint, QPointF, QRectF, QSize, Qt
#from PyQt6.QtGui import QImage, QPainter, QPixmap
#from PyQt6.QtWidgets import QAbstractItemView, QApplication, QFileDialog, QGridLayout, QHBoxLayout, QHeaderView, QLabel, QLineEdit, QSizePolicy, QTableView, QVBoxLayout, QWidget

#Numbers of pixels in speed, time/frequency, and series
spd_pixels = 256
frq_pixels = 256
tim_pixels = 32768
ser_pixels = 20

spd_param_delta = 400. / spd_pixels
frq_param_delta = 256. / frq_pixels

#Load detector model
model = tf.keras.models.load_model(r'model/model.h5')
model.summary()

#Load model metadata
with open(r'model/model.json') as file:
    meta = json.load(file)

#Ranges of width, height, and ratio
twf_param_beg, twf_param_end = meta['twf_beg_param'], meta['twf_end_param']
thf_param_beg, thf_param_end = meta['thf_beg_param'], meta['thf_end_param']
trf_param_beg, trf_param_end = meta['trf_beg_param'], meta['trf_end_param']

#Range of noise
nsf_param_end = meta['nsf_end_param']

fid = np.empty((tim_pixels, ser_pixels), dtype = np.complex128)
for ser_pixel in range(ser_pixels):
    dic, fid[:, ser_pixel] = ng.bruker.read(os.path.join(sys.argv[1], str(1 + ser_pixel)))

par_unit_name = sys.argv[2]
par_unit_delta = float(sys.argv[3])

frq_unit_beg, frq_unit_end = dic['procs']['ABSF1'], dic['procs']['ABSF2']
frq_unit_delta = (frq_unit_end - frq_unit_beg) / tim_pixels
spd_unit_delta = 1.e3 * frq_unit_delta * spd_param_delta / par_unit_delta / ser_pixels

def frq_unit_from_param(frq_param):
    frq_unit = frq_unit_beg + frq_param * frq_unit_delta / frq_param_delta
    return frq_unit

def frq_param_from_unit(frq_unit):
    frq_param = (frq_unit - frq_unit_beg) * frq_param_delta / frq_unit_delta
    return frq_param

def spd_unit_from_param(spd_param):
    spd_unit = spd_param * spd_unit_delta / spd_param_delta
    return spd_unit

def spd_param_from_unit(spd_unit):
    spd_param = spd_unit * spd_param_delta / spd_unit_delta
    return spd_param

app = QApplication([])

class PeakModel(QAbstractTableModel):
    def __init__(self, parent = None):
        QAbstractTableModel.__init__(self, parent)
        self.peakList = list()
    def rowCount(self, parent):
        return len(self.peakList)
    def columnCount(self, parent):
        return 3
    def data(self, index, role):
        if role == Qt.ItemDataRole.DisplayRole:
            return '%.6f' % (self.peakList[index.row()][index.column()],)
        else:
            return None
    def headerData(self, section, orientation, role):
        if orientation == Qt.Orientation.Horizontal and role == Qt.ItemDataRole.DisplayRole:
            if section == 0:
                return 'Frq[ppm]'
            elif section == 1:
                return 'Spd[ppb/' + par_unit_name + ']'
            elif section == 2:
                return 'Probability'
        else:
            return super().headerData(section, orientation, role)
    def setPeakList(self, peakList):
        self.beginResetModel()
        self.peakList = peakList
        self.endResetModel()

class Widget1D(QWidget):
    def __init__(self, parent = None):
        QWidget.__init__(self, parent)
        self.setSizePolicy(QSizePolicy.Policy.Minimum, QSizePolicy.Policy.Minimum)
        self.setPixmap(QPixmap(), 0., 1.)        
    def paintEvent(self, event):
        rectangle = self.rect().adjusted(50, 0, -15, -20)
        pix_hor_min, pix_hor_max = rectangle.right(), rectangle.left()
        pix_ver_min, pix_ver_max = rectangle.bottom(), rectangle.top()
        pix_hor_ticks = (pix_hor_max - pix_hor_min) * (self.pnt_hor_ticks - self.pnt_hor_min) / (self.pnt_hor_max - self.pnt_hor_min) + pix_hor_min
        pix_hor_ticks = pix_hor_ticks.round().astype(np.int64).tolist()
        painter = QPainter(self)
        for pix_hor_tick, pnt_hor_tick in zip(pix_hor_ticks, self.pnt_hor_ticks):
            painter.drawLine(pix_hor_tick, pix_ver_min, pix_hor_tick, pix_ver_min + 5)
            text = '%.2f' % (pnt_hor_tick,)
            rect = painter.boundingRect(self.rect(), Qt.AlignmentFlag.AlignLeft | Qt.AlignmentFlag.AlignTop, text)
            rect.moveCenter(QPoint(pix_hor_tick, pix_ver_min))
            rect.moveTop(pix_ver_min + 5)
            painter.drawText(rect, Qt.AlignmentFlag.AlignCenter, text)
        painter.drawPixmap(rectangle, self.pixmap)
    def sizeHint(self):
        return QSize(321, 292)
    def setPixmap(self, pixmap, pnt_hor_min, pnt_hor_max):
        self.pixmap = pixmap
        self.pnt_hor_min = pnt_hor_min
        self.pnt_hor_max = pnt_hor_max
        self.pnt_hor_ticks = np.arange(np.ceil(self.pnt_hor_min / 0.02) * 0.02, self.pnt_hor_max, 0.02)
        self.update()

class Widget2D(QWidget):
    def __init__(self, parent = None):
        QWidget.__init__(self, parent)
        self.setSizePolicy(QSizePolicy.Policy.Minimum, QSizePolicy.Policy.Minimum)
        self.setPixmap(QPixmap(), 0., 1., 0., 1.)        
    def paintEvent(self, event):
        rectangle = self.rect().adjusted(50, 5, -15, -20)
        pix_hor_min, pix_hor_max = rectangle.right(), rectangle.left()
        pix_ver_min, pix_ver_max = rectangle.bottom(), rectangle.top()
        pix_hor_ticks = (pix_hor_max - pix_hor_min) * (self.pnt_hor_ticks - self.pnt_hor_min) / (self.pnt_hor_max - self.pnt_hor_min) + pix_hor_min
        pix_hor_ticks = pix_hor_ticks.round().astype(np.int64).tolist()
        pix_ver_ticks = (pix_ver_max - pix_ver_min) * (self.pnt_ver_ticks - self.pnt_ver_min) / (self.pnt_ver_max - self.pnt_ver_min) + pix_ver_min
        pix_ver_ticks = pix_ver_ticks.round().astype(np.int64).tolist()
        painter = QPainter(self)
        for pix_hor_tick, pnt_hor_tick in zip(pix_hor_ticks, self.pnt_hor_ticks):
            painter.drawLine(pix_hor_tick, pix_ver_min, pix_hor_tick, pix_ver_min + 5)
            text = '%.2f' % (pnt_hor_tick,)
            rect = painter.boundingRect(self.rect(), Qt.AlignmentFlag.AlignLeft | Qt.AlignmentFlag.AlignTop, text)
            rect.moveCenter(QPoint(pix_hor_tick, pix_ver_min))
            rect.moveTop(pix_ver_min + 5)
            painter.drawText(rect, Qt.AlignmentFlag.AlignCenter, text)
        for pix_ver_tick, pnt_ver_tick in zip(pix_ver_ticks, self.pnt_ver_ticks):
            painter.drawLine(pix_hor_max, pix_ver_tick, pix_hor_max - 5, pix_ver_tick)
            text = ('%.' + self.pnt_ver_digit + 'f') % (pnt_ver_tick,)
            rect = painter.boundingRect(self.rect(), Qt.AlignmentFlag.AlignLeft | Qt.AlignmentFlag.AlignTop, text)
            rect.moveCenter(QPoint(pix_hor_max, pix_ver_tick))
            rect.moveRight(pix_hor_max - 7)
            painter.drawText(rect, Qt.AlignmentFlag.AlignCenter, text)
        painter.drawPixmap(rectangle, self.pixmap)
    def sizeHint(self):
        return QSize(321, 281)
    def setPixmap(self, pixmap, pnt_hor_min, pnt_hor_max, pnt_ver_min, pnt_ver_max):
        self.pixmap = pixmap
        self.pnt_hor_min, self.pnt_hor_max = pnt_hor_min, pnt_hor_max
        self.pnt_ver_min, self.pnt_ver_max = pnt_ver_min, pnt_ver_max
        pnt_ver_step = np.power(10., np.floor(np.log10(self.pnt_ver_max - self.pnt_ver_min)))
        if (self.pnt_ver_max - self.pnt_ver_min) / pnt_ver_step < 2.:
            pnt_ver_step /= 5.
        elif (self.pnt_ver_max - self.pnt_ver_min) / pnt_ver_step < 5.:
            pnt_ver_step /= 2.
        self.pnt_ver_digit = str(max(0, -int(np.floor(np.log10(pnt_ver_step)))))
        self.pnt_hor_ticks = np.arange(np.ceil(self.pnt_hor_min / 0.02) * 0.02, self.pnt_hor_max, 0.02)
        self.pnt_ver_ticks = np.arange(np.ceil(self.pnt_ver_min / pnt_ver_step) * pnt_ver_step, self.pnt_ver_max, pnt_ver_step)
        self.update()

class Picker(QWidget):

    def __init__(self):
        super().__init__()

        #Fourier spectrum
        self.fourier = tf.transpose(tf.signal.fft(tf.transpose(fid)))
        self.fourier = tf.signal.fftshift(self.fourier, 0)
        self.fourier = tf.math.real(self.fourier)

        self.spd_pixel_beg, self.spd_pixel_end = -128, 128
        self.frq_pixel_beg, self.frq_pixel_end = 16256, 16512

        #Ranges of speed, frequency, time, and series
        spd_param_beg, spd_param_end = self.spd_pixel_beg * spd_param_delta, self.spd_pixel_end * spd_param_delta
        tim_param_beg, tim_param_end = 0., 1.
        ser_param_beg, ser_param_end = 0., 1.

        #Coordinates in speed, time, and series
        self.spd_param_grid = tf.range(spd_param_beg, spd_param_end, (spd_param_end - spd_param_beg) / spd_pixels)
        self.tim_param_grid = tf.range(tim_param_beg, tim_param_end, (tim_param_end - tim_param_beg) / tim_pixels)
        self.ser_param_grid = tf.range(ser_param_beg, ser_param_end, (ser_param_end - ser_param_beg) / ser_pixels)

        #Shifting phase
        phase = tf.exp(-2. * math.pi * tf.complex(0., self.spd_param_grid[:, None, None] * self.tim_param_grid[:, None] * self.ser_param_grid))

        #Radon spectrum
        self.radon = phase * fid
        self.radon = tf.reduce_mean(self.radon, 2)
        self.radon = tf.signal.fft(self.radon)
        self.radon = tf.signal.fftshift(self.radon, 1)
        self.radon = tf.math.real(self.radon)

        self.divide = complex(fid[0, 0]).real
        self.subtract = 0.

        self.frq_pixel_marker = 128
        self.spd_pixel_marker = 128

        self.spectrum = True

        self.colormap = plt.get_cmap()

        self.subtractEdit = QLineEdit()
        self.divideEdit = QLineEdit()
        self.subtractEdit.setFixedWidth(96)
        self.divideEdit.setFixedWidth(96)
        self.subtractEdit.editingFinished.connect(self.onSubtractEdited)
        self.divideEdit.editingFinished.connect(self.onDivideEdited)

        self.frequencyFullMinEdit = QLineEdit('%.6f' % (frq_unit_beg,))
        self.frequencyFullMaxEdit = QLineEdit('%.6f' % (frq_unit_end,))
        self.frequencyFullMinEdit.setFixedWidth(96)
        self.frequencyFullMaxEdit.setFixedWidth(96)
        self.frequencyFullMinEdit.setEnabled(False)
        self.frequencyFullMaxEdit.setEnabled(False)

        self.frequencyVisibleMinEdit = QLineEdit()
        self.frequencyVisibleMaxEdit = QLineEdit()
        self.frequencyVisibleMinEdit.setFixedWidth(96)
        self.frequencyVisibleMaxEdit.setFixedWidth(96)
        self.frequencyVisibleMinEdit.editingFinished.connect(self.onFrequencyVisibleMinEdited)
        self.frequencyVisibleMaxEdit.editingFinished.connect(self.onFrequencyVisibleMaxEdited)

        self.speedVisibleMinEdit = QLineEdit()
        self.speedVisibleMaxEdit = QLineEdit()
        self.speedVisibleMinEdit.setFixedWidth(96)
        self.speedVisibleMaxEdit.setFixedWidth(96)
        self.speedVisibleMinEdit.editingFinished.connect(self.onSpeedVisibleMinEdited)
        self.speedVisibleMaxEdit.editingFinished.connect(self.onSpeedVisibleMaxEdited)

        self.frequencyMarkerEdit = QLineEdit()
        self.speedMarkerEdit = QLineEdit()
        self.frequencyMarkerEdit.setFixedWidth(96)
        self.speedMarkerEdit.setFixedWidth(96)
        self.frequencyMarkerEdit.editingFinished.connect(self.onFrequencyMarkerEdited)
        self.speedMarkerEdit.editingFinished.connect(self.onSpeedMarkerEdited)

        self.frequencyResolutionEdit = QLineEdit('%.6f' % (frq_unit_delta,))
        self.speedResolutionEdit = QLineEdit('%.6f' % (spd_unit_delta,))
        self.frequencyResolutionEdit.setFixedWidth(96)
        self.speedResolutionEdit.setFixedWidth(96)
        self.frequencyResolutionEdit.setEnabled(False)
        self.speedResolutionEdit.setEnabled(False)

        self.peakModel = PeakModel()

        self.peakView = QTableView()
        self.peakView.setModel(self.peakModel)
        self.peakView.horizontalHeader().setSectionResizeMode(QHeaderView.ResizeMode.Stretch)
        self.peakView.setSelectionBehavior(QAbstractItemView.SelectionBehavior.SelectRows)
        self.peakView.setFocusPolicy(Qt.FocusPolicy.NoFocus)
        self.peakView.pressed.connect(self.onPeakSelected)

        controlLayout = QGridLayout()

        controlLayout.addWidget(QLabel('Scaling factors'), 0, 1, 1, 5)
        controlLayout.addWidget(QLabel('Sub'), 1, 0)
        controlLayout.addWidget(self.subtractEdit, 1, 1)
        controlLayout.addWidget(QLabel('Div'), 1, 3)
        controlLayout.addWidget(self.divideEdit, 1, 4)

        controlLayout.addWidget(QLabel('Full frequency range'), 2, 1, 1, 5)
        controlLayout.addWidget(QLabel('Lft'), 3, 0)
        controlLayout.addWidget(self.frequencyFullMinEdit, 3, 1)
        controlLayout.addWidget(QLabel('ppm'), 3, 2)
        controlLayout.addWidget(QLabel('Rgh'), 3, 3)
        controlLayout.addWidget(self.frequencyFullMaxEdit, 3, 4)
        controlLayout.addWidget(QLabel('ppm'), 3, 5)

        controlLayout.addWidget(QLabel('Visible frequency range'), 4, 1, 1, 5)
        controlLayout.addWidget(QLabel('Lft'), 5, 0)
        controlLayout.addWidget(self.frequencyVisibleMinEdit, 5, 1)
        controlLayout.addWidget(QLabel('ppm'), 5, 2)
        controlLayout.addWidget(QLabel('Rgh'), 5, 3)
        controlLayout.addWidget(self.frequencyVisibleMaxEdit, 5, 4)
        controlLayout.addWidget(QLabel('ppm'), 5, 5)

        controlLayout.addWidget(QLabel('Visible speed range'), 6, 1, 1, 5)
        controlLayout.addWidget(QLabel('Top'), 7, 0)
        controlLayout.addWidget(self.speedVisibleMinEdit, 7, 1)
        controlLayout.addWidget(QLabel('ppb/' + par_unit_name), 7, 2)
        controlLayout.addWidget(QLabel('Bot'), 7, 3)
        controlLayout.addWidget(self.speedVisibleMaxEdit, 7, 4)
        controlLayout.addWidget(QLabel('ppb/' + par_unit_name), 7, 5)

        controlLayout.addWidget(QLabel('Marker positions'), 8, 1, 1, 5)
        controlLayout.addWidget(QLabel('Frq'), 9, 0)
        controlLayout.addWidget(self.frequencyMarkerEdit, 9, 1)
        controlLayout.addWidget(QLabel('ppm'), 9, 2)
        controlLayout.addWidget(QLabel('Spd'), 9, 3)
        controlLayout.addWidget(self.speedMarkerEdit, 9, 4)
        controlLayout.addWidget(QLabel('ppb/' + par_unit_name), 9, 5)

        controlLayout.addWidget(QLabel('Frequency and speed per pixel'), 10, 1, 1, 5)
        controlLayout.addWidget(QLabel('Frq'), 11, 0)
        controlLayout.addWidget(self.frequencyResolutionEdit, 11, 1)
        controlLayout.addWidget(QLabel('ppm'), 11, 2)
        controlLayout.addWidget(QLabel('Spd'), 11, 3)
        controlLayout.addWidget(self.speedResolutionEdit, 11, 4)
        controlLayout.addWidget(QLabel('ppb/' + par_unit_name), 11, 5)

        controlLayout.addWidget(self.peakView, 12, 0, 1, 6)

        controlLayout.addWidget(QLabel('Center for Machine Learning @ University of Warsaw'), 13, 0, 1, 6)

        self.fou_widget = Widget1D()
        self.rad_widget = Widget2D()

        spectrumLayout = QVBoxLayout()
        spectrumLayout.addWidget(self.fou_widget)
        spectrumLayout.addWidget(self.rad_widget)

        mainLayout = QHBoxLayout()
        mainLayout.addLayout(controlLayout)
        mainLayout.addLayout(spectrumLayout, 1)

        self.setLayout(mainLayout)
        self.setWindowTitle('Radon Peak Picker 1.0')

        self.detect()
        self.draw()
        self.setFocus()

    def onSubtractEdited(self):
        self.subtract = float(self.subtractEdit.text())
        self.detect()
        self.draw()
        self.setFocus()

    def onDivideEdited(self):
        self.divide = float(self.divideEdit.text())
        self.detect()
        self.draw()
        self.setFocus()

    def onFrequencyVisibleMinEdited(self):
        self.frq_pixel_beg = int(round(frq_param_from_unit(float(self.frequencyVisibleMinEdit.text())) / frq_param_delta))
        self.frq_pixel_end = self.frq_pixel_beg + frq_pixels
        self.detect()
        self.draw()
        self.setFocus()

    def onFrequencyVisibleMaxEdited(self):
        self.frq_pixel_end = int(round(frq_param_from_unit(float(self.frequencyVisibleMaxEdit.text())) / frq_param_delta))
        self.frq_pixel_beg = self.frq_pixel_end - frq_pixels
        self.detect()
        self.draw()
        self.setFocus()

    def onSpeedVisibleMinEdited(self):
        self.spd_pixel_beg = int(round(spd_param_from_unit(float(self.speedVisibleMinEdit.text())) / spd_param_delta))
        self.spd_pixel_end = self.spd_pixel_beg + spd_pixels
        spd_param_beg, spd_param_end = self.spd_pixel_beg * spd_param_delta, self.spd_pixel_end * spd_param_delta
        self.spd_param_grid = tf.range(spd_param_beg, spd_param_end, (spd_param_end - spd_param_beg) / spd_pixels)
        phase = tf.exp(-2. * math.pi * tf.complex(0., self.spd_param_grid[:, None, None] * self.tim_param_grid[:, None] * self.ser_param_grid))
        self.radon = phase * fid
        self.radon = tf.reduce_mean(self.radon, 2)
        self.radon = tf.signal.fft(self.radon)
        self.radon = tf.signal.fftshift(self.radon, 1)
        self.radon = tf.math.real(self.radon)
        self.detect()
        self.draw()
        self.setFocus()

    def onSpeedVisibleMaxEdited(self):
        self.spd_pixel_end = int(round(spd_param_from_unit(float(self.speedVisibleMaxEdit.text())) / spd_param_delta))
        self.spd_pixel_beg = self.spd_pixel_end - spd_pixels
        spd_param_beg, spd_param_end = self.spd_pixel_beg * spd_param_delta, self.spd_pixel_end * spd_param_delta
        self.spd_param_grid = tf.range(spd_param_beg, spd_param_end, (spd_param_end - spd_param_beg) / spd_pixels)
        phase = tf.exp(-2. * math.pi * tf.complex(0., self.spd_param_grid[:, None, None] * self.tim_param_grid[:, None] * self.ser_param_grid))
        self.radon = phase * fid
        self.radon = tf.reduce_mean(self.radon, 2)
        self.radon = tf.signal.fft(self.radon)
        self.radon = tf.signal.fftshift(self.radon, 1)
        self.radon = tf.math.real(self.radon)
        self.detect()
        self.draw()
        self.setFocus()

    def onFrequencyMarkerEdited(self):
        self.frq_pixel_marker = int(round(frq_param_from_unit(float(self.frequencyMarkerEdit.text())) / frq_param_delta - self.frq_pixel_beg - 0.5))
        self.draw()
        self.setFocus()

    def onSpeedMarkerEdited(self):
        self.spd_pixel_marker = int(round(spd_param_from_unit(float(self.speedMarkerEdit.text())) / spd_param_delta - self.spd_pixel_beg - 0.5))    
        self.draw()
        self.setFocus()

    def onPeakSelected(self, index):
        obj_blob = index.row() + 1
        obj_hot = self.obj_blob == obj_blob
        obj_prob = np.where(obj_hot, self.obj_prob, 0.)
        obj_index = obj_prob.ravel().argmax()
        self.spd_pixel_marker, self.frq_pixel_marker = np.unravel_index(obj_index, (spd_pixels, frq_pixels))
        self.draw()

    def detect(self):
        radon = self.radon[:, self.frq_pixel_beg: self.frq_pixel_end]
        radon = (radon - self.subtract) / self.divide
        #Predicted logits and shifts from model
        self.obj_logit = model.predict(radon[None])[0]
        #Predicted logits and labels
        self.obj_label = np.where(self.obj_logit < 0., False, True)
        self.obj_prob = scipy.special.expit(self.obj_logit)
        self.obj_blob, self.obj_blobs = scipy.ndimage.label(self.obj_label.T, np.ones((3, 3)))
        self.obj_blob = self.obj_blob.T
        obj_hot = 1 + np.arange(self.obj_blobs)[:, None, None] == self.obj_blob
        obj_prob = np.where(obj_hot, self.obj_prob, 0.)
        obj_index = obj_prob.reshape(-1, spd_pixels * frq_pixels).argmax(1)
        self.spd_pixel, self.frq_pixel = np.unravel_index(obj_index, (spd_pixels, frq_pixels))
        peakList = list()
        for index in range(self.obj_blobs):
            peakList.append((frq_unit_from_param((self.frq_pixel_beg + self.frq_pixel[index] + 0.5) * frq_param_delta), \
                             spd_unit_from_param((self.spd_pixel_beg + self.spd_pixel[index] + 0.5) * spd_param_delta), \
                             obj_prob[index, self.spd_pixel[index], self.frq_pixel[index]]))
        self.peakModel.setPeakList(peakList)

    def draw(self):

        self.subtractEdit.setText('%.4e' % (self.subtract,))
        self.divideEdit.setText('%.4e' % (self.divide,))

        self.frequencyVisibleMinEdit.setText('%.6f' % (frq_unit_from_param(self.frq_pixel_beg * frq_param_delta),))
        self.frequencyVisibleMaxEdit.setText('%.6f' % (frq_unit_from_param(self.frq_pixel_end * frq_param_delta),))

        self.speedVisibleMinEdit.setText('%.6f' % (spd_unit_from_param(self.spd_pixel_beg * spd_param_delta),))
        self.speedVisibleMaxEdit.setText('%.6f' % (spd_unit_from_param(self.spd_pixel_end * spd_param_delta),))

        self.frequencyMarkerEdit.setText('%.6f' % (frq_unit_from_param((self.frq_pixel_beg + self.frq_pixel_marker + 0.5) * frq_param_delta),))
        self.speedMarkerEdit.setText('%.6f' % (spd_unit_from_param((self.spd_pixel_beg + self.spd_pixel_marker + 0.5) * spd_param_delta),))

        obj_blob = self.obj_blob[self.spd_pixel_marker, self.frq_pixel_marker]
        if obj_blob:
            obj_hot = self.obj_blob == obj_blob
            obj_prob = np.where(obj_hot, self.obj_prob, 0.)
            obj_index = obj_prob.ravel().argmax()
            spd_pixel, frq_pixel = np.unravel_index(obj_index, (spd_pixels, frq_pixels))
            self.peakView.selectRow(obj_blob - 1)
        else:
            self.peakView.clearSelection()

        rad_array = self.radon[:, self.frq_pixel_beg: self.frq_pixel_end].numpy()
        rad_array = (rad_array - self.subtract) / self.divide / thf_param_end

        rad_plot = self.colormap(rad_array)
        rad_plot = rad_plot[:, :, : 3]

        #Plot markers in blue
        rad_plot[self.spd_pixel_marker, :] = [0., 0., 1.]
        rad_plot[:, self.frq_pixel_marker] = [0., 0., 1.]

        #Plot detected peaks in red and white
        rad_plot[self.obj_label] = [1., 0., 0.]
        rad_plot[self.spd_pixel, self.frq_pixel] = [1., 1., 1.]

        rad_plot = rad_plot * 255.
        rad_plot = rad_plot.astype(np.uint8)

        rad_image = QImage(rad_plot.data, frq_pixels, spd_pixels, frq_pixels * 3, QImage.Format.Format_RGB888)
        rad_pixmap = QPixmap.fromImage(rad_image)
        self.rad_widget.setPixmap(rad_pixmap, frq_unit_from_param(self.frq_pixel_end * frq_param_delta),
                                              frq_unit_from_param(self.frq_pixel_beg * frq_param_delta),
                                              spd_unit_from_param(self.spd_pixel_end * spd_param_delta),
                                              spd_unit_from_param(self.spd_pixel_beg * spd_param_delta))

        if self.spectrum:
            thf_value = self.fourier[self.frq_pixel_beg: self.frq_pixel_end, 0].numpy()
        else:
            thf_value = self.radon[self.spd_pixel_marker, self.frq_pixel_beg: self.frq_pixel_end].numpy()
        thf_param = (thf_value - self.subtract) / self.divide

        thf_end_pixel = 256
        thf_size_frac = (1. + thf_param_end) / thf_param_end
        thf_size_pixel = int(thf_end_pixel * thf_size_frac)

        thf_grid_param = np.linspace(-1., thf_param_end, thf_size_pixel, endpoint = False)

        thf_plot = np.ones((frq_pixels, thf_size_pixel, 3)) * [0.267004, 0.004874, 0.329415]
        if self.spectrum:
            thf_plot[thf_grid_param < thf_param[:, None]] = [1.0, 0.75, 0.0]
        else:
            thf_plot[thf_grid_param < thf_param[:, None]] = [0.993248, 0.906157, 0.143936]

        if nsf_param_end:
            nsf_size_param = nsf_param_end * (np.sqrt(ser_pixels) if self.spectrum else 1.)
            nsf_bot_frac = (1. - nsf_size_param) / thf_param_end
            nsf_top_frac = (1. + nsf_size_param) / thf_param_end
            nsf_bot_pixel = int(thf_end_pixel * nsf_bot_frac)
            nsf_top_pixel = int(thf_end_pixel * nsf_top_frac)
            thf_plot[:, nsf_bot_pixel: nsf_top_pixel] /= 2.
            thf_plot[:, nsf_bot_pixel: nsf_top_pixel, 2] = 1.
        else:
            zer_frac = 1. / thf_param_end
            zer_pixel = int(thf_end_pixel * zer_frac)
            thf_plot[:, zer_pixel] /= 2.
            thf_plot[:, zer_pixel, 2] = 1.

        #Plot frequency marker in blue
        thf_plot[self.frq_pixel_marker, :, :] /= 2.
        thf_plot[self.frq_pixel_marker, :, 2] = 1.

        thf_plot = np.flip(thf_plot, 1).transpose((1, 0, 2))
        thf_plot = thf_plot * 255.
        thf_plot = thf_plot.astype(np.uint8).copy()

        thf_image = QImage(thf_plot.data, frq_pixels, thf_size_pixel, frq_pixels * 3, QImage.Format.Format_RGB888)
        thf_pixmap = QPixmap.fromImage(thf_image)
        self.fou_widget.setPixmap(thf_pixmap, frq_unit_from_param(self.frq_pixel_end * frq_param_delta),
                                              frq_unit_from_param(self.frq_pixel_beg * frq_param_delta))

    def keyPressEvent(self, event):
        key = event.key()
        minus = key == Qt.Key.Key_Minus or key == Qt.Key.Key_Underscore
        plus = key == Qt.Key.Key_Plus or key == Qt.Key.Key_Equal
        left = key == Qt.Key.Key_Left
        right = key == Qt.Key.Key_Right
        up = key == Qt.Key.Key_Up
        down = key == Qt.Key.Key_Down
        space = key == Qt.Key.Key_Space
        modifiers = event.modifiers()
        shift = modifiers & Qt.KeyboardModifier.ShiftModifier
        control = modifiers & Qt.KeyboardModifier.ControlModifier
        if minus and not (shift or control):
            self.divide /= 1.189207115002721
            self.detect()
        elif plus and not (shift or control):
            self.divide *= 1.189207115002721
            self.detect()
        elif minus and (shift or control):
            self.subtract -= self.divide / 16.
            self.detect()
        elif plus and (shift or control):
            self.subtract += self.divide / 16.
            self.detect()
        elif left and control:
            self.frq_pixel_beg -= 1
            self.frq_pixel_end -= 1
            self.detect()
        elif right and control:
            self.frq_pixel_beg += 1
            self.frq_pixel_end += 1
            self.detect()
        elif left and shift:
            self.frq_pixel_beg -= 16
            self.frq_pixel_end -= 16
            self.detect()
        elif right and shift:
            self.frq_pixel_beg += 16
            self.frq_pixel_end += 16
            self.detect()
        elif up and control:
            self.spd_pixel_beg -= 1
            self.spd_pixel_end -= 1
            spd_param_beg, spd_param_end = self.spd_pixel_beg * spd_param_delta, self.spd_pixel_end * spd_param_delta
            self.spd_param_grid = tf.range(spd_param_beg, spd_param_end, (spd_param_end - spd_param_beg) / spd_pixels)
            phase = tf.exp(-2. * math.pi * tf.complex(0., self.spd_param_grid[:, None, None] * self.tim_param_grid[:, None] * self.ser_param_grid))
            self.radon = phase * fid
            self.radon = tf.reduce_mean(self.radon, 2)
            self.radon = tf.signal.fft(self.radon)
            self.radon = tf.signal.fftshift(self.radon, 1)
            self.radon = tf.math.real(self.radon)
            self.detect()
        elif down and control:
            self.spd_pixel_beg += 1
            self.spd_pixel_end += 1
            spd_param_beg, spd_param_end = self.spd_pixel_beg * spd_param_delta, self.spd_pixel_end * spd_param_delta
            self.spd_param_grid = tf.range(spd_param_beg, spd_param_end, (spd_param_end - spd_param_beg) / spd_pixels)
            phase = tf.exp(-2. * math.pi * tf.complex(0., self.spd_param_grid[:, None, None] * self.tim_param_grid[:, None] * self.ser_param_grid))
            self.radon = phase * fid
            self.radon = tf.reduce_mean(self.radon, 2)
            self.radon = tf.signal.fft(self.radon)
            self.radon = tf.signal.fftshift(self.radon, 1)
            self.radon = tf.math.real(self.radon)
            self.detect()
        elif up and shift:
            self.spd_pixel_beg -= 16
            self.spd_pixel_end -= 16
            spd_param_beg, spd_param_end = self.spd_pixel_beg * spd_param_delta, self.spd_pixel_end * spd_param_delta
            self.spd_param_grid = tf.range(spd_param_beg, spd_param_end, (spd_param_end - spd_param_beg) / spd_pixels)
            phase = tf.exp(-2. * math.pi * tf.complex(0., self.spd_param_grid[:, None, None] * self.tim_param_grid[:, None] * self.ser_param_grid))
            self.radon = phase * fid
            self.radon = tf.reduce_mean(self.radon, 2)
            self.radon = tf.signal.fft(self.radon)
            self.radon = tf.signal.fftshift(self.radon, 1)
            self.radon = tf.math.real(self.radon)
            self.detect()
        elif down and shift:
            self.spd_pixel_beg += 16
            self.spd_pixel_end += 16
            spd_param_beg, spd_param_end = self.spd_pixel_beg * spd_param_delta, self.spd_pixel_end * spd_param_delta
            self.spd_param_grid = tf.range(spd_param_beg, spd_param_end, (spd_param_end - spd_param_beg) / spd_pixels)
            phase = tf.exp(-2. * math.pi * tf.complex(0., self.spd_param_grid[:, None, None] * self.tim_param_grid[:, None] * self.ser_param_grid))
            self.radon = phase * fid
            self.radon = tf.reduce_mean(self.radon, 2)
            self.radon = tf.signal.fft(self.radon)
            self.radon = tf.signal.fftshift(self.radon, 1)
            self.radon = tf.math.real(self.radon)
            self.detect()
        elif left and not (shift or control):
            self.frq_pixel_marker -= 1
        elif right and not (shift or control):
            self.frq_pixel_marker += 1
        elif up and not (shift or control):
            self.spd_pixel_marker -= 1
        elif down and not (shift or control):
            self.spd_pixel_marker += 1
        elif space:
            self.spectrum = not self.spectrum
        else:
            return
        self.draw()

picker = Picker()
picker.show()

app.exec()
