import os
import subprocess
from http.server import BaseHTTPRequestHandler, HTTPServer
import urllib.parse

import argparse
import torch
from torch import nn
from collections import OrderedDict
import time

import numpy as np
from contextlib import nullcontext

from models.MobileNetV3 import get_model as get_mobilenet, get_ensemble_model
from models.preprocess import AugmentMelSTFT
from helpers.utils import NAME_TO_WIDTH

import ffmpeg
import mimetypes

import ssl

prefix = "/run/media/tristan/Main/music/tidal-dl/Playlist/my_tracklist_10k/"
prefix = "/home/tristan/dev/tridonn/EfficientAT/songcache/"

device = ""
audio_path = ""

model = None
mel = None

sample_rate = 32000

def setup_models(args):
    global device, audio_path, model, mel, tridonn_model
    model_name = args.model_name
    #device = torch.device('cuda') if args.cuda and torch.cuda.is_available() else torch.device('cpu')
    #device = torch.device('xpu') if args.intel else torch.device('cpu')
    device = torch.device('cpu')
    audio_path = args.audio_path
    sample_rate = args.sample_rate
    window_size = args.window_size
    hop_size = args.hop_size
    n_mels = args.n_mels

    # load pre-trained model
    if len(args.ensemble) > 0:
        model = get_ensemble_model(args.ensemble)
    else:
        model = get_mobilenet(width_mult=NAME_TO_WIDTH(model_name), pretrained_name=model_name, strides=args.strides,
                              head_type=args.head_type)
    model.eval()
    #if torch20:
    #    model = torch.compile(model)
    #model = model.to(device)
    #if args.cuda:
    #    model = ipex.optimize(model)

    # model to preprocess waveform into mel spectrograms
    mel = AugmentMelSTFT(n_mels=n_mels, sr=sample_rate, win_length=window_size, hopsize=hop_size)
    mel.eval()
    #if torch20:
    #    mel = torch.compile(mel)

    #tridonn_model.load_state_dict(torch.load(model_filename))
    #model.to("xpu")
    #if torch20:
    #    tridonn_model = torch.compile(tridonn_model)
    #tridonn_model.eval()


def tag_audio_window(audio_path, window_size=10.0, hop_length=5.0):
        """
            Tags an audio file with an acoustic event.
            Args:
                audio_path (str): path to the audio file
                window_size (float): size of the window in seconds
                hop_length (float): hop length in seconds
            Returns:
                List of dictionaries with the following keys:
                    - 'start': start time of the window in seconds
                    - 'end': end time of the window in seconds
                    - 'tags': list of tags for the window in dictionary format
                        - 'tag': name of the tag
                        - 'probability': confidence of the tag

        """

        rawaudio, _out = (ffmpeg.input(audio_path).output('-', format='f32le', acodec='pcm_f32le', ac=1, ar='32000').run(capture_stdout=True, capture_stderr=True))
        print(f"GOT to {audio_path}")

        #cache_writer = open(f"windowed/windowed_ana_{window_size}_{hop_length}.csv", "w")

        #if len(rawaudio) < 230400000:
        audio = np.frombuffer(rawaudio, dtype=np.float32).copy()
        waveform = torch.from_numpy(audio[None, :]).to(device)

        # load audio file
        #(waveform, _) = librosa.core.load(audio_path, sr=self.sample_rate, mono=True)
        #waveform = torch.from_numpy(waveform[None, :]).to(self.device)

        # analyze the audio file in windows, pad the last window if needed
        window_size = int(window_size * sample_rate)
        hop_length = int(hop_length * sample_rate)
        n_windows = int(np.ceil((waveform.shape[1] - window_size) / hop_length)) + 1
        waveform = torch.nn.functional.pad(waveform, (0, n_windows * hop_length + window_size - waveform.shape[1]))

        res = np.array([], dtype=np.float32)

        with torch.no_grad(), autocast(device_type=device.type) if device.type == 'cuda' else nullcontext():
            tags = []
            for i in range(n_windows):
                start = i * hop_length
                end = start + window_size
                spec = mel(waveform[:, start:end])
                preds, features = model(spec.unsqueeze(0))
                preds = torch.sigmoid(preds.float()).squeeze().cpu().numpy()
                #sorted_indexes = np.argsort(preds)[::-1]

                res = np.concatenate((res, preds))

                # my section
                #preds_string = "\t".join([float_formatter(k) for k in preds])
                #print(preds_string)
                #cache_writer.write(f"{trackfilename.split('/')[-1]}" + "\t" + f"{trackid}" + "\t" + f"{preds_string}" + "\n")
                #cache_writer.write(f"{preds_string}" + "\n")
                #cache_writer.flush()
                #print(f"Added {trackid} to database")
                # progress bar
                print(f'\rProgress: {i+1}/{n_windows}', end='')

            res.flatten()
            print(res)

        #cache_writer.close()
        return res

class MyHandler(BaseHTTPRequestHandler):
    def do_GET(self):
        print(self.path)

        parsed_url = urllib.parse.urlparse(self.path)

        content = ""

        if parsed_url.path.endswith('.m4a'):
            # Extract file name from path

            filename = urllib.parse.unquote(parsed_url.path[1:]).replace("full/", "").replace("preview/", "").replace("info/", "").replace("tags/", "")
            print("filename: " + filename)

            # Convert the file to .wav using ffmpeg and cut the first 60 seconds
            real_path = parsed_url.path.split("/")

            self.send_response(200)

            if real_path[1] == "preview" or real_path[1] == "full":
                output_filename = '/dev/shm/' + str(len(self.path)) + "_" + filename.split('.')[0].split("/")[-1] + '.wav'
                time_mgr = '-t' if parsed_url.path.split("/")[1] == "preview" else '-ss'
                subprocess.run(['ffmpeg', '-i', prefix+filename, time_mgr, '20', '-y', output_filename], check=True)

                # Open the converted file and read it
                with open(output_filename, 'rb') as f:
                    content = f.read()

                # Send response headers
                self.send_header('Content-type', 'audio/x-wav')

            elif real_path[1] == "info":
                result = subprocess.run(['ffprobe',  '-v', 'quiet', '-print_format', 'json', '-show_format', '-show_streams', prefix+filename], check=True, capture_output=True)
                content = result.stdout

                self.send_header('Content-type', 'application/json')

            elif real_path[1] == "tags":
                #result = subprocess.run(['ffprobe',  '-v', 'quiet', '-print_format', 'json', '-show_format', '-show_streams', prefix+filename], check=True, capture_output=True)
                output_filename = '/dev/shm/' + filename.split('.')[0].split("/")[-1] + '.bin'
                res = tag_audio_window(prefix+filename, 20, 2.5)
                content = res.tofile(output_filename)

                # Open the converted file and read it
                with open(output_filename, 'rb') as f:
                    content = f.read()

                self.send_header('Content-type', 'application/octet-stream')

            self.send_header('Access-Control-Allow-Origin', '*')
            self.send_header('Content-Length', len(content))
            self.end_headers()

            # Send the content
            self.wfile.write(content)

            if output_filename:
                os.remove(output_filename)
        elif parsed_url.path.endswith(".txt"):
            parsed_url = urllib.parse.urlparse(self.path)

            # Convert the file to .wav using ffmpeg and cut the first 60 seconds
            output_filename = parsed_url.path[1:] #'/dev/shm/' + str(len(self.path)) + "_" + filename.split('.')[0].split("/")[-1] + '.wav'
            print(output_filename)

            # Open the converted file and read it
            with open(output_filename, 'rb') as f:
                content = f.read()

            # Send response headers
            self.send_response(200)
            #self.send_header('Content-type', 'audio/x-wav')
            mimetype, _ = mimetypes.guess_type(output_filename)
            self.send_header('Content-type', mimetype)
            self.send_header('Content-Length', len(content))
            self.send_header('Access-Control-Allow-Origin', '*')
            self.send_header('Cross-Origin-Embedder-Policy', 'require-corp')

            self.end_headers()

            # Send the content
            self.wfile.write(content)
        else:
            output_filename = "/run/media/tristan/Main/projects/flymusic/flymusic_5_2/" + parsed_url.path[1:] #'/dev/shm/' + str(len(self.path)) + "_" + filename.split('.')[0].split("/")[-1] + '.wav'
            print(output_filename)

            if self.path == "/":
                output_filename = "flymusic_5_2.html"


            #time_mgr = '-t' if self.path.split("/")[1] == "preview" else '-ss'
            #subprocess.run(['ffmpeg', '-i', prefix+filename, time_mgr, '20', '-y', output_filename], check=True)

            # Open the converted file and read it
            with open(output_filename, 'rb') as f:
                content = f.read()

            # Send response headers
            self.send_response(200)
            #self.send_header('Content-type', 'audio/x-wav')
            mimetype, _ = mimetypes.guess_type(output_filename)
            self.send_header('Content-type', mimetype)
            self.send_header('Content-Length', len(content))
            self.send_header('Access-Control-Allow-Origin', '*')
            self.send_header('Cross-Origin-Embedder-Policy', 'require-corp')
            self.send_header('Cross-Origin-Opener-Policy', 'same-origin')
            self.end_headers()

            # Send the content
            self.wfile.write(content)
        #else:
        #    self.send_error(404, 'File Not Found: %s' % self.path)

def run(server_class=HTTPServer, handler_class=MyHandler):
    parser = argparse.ArgumentParser(description='Example of parser. ')
    # model name decides, which pre-trained model is loaded
    parser.add_argument('--model_name', type=str, default='mn20_as')#40_as_ext')
    parser.add_argument('--strides', nargs=4, default=[2, 2, 2, 2], type=int)
    parser.add_argument('--head_type', type=str, default="mlp")
    parser.add_argument('--intel', action='store_true', default=False)
    parser.add_argument('--audio_path', type=str, required=False)

    # preprocessing
    parser.add_argument('--sample_rate', type=int, default=32000)
    parser.add_argument('--window_size', type=int, default=800)
    parser.add_argument('--hop_size', type=int, default=320)
    parser.add_argument('--n_mels', type=int, default=128)

    # overwrite 'model_name' by 'ensemble_model' to evaluate an ensemble
    parser.add_argument('--ensemble', nargs='+', default=[])

    args = parser.parse_args()
    server_address = ('', 10454)
    httpd = server_class(server_address, handler_class)
    httpd.socket = ssl.wrap_socket(httpd.socket, keyfile="/dev/shm/cube.key", certfile="/dev/shm/cube.crt", server_side=False)
    print('Starting httpd...')
    setup_models(args)
    httpd.serve_forever()

if __name__ == '__main__':
    run()
