from direct.showbase.ShowBase import ShowBase

import simplepbr
from direct.task import Task
from direct.showbase import ShowBaseGlobal
import wx
import trimesh
import numpy as np
import os
import math
from panda3d.core import WindowProperties, Filename, Texture, PNMImage, LineSegs, NodePath, TextNode, BillboardEffect, GeomVertexRewriter
from panda3d.core import AmbientLight, DirectionalLight,  BoundingBox, Material, LColor, Vec4, Vec3,Point3, ColorAttrib
from panda3d.core import GeomNode, GeomVertexFormat, GeomVertexReader, GeomVertexWriter, Geom,  GeomVertexData, GeomPoints
from panda3d.core import CollisionTraverser, CollisionNode, CollisionRay, CollisionHandlerQueue, CollisionPolygon, CollisionBox, BitMask32
from panda3d.core import loadPrcFileData

loadPrcFileData ("", "load-file-type p3assimp")
class QShowbase(ShowBase):

  def __init__(self, panel, frame_main):
    self.wx_panel = panel
    self.frame_main = frame_main
    
    wp = WindowProperties()
    wp.setOrigin(0,0)
    w, h = self.wx_panel.GetSize()
    wp.setSize(w, h)
    wp.setParentWindow(self.wx_panel.GetHandle())
    ShowBase.__init__(self, fStartDirect=True, windowType='none')
    
    grey_material = Material()
    grey_material.set_diffuse(LColor(0.5, 0.5, 0.5, 1))  # Set grey color

    self.startWx()
    base.openMainWindow(props = wp, gsg = None)
    self.wx_panel._win = base.win

    #self.point_cloud.set_transform(self.get_first_transform(self.textured))
    simplepbr.init(render_node=self.render, enable_shadows=False,)
    base.setBackgroundColor(0.02, 0.02, 0.02)
    self.wx_panel.Bind(wx.EVT_SIZE, self.OnResize)
   
    self.line_np = None
    self.text_np = None

    self.cTrav = CollisionTraverser()
    self.cHandler = CollisionHandlerQueue()

    self.pickerNode = CollisionNode('mouseRay')
    self.pickerNP = base.cam.attachNewNode(self.pickerNode)
    # cnodePath.node().setIntoCollideMask(BitMask32.allOff())
    self.pickerNode.setFromCollideMask(BitMask32.bit(1))
    # self.pickerNode.setFromCollideMask(BitMask32.bit(0))
    # self.pickerNode.setIntoCollideMask(BitMask32.allOff())
    self.pickerRay = CollisionRay()

    self.pickerNode.addSolid(self.pickerRay)
    self.cTrav.addCollider(self.pickerNP, self.cHandler)
    self.accept('space', self.on_mouse_click)

    self.measure_ctr = 0
    self.sphere_a = self.loader.loadModel("models/misc/sphere")
    self.sphere_a.setColor(1, 0, 0, 1)  # Make it red
    self.sphere_a.setScale(0.05)  # Adjust size as needed
    self.sphere_a.reparentTo(self.render)
    self.sphere_a.hide()  # Hide it initially
    
    self.sphere_b = self.loader.loadModel("models/misc/sphere")
    self.sphere_b.setColor(0, 0, 1, 1)  # Make it red
    self.sphere_b.setScale(0.05)  # Adjust size as needed
    self.sphere_b.hide()  # Hide it initially
    self.sphere_b.reparentTo(self.render)

    self.scale = 1
    self.measure = True

    self.accept('r', self.clear_measure)

    

  def calc_distance_and_display(self):
    
    line = LineSegs()
    line.setColor(0, 0, 0, 1)
    line.setThickness(5)
    line.moveTo(self.point_a)
    line.drawTo(self.point_b)
    
    line_node = line.create()
    self.line_np = NodePath(line_node)
    self.line_np.reparentTo(self.render)

    distance = np.linalg.norm(np.array(self.point_b) - np.array(self.point_a)) * self.scale
    self.distance = distance
    midpoint = (self.point_a + self.point_b) / 2

    text_node = TextNode('distance_text')
    text_node.setText(f"{distance:.2f}")
    text_node.setTextColor(LColor(.35, 1, .35, 1))
    text_node.setAlign(TextNode.ACenter)
    text_node.setShadowColor(Vec4(0, 0, 0, 1))  # Black shadow color
    text_node.setShadow(0.04, 0.04)  # Offset for the shadow
    
    self.text_np = base.render.attachNewNode(text_node)
    self.text_np.setPos(midpoint + Vec3(0, 0, 0.5)) 

    self.text_np.setScale(0.5)
    billboard = BillboardEffect.make(Vec3(0, 0, 1), True, True, 0.0, base.cam, Point3(0, 0, 0))
    self.text_np.setEffect(billboard)
    self.text_np.setBin('fixed', 100)
    self.text_np.setDepthTest(False)
    self.text_np.setDepthWrite(False)

  def print_vertex_info(self, model):
    for node in model.findAllMatches('**/+GeomNode'):
        geom_node = node.node()
        for i in range(geom_node.getNumGeoms()):
            geom = geom_node.getGeom(i)
            vdata = geom.getVertexData()

            vertex_reader = GeomVertexReader(vdata, 'vertex')
            normal_reader = GeomVertexReader(vdata, 'normal')
            color_reader = GeomVertexReader(vdata, 'color')

            while not vertex_reader.isAtEnd():
                vertex = vertex_reader.getData3()
                normal = normal_reader.getData3() if normal_reader else None
                color = color_reader.getData4() if color_reader else None
                print(f"Vertex: {vertex}, Normal: {normal}, Color: {color}")

  def recalc_distance(self):
     if self.line_np and self.text_np:
      self.line_np.removeNode()
      self.line_np = None
      self.text_np.removeNode()
      self.text_np = None
      self.calc_distance_and_display()

  def convert_ply_to_obj(self, ply_path, obj_path):
    mesh = trimesh.load(ply_path)
    mesh.vertex_normals = mesh.vertex_normals #/ np.linalg.norm(mesh.vertex_normals, axis=1, keepdims=True)
    obj = trimesh.exchange.obj.export_obj(mesh, include_normals=True)
    with open(obj_path, 'w') as f:
        f.write(obj)

    gray_color = np.array([0.5, 0.5, 0.5, 1.0])  # RGBA
    mesh.visual.vertex_colors = np.tile(gray_color, (mesh.vertices.shape[0], 1))
    modified_path = obj_path.replace(".obj", "_untextured.obj")
    obj = trimesh.exchange.obj.export_obj(mesh, include_normals=True)
    with open(modified_path, 'w') as f:
        f.write(obj)
    
  
  def clear_measure(self):
    self.textured.setCollideMask(BitMask32.bit(0))
    self.measure_ctr = 0
    self.sphere_a.hide()
    self.sphere_b.hide()
    if self.line_np:
      self.line_np.removeNode()
      self.line_np = None
    if self.text_np:
      self.text_np.removeNode()
      self.text_np = None


  def insert_mesh(self, mesh_path):
    if ".obj" not in mesh_path:
      self.convert_ply_to_obj(mesh_path, mesh_path.replace(".ply", ".obj"))
      mesh_path = mesh_path.replace(".ply", ".obj")
    #mesh_path = mesh_path.replace(".glb", ".obj")
    matte_material = Material()
    matte_material.set_diffuse(Vec4(1, 1, 1, 1))  # Set diffuse color (white)
    matte_material.set_specular(Vec4(0, 0, 0, 1))  # Set specular color (black, no specular highlight)
    matte_material.set_shininess(0)
  
    #mesh_path  ="C:/Users/AlexAdkins/Downloads/docking_ring2_in (1).glb"
    
    
    self.textured = self.loader.loadModel(Filename.fromOsSpecific(mesh_path))
    self.textured.reparentTo(self.render)
    #self.flip_normals(self.textured)
    #self.calculate_normals(self.textured)
    #self.textured.setScale(2, 2, 2)
    self.textured.setCollideMask(BitMask32.bit(1))
    self.textured.setTwoSided(True)

    tex = Texture("uniform_color")
    image = PNMImage(512, 512)
    image.fill(0.5, 0.5, 0.5)  # RGB values for gray
    tex.load(image)
    
    untextured_path = mesh_path.replace(".obj", "_untextured.obj")
    self.untextured = self.loader.loadModel(Filename.fromOsSpecific(untextured_path))
    self.untextured.setTwoSided(True)
    self.untextured.reparentTo(self.render)
    #self.print_vertex_info(self.untextured)
    #self.untextured.setAttrib(ColorAttrib.makeOff())
    #self.untextured.setColor(.5, .5, .5, 1)
    # self.untextured.setColorOff()
    # self.untextured.setMaterialOff()
    #self.untextured.setMaterialOff()
    # material = Material()
    # material.setShininess(50)  # Set shininess to a reasonable value
    # material.setAmbient((0.5, 0.5, 0.5, 1))
    # material.setDiffuse((0.7, 0.7, 0.7, 1))
    # material.setSpecular((1, 1, 1, 1))
    # img = PNMImage(1, 1)
    # img.fill(0.5, 0.5, 0.5)  # Gray color
    # texture = Texture("gray_texture")
    # texture.load(img)
    # self.untextured.setTexture(texture, 1)
    # self.untextured.setMaterial(material, 1)  # Apply the material to the model

    self.untextured.hide()

    #self.untextured.setShaderAuto()
    self.point_cloud = self.create_point_cloud(self.textured)
    self.point_cloud.reparentTo(self.render)
    #self.point_cloud.setScale(2, 2, 2)
    self.point_cloud.setLightOff(1)
    self.point_cloud.hide()

    bbox = self.get_bounding_box(self.textured)
    self.center = (bbox.getMin() + bbox.getMax()) * 0.5
    self.model_size = bbox.getMax() - bbox.getMin()
    self.cam_pos_inital = self.trackball.node().get_pos()
    #self.trackball.node().setRelTo(base.cam)
    # self.camera_dummy = base.render.attachNewNode("dummyNode")
    base.camera.reparentTo(self.textured)
    self.trackball.node().setForwardScale(.01)
    self.trackball.node().setOrigin(self.center)
    self.reset_cam()
    # base.camera.setPos(self.center) # 10 = distance between cam and point
    # base.camera.setH(60) #this will rotate it 60 degrees around the point

    # self.reset_cam()
    # self.trackball.node().set_pos(self.center)
    # base.cam.headsUp(self.center, Vec3(0, 0, 1))
    # base.cam.lookAt(center)
    # base.camera.setUp(Vec3(0, 0, 1))
    self.add_lighting()


  def reset_cam(self):
    cam_distance = max(self.model_size)
    #base.cam.lookAt(self.center)
    #base.cam.setPos(self.center)# + Vec3(0, -cam_distance, cam_distance))
    
    # camera_pos = self.trackball.node().getPos()

    # # Get the trackball's origin
    # origin = self.trackball.node().getOrigin()

    # # Calculate the direction from the origin to the camera
    # direction = (camera_pos - origin).normalized()

    # # Move the camera backward along the direction vector to zoom out
    # new_camera_pos = camera_pos + direction * 1.8
    # self.trackball.node().setP(new_camera_pos[0])
    self.trackball.node().reset()
    #self.cam.lookAt(self.textured, 0, 0, 5)
    
    # Ensure the camera continues to look at the origin
    # self.cam.lookAt(origin)



  def on_mouse_click(self):
    if not self.measure:
      return
    self.textured.setCollideMask(BitMask32.bit(1))
    if self.mouseWatcherNode.hasMouse():
      mpos = base.mouseWatcherNode.getMouse()
      #self.pickerNP.detachNode()
      self.pickerRay.setFromLens(base.camNode, mpos.getX(), mpos.getY())
      
      self.cHandler.clearEntries()
      base.cTrav.traverse(self.render)
      if self.cHandler.getNumEntries() > 0:
        self.cHandler.sortEntries()
        print(len(self.cHandler.getEntries()))
        entry = self.cHandler.getEntry(0)
        collision_point = entry.getSurfacePoint(self.render)
        print("Collision point:", collision_point)
        self.measure_ctr+=1
        if self.measure_ctr == 3:
          self.clear_measure()
          return
        
        if self.measure_ctr == 1:
          self.point_a = collision_point
          self.sphere_a.setPos(collision_point)
          self.sphere_a.show()
        elif self.measure_ctr == 2:
          self.point_b = collision_point
          self.sphere_b.setPos(collision_point)
          self.sphere_b.show()
          self.calc_distance_and_display()
    self.textured.setCollideMask(BitMask32.bit(0))

  def get_first_transform(self, node):
    for np in node.find_all_matches('**/+GeomNode'):
      return np.get_transform()
  
  def set_grey_material(self, node, material):
    for np in node.find_all_matches('**/+GeomNode'):
      np.set_material(material)
  
  def clear_materials_and_texture(self, node):
    for np in node.find_all_matches('**/+GeomNode'):
      np.clear_texture()
      np.clear_material()
      np.set_shader_off()  # Disable any custom shaders


  def toggle_view(self, view_mode: int):
    self.textured.hide()
    self.point_cloud.hide()
    self.untextured.hide()

    if view_mode == 0:
      self.textured.show()
    elif view_mode == 1:
      self.untextured.show()
    else:
      self.point_cloud.show() 

  def create_point_cloud(self, node):
    vertex_format = GeomVertexFormat.get_v3()
    vertex_data = GeomVertexData('point_cloud', vertex_format, Geom.UH_static)
    vertex_writer = GeomVertexWriter(vertex_data, 'vertex')

    for np in node.find_all_matches('**/+GeomNode'):
        geom_node = np.node()
        for i in range(geom_node.get_num_geoms()):
            geom = geom_node.get_geom(i)
            vdata = geom.get_vertex_data()
            vertex_reader = GeomVertexReader(vdata, 'vertex')
            
            while not vertex_reader.is_at_end():
                vertex = vertex_reader.get_data3()
                vertex_writer.add_data3(vertex)

    points_prim = GeomPoints(Geom.UH_static)
    points_prim.add_next_vertices(vertex_data.get_num_rows())
    points_prim.close_primitive()

    point_geom = Geom(vertex_data)
    point_geom.add_primitive(points_prim)

    point_geom_node = GeomNode('point_cloud')
    point_geom_node.add_geom(point_geom)

    #point_geom_node.set_render_mode_thickness(5)

    return self.render.attach_new_node(point_geom_node)

  def add_lighting(self):
    # Add ambient light
    ambient_light = AmbientLight("ambientLight")
    ambient_light.setColor((0.2, 0.2, 0.2, 1))
    a_np = self.render.attachNewNode(ambient_light)
    self.render.setLight(a_np)

    # directional_light = DirectionalLight("directional_light")
    # directional_light.setColor(Vec4(1, 1, 1, 1))  # Adjust directional light color
    # directional_light_np = self.render.attachNewNode(directional_light)
    # #directional_light_np.setHpr(0, -45, 0)  # Adjust light direction
    # #directional_light_np.setPos(1000, 1000, 1000)  # Far away position
    # #directional_light_np.lookAt(0, 0, 0)
    # self.render.setLight(directional_light_np)

    # Add directional light
    directional_light = DirectionalLight("directionalLight")
    directional_light.setColor((0.8, 0.8, 0.8, 1))
    #directional_light.setShadowCaster(True, 512, 512)

    self.directional_light_node = base.cam.attachNewNode(directional_light)
    self.render.setLight(self.directional_light_node)

  def flip_normals(self, node):
    for np in node.find_all_matches('**/+GeomNode'):
        geom_node = np.node()
        for i in range(geom_node.get_num_geoms()):
            geom = geom_node.modify_geom(i)
            vdata = geom.modify_vertex_data()
            
            normal_reader = GeomVertexReader(vdata, 'normal')
            normal_writer = GeomVertexWriter(vdata, 'normal')
            
            for row in range(vdata.get_num_rows()):
                normal_reader.set_row(row)
                normal = normal_reader.get_data3()
                flipped_normal = -normal  # Flip the normal
                normal_writer.set_row(row)
                normal_writer.set_data3(flipped_normal)
  
  def toggle_measure(self, toggle):
    self.measure = toggle

  def scale_model(self, scale_to):
    scale = scale_to / self.distance
    self.scale = self.scale * scale
    self.recalc_distance()
  
  def export_model(self, ply_path, export_path):
    mesh = trimesh.load(ply_path)
    mesh.apply_scale(self.scale)
    mesh.export(export_path, file_type='ply')

  def get_bounding_box(self, node):
    bounds = node.get_tight_bounds()
    min_point = bounds[0]
    max_point = bounds[1]
    return BoundingBox(min_point, max_point)
  
  def get_wx_prop(self):
    wp = WindowProperties()
    wp.setOrigin(0,0)
    w, h = self.wx_panel.GetSize()
    wp.setSize(w, h)
    wp.setParentWindow(self.wx_panel.GetHandle())
    return wp
  
  def OnResize(self, event): 
    wp = self.get_wx_prop()  
    base.win.requestProperties(wp) 
    event.Skip()

  def exit(self):
    base.userExit()
