import multiprocessing
import sys
from multiprocessing import Lock, Queue
import socket
import struct
import librosa
import numpy as np
import torch
import os
from model.htsat import HTSAT_Swin_Transformer
from sed_model import SEDWrapper
import esc_config as config
import time

import numpy as np
import matplotlib.pyplot as plt
import math as rm
import json

class TTcpSampleReciever:
    def __init__(self, listen_port =33377):
        self.m_CliSocket = None
        self.m_Socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self.m_Socket.bind(("127.0.0.1", listen_port))
        self.m_Socket.listen()
        self.m_HaveConn=False

    def RecieveSamples(self):
        if self.m_HaveConn==False:
            conn, addr = self.m_Socket.accept()
            self.m_CliSocket = conn
            self.m_HaveConn = True
            print("sample reciever accepted conn")
        sz_data = self.m_CliSocket.recv(2)


        block_size = struct.unpack('H', sz_data)[0]
        sample_count = block_size // 8

        sample_data = bytearray()
        while len(sample_data) < block_size:
            packet = self.m_CliSocket.recv(block_size - len(sample_data))
            if not packet:
                # Сокет закрыл соединение или возникла другая ошибка
                raise ConnectionError("Socket connection closed unexpectedly")
            sample_data.extend(packet)



        samples = []
        for i in range(0, sample_count):
            try:
                d = struct.unpack('d', sample_data[i * 8:i * 8 + 8])[0]
                samples.append(d)
            except:
                print(i, len(sample_data), block_size, sample_count)
        return samples

class TDescSender:
    def __init__(self, dest_port =33399):
        self.m_Socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        while True:
            try:
                self.m_Socket.connect(("127.0.0.1",dest_port))
            except:
                continue
            break
        print("REMOTE PLOTTER CONNECT OK", dest_port)

    def SendDesc(self, x,y):
        self.m_Socket.sendall("{0};{1}".format(x,y).encode())


class Audio_Classification:
    def __init__(self, model_path, config):
        super().__init__()

        self.sed_model = HTSAT_Swin_Transformer(
            spec_size=config.htsat_spec_size,
            patch_size=config.htsat_patch_size,
            in_chans=1,
            num_classes=config.classes_num,
            window_size=config.htsat_window_size,
            config=config,
            depths=config.htsat_depth,
            embed_dim=config.htsat_dim,
            patch_stride=config.htsat_stride,
            num_heads=config.htsat_num_head
        )
        ckpt = torch.load(model_path, map_location="cpu")
        temp_ckpt = {}
        for key in ckpt["state_dict"]:
            temp_ckpt[key[10:]] = ckpt['state_dict'][key]
        self.sed_model.load_state_dict(temp_ckpt)
        self.sed_model.eval()


    def predict2(self, waveform):

        with torch.no_grad():
            x = torch.from_numpy(waveform).float()
            output_dict = self.sed_model(x[None, :], None, True)
            pred = output_dict['clipwise_output']

            pred_post = pred[0].detach().cpu().numpy()
            pred_label = np.argmax(pred_post)
            #pred_prob = np.max(pred_post)
            psumm=0.0
            for pc in pred_post:
                psumm+=rm.pow(rm.e, pc)
            drone_prob = rm.pow(rm.e, pred_post[0])/psumm

        return pred_label, drone_prob



# class TSaver:
#     def __init__(self):
#         for s in samples:
#             iv = int(s)
#             f_out.write(iv.to_bytes(2,'little',signed=True))
#         f_out.flush()
#         pass

def read_pcm(fname):
    samples =[]
    f_in = open(fname,"rb")
    n=10
    while True:
        bblock = f_in.read(2)
        if not bblock:
            break
        s = struct.unpack("h",bblock )[0]
        samples.append(float(s)/32768)
    return samples



def start_nn_service(samples_listen_port, target_sample_count:int, model_path , desc_port):

    config.hop_size = 320
    Audiocls = Audio_Classification(model_path, config)
    

    print("CKPT LOADED OK")
    signal_reciever = TTcpSampleReciever(samples_listen_port)
    desc_sender = TDescSender(desc_port)

    sample_buff =[]
    max_abs =0.0
    while True:
        samples = signal_reciever.RecieveSamples()
        for s in samples:
            if abs(s)>max_abs:
                max_abs = abs(s)
        sample_buff.extend(samples)
        if len(sample_buff) < target_sample_count:
            continue
        for i in range(0,len(sample_buff)):
            sample_buff[i]/=max_abs
        pred_label, pred_prob = Audiocls.predict2(np.asarray(sample_buff))

        desc_sender.SendDesc(pred_label, pred_prob)
        print(pred_label, pred_prob)
        sample_buff=[]
        max_abs =0.0



js_settings = json.load(open("nn_settings.json"))
proc_lst =[]

#ID = int( sys.argv[1])
#print("ID=",ID)

sample_count_for_desc = int(js_settings["sample_count_for_desc"])
cur_id =0
for s in js_settings["nn_processors"]:
    #if cur_id!=ID:
    #    cur_id += 1
    #    continue
    ckpt_path=s["ckpt_path"]
    samples_recv_port = int(s["samples_recv_port"])
    descs_dest_port = int(s["descs_dest_port"])
    print("process started")
    p = multiprocessing.Process(target=start_nn_service, args=(samples_recv_port, sample_count_for_desc, ckpt_path, descs_dest_port))

    p.start()
    proc_lst.append(p)
    cur_id+=1

#desc_proc = multiprocessing.Process(target=make_common_descision, args=(qlist, ))
#desc_proc.start()
for p in proc_lst:
    p.join()

#desc_proc.join()