from modules.scene.train_ert import ERTTrainer
from panda3d.core import GraphicsOutput, Texture, FrameBufferProperties, WindowProperties, ClockObject, Mat4, Vec3, Quat
from panda3d.core import Geom, GeomNode, GeomVertexFormat, GeomTriangles, GeomVertexData, GeomVertexWriter, GeomVertexReader, GeomPoints, NodePath
from panda3d.core import BillboardEffect, LineSegs, TextNode, LColor, Vec4
from panda3d.core import GraphicsPipe, Point3, PNMImage, TransparencyAttrib, LMatrix4f, CardMaker, NodePath, Shader, DepthWriteAttrib, Camera
from direct.gui.DirectGui import DirectWaitBar
from direct.gui.OnscreenImage import OnscreenImage
from direct.showbase.ShowBase import ShowBase
import numpy as np
from direct.task import Task
import torch
import math
import wx
from nerfstudio.utils.eval_utils import eval_setup
from nerfstudio.cameras.cameras import Cameras
from nerfstudio.scripts.exporter import ExportGaussianSplat
from pathlib import Path
import os
import time
import open3d as o3d
import cv2
import gc
import multiprocessing
import ctypes
import threading
SPEED = 0.25
WIDTH = 1280//2
HEIGHT = 720//2
FULL_WIDTH = 1920
FULL_HEIGHT = 1080
FOV = 75

class RenderERT(multiprocessing.Process):
    def __init__(self, recon_path : str, rgb_buffer, depth_buffer, render_queue : multiprocessing.Queue, response_queue : multiprocessing.Queue): 
        super().__init__()
        self.render_queue = render_queue
        self.response_queue = response_queue
        self.recon_path = recon_path
        self.rgb_buffer = rgb_buffer
        self.depth_buffer = depth_buffer

    def run(self):
        config, pipeline, _, _ = eval_setup(
            Path(self.recon_path),
            test_mode="inference",
        )
        ert_model = pipeline.model

        while True:
            data = self.render_queue.get()
            
            if data == None:
                break
            if not self.response_queue.empty():
                continue
            
            splat_camera = Cameras(data[0],
                               data[1],
                               data[1],
                               float(WIDTH//2),
                               float(HEIGHT//2),
                               width = WIDTH,
                               height= HEIGHT)
            with torch.no_grad() : ert_imgs = ert_model.get_outputs_for_camera(splat_camera)

           # start = time.time()
            rgb = ert_imgs["rgb"].cpu().numpy()
            rgb_uint8 = (rgb * 255).clip(0, 255).astype(np.uint8)
            rgba = np.dstack((rgb_uint8[..., ::-1], np.full(rgb_uint8.shape[:2], 210, dtype=np.uint8)))
            self.rgb_buffer[:] = rgba.tobytes()
            #print(time.time() - start)

            if data[2]:
                self.depth_buffer[:] = ert_imgs["depth"].cpu().numpy().tobytes()

            self.response_queue.put("image copied")


class SplatShowbase(ShowBase):
    def __init__(self, panel, app, iter_count):
        self.app = app
        self.wxApp = app
        self.wx_panel = panel
        wp = WindowProperties()
        wp.setOrigin(0,0)
        w, h = self.wx_panel.GetSize()
        wp.setSize(WIDTH, HEIGHT)
        wp.setParentWindow(self.wx_panel.GetHandle())
        
        ShowBase.__init__(self, fStartDirect=True, windowType='none')
        self.pointclouds = []
        self.overlayCard = None
        self.disableMouse()
        
        base.openMainWindow(props=wp)
        self.wx_panel.Bind(wx.EVT_SIZE, self.OnResize)
        self.win.setClearColor((0.0, 0.0, 0.0, 1.0))
        
        self.camLens.setFov(FOV)
        self.camLens.setAspectRatio(WIDTH/HEIGHT)
        self.camLens.setNear(0.01)
        self.cam_mat = np.array(base.cam.get_mat())
        self.init_controls()

        #self.taskMgr.remove('igLoop')
        self.taskMgr.add(self.wxTask, "wxTask")
        self.taskMgr.add(self.move_camera, "move_camera_task")

        self.recon_loaded = False
        self.line_np = None
        self.text_np = None
        self.measure_ctr = 0
        self.sphere_a = self.loader.loadModel("models/misc/sphere")
        self.sphere_a.setColor(0, 1, 0, 1)
        self.sphere_a.setScale(0.03)
        self.sphere_a.reparentTo(self.render)
        self.sphere_a.hide()
        
        self.sphere_b = self.loader.loadModel("models/misc/sphere")
        self.sphere_b.setColor(0, 1, 0, 1)
        self.sphere_b.setScale(0.03)
        self.sphere_b.reparentTo(self.render)
        self.sphere_b.hide()

        self.scale = 1
        self.measure = True

        self.recon_step_max = (iter_count*30)
        self.loading_bar = DirectWaitBar(text="Recon in Progress", range=100, value=0, pos=(0, 0, 0), scale=1.5)
        self.loading_bar['barColor'] = (0.13, 0.59, 0.95, 1)  # Set loading bar color
        self.loading_bar.hide()

        self.texture = Texture()
        self.texture.setup2dTexture(WIDTH, HEIGHT, Texture.TUnsignedByte, Texture.FRgba8)

        cm = CardMaker('overlay')
        cm.setFrameFullscreenQuad()
        self.overlayCard = NodePath(cm.generate())
        self.overlayCard.reparentTo(self.render2d)

        self.overlayCard.setTexture(self.texture)
        self.overlayCard.setTransparency(TransparencyAttrib.MAlpha)
        self.overlayCard.setBin('fixed', 60)

        base.setFrameRateMeter(True)
        print(base.win.gsg.driver_vendor)
        print(base.win.gsg.driver_renderer)
        globalClock = ClockObject.getGlobalClock()
        globalClock.setMode(ClockObject.MLimited)
        globalClock.setFrameRate(30)


    def recon_in_progress(self):
        self.loading_bar['value'] = 0
        self.loading_bar.show()

    def update_loading_bar(self, val = 1):
        # Simulate loading progress
        #print(int(((self.loading_bar['value'] + val) / self.recon_step_max)))
        test = int((val/self.recon_step_max) * 100)
        self.loading_bar['value'] = int((val/self.recon_step_max) * 100)
        if self.loading_bar['value'] >= self.loading_bar['range']:
            self.loading_bar.hide()  # Hide the loading bar when done
            self.loading_bar['value'] = 0

    def init_controls(self):
        self.speed = SPEED * 1.0
        self.last_mouse_pos = (0, 0)
        self.control_map = {'forward':0,
                            'backward':0, 
                            'left':0,
                            'right':0,
                            'drag':0,
                            'roll-left':0, 
                            'roll-right':0,
                            'up':0, 
                            'down':0}

        self.accept("w", self.set_control_map, ["forward",1])
        self.accept("a", self.set_control_map, ["left",1])
        self.accept("s", self.set_control_map, ["backward",1])
        self.accept("d", self.set_control_map, ["right",1])
        self.accept("arrow_up", self.set_control_map, ["forward",1])
        self.accept("arrow_left", self.set_control_map, ["left",1])
        self.accept("arrow_down", self.set_control_map, ["backward",1])
        self.accept("arrow_right", self.set_control_map, ["right",1])
        self.accept("q", self.set_control_map, ["up",1])
        self.accept("e", self.set_control_map, ["down",1])
        self.accept("z", self.set_control_map, ["roll-left",1])
        self.accept("c", self.set_control_map, ["roll-right",1])
        self.accept("leftshift", self.set_speed, [3])
        
        self.accept("w-up", self.set_control_map, ["forward",0])
        self.accept("a-up", self.set_control_map, ["left",0])
        self.accept("s-up", self.set_control_map, ["backward",0])
        self.accept("d-up", self.set_control_map, ["right",0])
        self.accept("arrow_up-up", self.set_control_map, ["forward",0])
        self.accept("arrow_left-up", self.set_control_map, ["left",0])
        self.accept("arrow_down-up", self.set_control_map, ["backward",0])
        self.accept("arrow_right-up", self.set_control_map, ["right",0])
        self.accept("q-up", self.set_control_map, ["up",0])
        self.accept("e-up", self.set_control_map, ["down",0])
        self.accept("z-up", self.set_control_map, ["roll-left",0])
        self.accept("c-up", self.set_control_map, ["roll-right",0])
        self.accept('mouse1', self.set_control_map, ["drag",1])
        self.accept('mouse1-up', self.set_control_map, ["drag",0])

        self.accept("r", self.clear_measure)
        self.accept('space', self.on_space_click)

        # self.accept('mouse1', self.hide_cursor)
        # self.accept('mouse1-up', self.show_cursor)


    # def hide_cursor(self):
    #     props = WindowProperties()
    #     props.setCursorHidden(True)
    #     base.win.requestProperties(props)

    # def show_cursor(self):
    #     props = WindowProperties()
    #     props.setCursorHidden(False)
    #     base.win.requestProperties(props)


    def unload(self):
        self.taskMgr.remove("move_camera_task")
        self.taskMgr.remove("render_task")
        if hasattr(self, 'render_process'):
            self.render_queue.put(None)
            self.render_process.terminate()
        for pc in self.pointclouds: pc.removeNode()
        self.pointclouds.clear()
        self.recon_loaded = False
        self.clear_measure()
        self.texture.setRamImage(np.zeros((WIDTH, HEIGHT, 4), dtype=np.uint8))

    def load_recon(self, ert):
        if self.recon_loaded:
            self.unload()

        self.loading_bar.hide()
        self.loading_bar['value'] = 0
        
        recon_path = ert
        self.recon_path = recon_path
        if recon_path == None:
            return

        self.render_buffer = 3
        self.render_queue = multiprocessing.Queue()
        self.response_queue = multiprocessing.Queue()
        self.shared_rgb = multiprocessing.RawArray(ctypes.c_ubyte, WIDTH*HEIGHT*4)
        self.shared_depth = multiprocessing.RawArray(ctypes.c_ubyte, WIDTH*HEIGHT*4)

        self.rgb = np.frombuffer(self.shared_rgb, dtype=np.uint8)
        self.depth = np.frombuffer(self.shared_depth, dtype=np.float32)
    
        self.render_process = RenderERT(recon_path, self.shared_rgb, self.shared_depth, self.render_queue, self.response_queue)
        self.render_process.start()
        
        self.recon_loaded = True
        
        self.render_splat()
        self.taskMgr.add(self.move_camera, "move_camera_task")
        self.taskMgr.add(self.update_render_card, "render_task")
        self.render_splat()

    def get_export_data(self):
        return (self.scale, self.recon_path)
    
    def export_ply(self, fp):
        dir = os.path.dirname(fp)
        file = os.path.basename(fp)
        pcd_exporter = ExportGaussianSplat(Path(self.recon_path), Path(dir), file)
        pcd_exporter.main()

        pcd = o3d.io.read_point_cloud(fp)
        scaling_mat = np.array([[self.scale, 0, 0, 0], [0, self.scale, 0, 0], [0, 0, self.scale, 0], [0, 0, 0, 1]])
        pcd.transform(scaling_mat)
        o3d.io.write_point_cloud(fp, pcd)

    def wxTask(self, task):
        while self.app.HasPendingEvents():
            self.app.ProcessPendingEvents()
        self.app.Yield()
        return Task.cont
    
    def reset_cam(self):
        base.cam.setMat(Mat4.identMat())
        self.render_splat()

    def set_speed(self, value):
        self.speed = SPEED * value

    def show_cursor(self, show):
        wp = WindowProperties()
        wp.setCursorHidden(show)
        self.win.requestProperties(wp)

    def set_control_map(self, key, value):
        if not self.recon_loaded:
            return
        
        if key == 'drag' and value and self.mouseWatcherNode.hasMouse():
            self.last_mouse_pos = (self.mouseWatcherNode.getMouseX(), self.mouseWatcherNode.getMouseY())
        self.control_map[key] = value
    
    def update_render_card(self, task):
        mat = np.array(base.cam.get_mat())
        depth = False

        if False in np.equal(self.cam_mat, mat):
            self.render_buffer = 3
            self.cam_mat = mat
        else:
            if self.render_buffer > 0: # We want to render a few frames after the user stops moving. For some reason this is needed otherwise not all user movement is renderered.
                self.render_buffer -= 1
                depth = True
            else:
                return Task.cont
        base.graphicsEngine.render_lock.acquire()
        self.render_splat(depth)
        res = self.response_queue.get()
        self.texture.setRamImage(self.rgb)
        base.graphicsEngine.render_lock.release()
        
        return Task.cont

    def move_camera(self, task):
        
        # if all(value == 0 for value in self.control_map.values()):
        #     return Task.cont

        cam_mat = base.cam.get_mat()
        forward_vec = cam_mat.get_row3(1)
        up_vec = cam_mat.get_row3(2)
        right_vec = cam_mat.get_row3(0)
        

        new_pos = base.cam.getPos()
        if self.control_map['forward']:
            new_pos = new_pos + forward_vec * self.speed * globalClock.getDt()
        if self.control_map['backward']:
            new_pos = new_pos - forward_vec * self.speed * globalClock.getDt()
        if self.control_map['left']:
            new_pos = new_pos - right_vec * self.speed * globalClock.getDt()
        if self.control_map['right']:
            new_pos = new_pos + right_vec * self.speed * globalClock.getDt()
        if  self.control_map['up']:
            new_pos = new_pos + up_vec * self.speed * globalClock.getDt()
        if  self.control_map['down']:
            new_pos = new_pos - up_vec * self.speed * globalClock.getDt()    
        
        if self.control_map['roll-left']:
            base.cam.set_hpr(base.cam, 0, 0, -1)
        if self.control_map['roll-right']:
            base.cam.set_hpr(base.cam, 0, 0, 1)

        base.cam.setPos(new_pos)

        if self.control_map['drag']:
            current_mouse_pos = (self.mouseWatcherNode.getMouseX(), self.mouseWatcherNode.getMouseY())

            dx = (current_mouse_pos[0] - self.last_mouse_pos[0]) * -100
            dy = (current_mouse_pos[1] - self.last_mouse_pos[1]) * 100

            #print(f"dx: {dx}, dy: {dy}") # (dx, dy)
            if dx !=0 or dy != 0:
                base.cam.setH(base.cam, dx)
                base.cam.setP(base.cam, dy)
                wp = base.win.getProperties()
                self.win.requestProperties(wp)

            #Move mouse to 0,0
            wp = base.win.getProperties()
            width = wp.getXSize()
            height = wp.getYSize()
            middle_x = int(width / 2)
            middle_y = int(height / 2)
            self.win.movePointer(0, middle_x, middle_y)
            self.last_mouse_pos = (0, 0)
        
        return Task.cont

    def render_splat(self, depth = False):
        if not self.recon_loaded:
            return
        cam = self.get_cam_info()

        f = WIDTH / (math.tan(math.radians(FOV/2)) * 2)
        self.render_queue.put((cam[1], f, depth))

    def on_space_click(self):
        if not self.recon_loaded:
            return
        
        self.render_splat(depth=True)

        if self.mouseWatcherNode.hasMouse():
            mpos = self.mouseWatcherNode.getMouse()
            dpos = [int(math.floor(((mpos.getX() + 1) / 2.0) * WIDTH)),
                    int(math.floor(((mpos.getY() + 1) / 2.0) * HEIGHT))] # Do not multiply y by -1 here due to image flipping in panda renderer
            depth = -self.depth[dpos[1]* WIDTH + dpos[0]]
            
            mat = base.cam.get_mat()
            c2w = np.array(mat)
            rotation_matrix = np.array([[-1, 0, 0, 0],
                                    [0, 0, -1, 0],
                                    [0, -1, 0, 0],
                                    [0, 0, 0, 1]])

            c2w = np.dot(rotation_matrix, c2w)
            c2w = np.transpose(c2w)

            focal_length = WIDTH / (math.tan(math.radians(FOV/2)) * 2)
            cx = WIDTH / 2
            cy = HEIGHT / 2
            x_cam = (dpos[0] - cx) * depth / focal_length
            y_cam = (dpos[1] - cy) * depth / focal_length
            z_cam = float(depth)

            P_cam = np.array([x_cam, y_cam, z_cam, 1])
            P_world = np.dot(c2w, P_cam)
            P_world /= P_world[3]

            self.measure_ctr+=1
            if self.measure_ctr == 3:
                self.clear_measure()
                return
            if self.measure_ctr == 1:
                self.point_a = Point3(P_world[0], P_world[1], P_world[2])
                self.sphere_a.setPos(self.point_a)
                self.sphere_a.show()
            elif self.measure_ctr == 2:
                self.point_b = Point3(P_world[0], P_world[1], P_world[2])
                self.sphere_b.setPos(self.point_b)
                self.sphere_b.show()
                self.calc_distance_and_display()


    def calc_distance_and_display(self):
    
        line = LineSegs()
        line.setColor(0, 0, 0, 1)
        line.setThickness(5)
        line.moveTo(self.point_a)
        line.drawTo(self.point_b)
        
        line_node = line.create()
        self.line_np = NodePath(line_node)
        self.line_np.reparentTo(self.render)

        distance = np.linalg.norm(np.array(self.point_b) - np.array(self.point_a)) * self.scale
        self.distance = distance
        midpoint = (self.point_a + self.point_b) / 2

        text_node = TextNode('distance_text')
        text_node.setText(f"{distance:.2f}")
        text_node.setTextColor(LColor(.35, 1, .35, 1))
        text_node.setAlign(TextNode.ACenter)
        text_node.setShadowColor(Vec4(1, 1, 1, 1))  # Black shadow color
        text_node.setShadow(0.04, 0.04)  # Offset for the shadow
        
        self.text_np = base.render.attachNewNode(text_node)
        self.text_np.setPos(midpoint + Vec3(0, 0, 0.01)) 

        self.text_np.setScale(0.05)
        billboard = BillboardEffect.make(Vec3(0, 0, 1), True, True, 0.0, base.cam, Point3(0, 0, 0))
        self.text_np.setEffect(billboard)
        self.text_np.setBin('fixed', 100)
        self.text_np.setDepthTest(False)
        self.text_np.setDepthWrite(False)

    def clear_measure(self):
    
        self.measure_ctr = 0
        self.sphere_a.hide()
        self.sphere_b.hide()
        if self.line_np:
            self.line_np.removeNode()
        self.line_np = None
        if self.text_np:
            self.text_np.removeNode()
        self.text_np = None

    def scale_model(self, scale_to):
        scale = scale_to / self.distance
        self.scale = self.scale * scale
        self.recalc_distance()

    def recalc_distance(self):
        if self.line_np and self.text_np:
            self.line_np.removeNode()
            self.line_np = None
            self.text_np.removeNode()
            self.text_np = None
            self.calc_distance_and_display()

    def get_wx_prop(self):
        wp = WindowProperties()
        wp.setOrigin(0,0)
        w, h = self.wx_panel.GetSize()
        target_height = int(w * (9.0 / 16.0))
        target_width = w
        offsetX, offsetY = 0, 0

        if target_height > h:
            target_height = h
            target_width = int(target_height * (16.0 / 9.0)) 
            offsetX = (w - target_width) // 2
        else:
            offsetY = (h - target_height) // 2

        wp.setSize(target_width, target_height)
        wp.setOrigin(offsetX, offsetY)
        wp.setParentWindow(self.wx_panel.GetHandle())
        return wp
  
    def OnResize(self, event): 
        wp = self.get_wx_prop()  
        base.win.requestProperties(wp) 
        event.Skip()
    
    def get_cam_info(self):
        lens = base.camLens
        focal_length = lens.get_focal_length()
        film_size = lens.get_film_size()
        principal_point = (film_size[0] / 2, film_size[1] / 2)
        K = np.array([
            [focal_length, 0, principal_point[0]],
            [0, focal_length, principal_point[1]],
            [0, 0, 1]
        ])
        camera_node = base.cam
        position = camera_node.get_pos()
        mat = base.cam.get_mat()
        c2w = np.array(mat)
        rotation_matrix = np.array([[1, 0, 0, 0],
                                    [0, 0, -1, 0],
                                    [0, -1, 0, 0],
                                    [0, 0, 0, 1]])

        c2w = np.dot(rotation_matrix, c2w)
        c2w = np.transpose(c2w)

        c2w_torch = torch.from_numpy(c2w).unsqueeze(0).float()
        
        return (K, c2w_torch)

    def exit(self):
        if self.recon_loaded:
            self.render_process.terminate()
            self.recon_loaded = False
        base.userExit()