import wx
from wx.lib.agw import customtreectrl
from modules.scene.nodes.mesh_node import MeshNode
from modules.scene.nodes.capture_node import CaptureNode
import cv2
import os
import subprocess
class TreeCtrlScene(customtreectrl.CustomTreeCtrl):
  def __init__(self, parent, id = wx.ID_ANY, pos = wx.DefaultPosition, size = wx.DefaultSize, style = 0, agwStyle=wx.TR_DEFAULT_STYLE | wx.TR_HIDE_ROOT, validator = wx.DefaultValidator, name = wx.TreeCtrlNameStr):
    super().__init__(parent, id, pos, size, style, agwStyle, validator, name)
    self.main_frame = parent
    self.aui_manager = self.main_frame.m_mgr
    self.tree_root = self.AddRoot("Scene")
    self.scene = None
    self.capture_type_mode = False
    self.items = {}
    self.fiducial_nodes = []
    self.current_data = None

    image_list = wx.ImageList(16, 16, False, 2)

    folder_closed = wx.Image("./assets/Icons/File and File Options 54x54 White/folder.png", wx.BITMAP_TYPE_ANY)
    folder_closed = folder_closed.Scale(16, 16, quality=wx.IMAGE_QUALITY_HIGH)
    folder_opened = wx.Image("./assets/Icons/File and File Options 54x54 White/opened-folder.png", wx.BITMAP_TYPE_ANY)
    folder_opened = folder_opened.Scale(16, 16, quality=wx.IMAGE_QUALITY_HIGH)
    image_list.Add(wx.Bitmap(folder_closed))
    image_list.Add(wx.Bitmap(folder_opened))
    self.AssignImageList(image_list)
    self.refresh()
    
    self.init_menu() 
   
    self.Bind(wx.EVT_KEY_DOWN, self.on_key_down)
    self.Bind(wx.EVT_TREE_ITEM_RIGHT_CLICK, self.on_tree_item_right_click)

  def init_menu(self):
    self.right_click_menu = wx.Menu()
    align_item = self.right_click_menu.Append(wx.ID_ANY, "Align")
    voxelize_item = self.right_click_menu.Append(wx.ID_ANY, "Voxelize")
    hide_item = self.right_click_menu.Append(wx.ID_ANY, "Hide")
    show_item = self.right_click_menu.Append(wx.ID_ANY, "Show")

    self.Bind(wx.EVT_MENU, self.align_geometry, id=align_item.GetId())
    self.Bind(wx.EVT_MENU, self.voxelize_geometry, id=voxelize_item.GetId())
    self.Bind(wx.EVT_MENU, self.hide_geometry, id=hide_item.GetId())
    self.Bind(wx.EVT_MENU, self.show_geometry, id=show_item.GetId())

    self.right_click_menu_captures = wx.Menu()
    display_item = self.right_click_menu_captures.Append(wx.ID_ANY, "Display Capture")
    visualize_item = self.right_click_menu_captures.Append(wx.ID_ANY, "Visualize Fiducials")
    project_item = self.right_click_menu_captures.Append(wx.ID_ANY, "Project Surfels")
    pixel_variation_item = self.right_click_menu_captures.Append(wx.ID_ANY, "Compute Pixel Variation")
    delete_capture_item = self.right_click_menu_captures.Append(wx.ID_ANY, "Delete Selected")
    copy_id_item = self.right_click_menu_captures.Append(wx.ID_ANY, "Copy ID")
    fid_show_item = self.right_click_menu_captures.Append(wx.ID_ANY, "Show")
    fid_hide_item = self.right_click_menu_captures.Append(wx.ID_ANY, "Hide")

    self.Bind(wx.EVT_MENU, self.display_capture, id=display_item.GetId())
    self.Bind(wx.EVT_MENU, self.visualize_capture_fiducials, id=visualize_item.GetId())
    self.Bind(wx.EVT_MENU, self.project_surfels, id = project_item.GetId())
    self.Bind(wx.EVT_MENU, self.calc_pixel_variation, id = pixel_variation_item.GetId())
    self.Bind(wx.EVT_MENU, self.delete_capture, id=delete_capture_item.GetId())
    self.Bind(wx.EVT_MENU, self.show_capture, id=fid_show_item.GetId())
    self.Bind(wx.EVT_MENU, self.copy_id, id=copy_id_item.GetId())
    self.Bind(wx.EVT_MENU, self.hide_capture, id=fid_hide_item.GetId())

    self.right_click_menu_3D_fiducials = wx.Menu()
    show_fid = self.right_click_menu_3D_fiducials.Append(wx.ID_ANY, "Show")
    hide_fid = self.right_click_menu_3D_fiducials.Append(wx.ID_ANY, "Hide")
    show_tri = self.right_click_menu_3D_fiducials.Append(wx.ID_ANY, "Show Triangulated Points")
    hide_tri = self.right_click_menu_3D_fiducials.Append(wx.ID_ANY, "Hide Triangulated Points")

    self.Bind(wx.EVT_MENU, self.show_3d_fiduciaL, id=show_fid.GetId())
    self.Bind(wx.EVT_MENU, self.hide_3d_fiducial, id=hide_fid.GetId())
    self.Bind(wx.EVT_MENU, self.show_tri_fid, id=show_tri.GetId())
    self.Bind(wx.EVT_MENU, self.hide_tri_fid, id=hide_tri.GetId())
    
    self.right_click_menu_capture_root = wx.Menu()
    order_submenu = wx.Menu()
    self.check_id = wx.MenuItem(order_submenu, wx.ID_ANY, "ID", kind = wx.ITEM_CHECK)
    self.check_timestamp = wx.MenuItem(order_submenu, wx.ID_ANY, "Timestamp", kind = wx.ITEM_CHECK)
    self.check_type = wx.MenuItem(order_submenu, wx.ID_ANY, "Shot Type", kind = wx.ITEM_CHECK)
    order_submenu.Append(self.check_id)
    order_submenu.Append(self.check_timestamp)
    order_submenu.Append(self.check_type)
    self.check_id.Check()

    export_submenu = wx.Menu()
    self.tiff = wx.MenuItem(export_submenu, wx.ID_ANY, ".tiff")
    self.exr = wx.MenuItem(export_submenu, wx.ID_ANY, ".exr") 
    
    order_item = self.right_click_menu_capture_root.AppendSubMenu(order_submenu, "Order By")
    export_item = self.right_click_menu_capture_root.AppendSubMenu(export_submenu, "Export S0")
    adjust_item = self.right_click_menu_capture_root.Append(wx.ID_ANY, "Adjust Frustum Size")
    export_submenu.Append(self.tiff)
    export_submenu.Append(self.exr)

    self.Bind(wx.EVT_MENU, self.order_by_id, id=self.check_id.GetId())
    self.Bind(wx.EVT_MENU, self.order_by_timestamp, id=self.check_timestamp.GetId())
    self.Bind(wx.EVT_MENU, self.order_by_type, id=self.check_type.GetId())
    self.Bind(wx.EVT_MENU, self.export_s0_tiff, id=self.tiff.GetId())
    self.Bind(wx.EVT_MENU, self.export_s0_exr, id=self.exr.GetId())
    self.Bind(wx.EVT_MENU, self.adjust_frustum, id=adjust_item.GetId())

    self.right_click_menu_segment = wx.Menu()
    optimize_item = self.right_click_menu_segment.Append(wx.ID_ANY, "Optimize")
    show_hide_item = self.right_click_menu_segment.Append(wx.ID_ANY, "Show/Hide Surfels")
    show_hide_voxel_item = self.right_click_menu_segment.Append(wx.ID_ANY, "Show/Hide Voxels")

    self.tri_point_menu = wx.Menu()
    show_tri_item = self.tri_point_menu.Append(wx.ID_ANY, "Show")
    self.Bind(wx.EVT_MENU, self.show_tri_pair, id=show_tri_item.GetId())
    self.Bind(wx.EVT_MENU, self.optimize_segment, id=optimize_item.GetId())
    self.Bind(wx.EVT_MENU, self.show_hide_segment, id=show_hide_item.GetId())
    self.Bind(wx.EVT_MENU, self.show_hide_voxel_segment, id=show_hide_voxel_item.GetId())

    self.right_click_cameras = wx.Menu()
    render_camera_item = self.right_click_cameras.Append(wx.ID_ANY, "Render")
    self.Bind(wx.EVT_MENU, self.render_camera, id=render_camera_item.GetId())
  
  def render_camera(self, event):
    self.main_frame.on_tree_render_camera_pose(self.current_data)

  def project_surfels(self, event):
    selections = self.GetSelections()
    node_list = []
    for selection in selections:
      node_list.append(self.GetItemData(selection))

    self.main_frame.project_surfels(node_list)

  def calc_pixel_variation(self, event):
    selections = self.GetSelections()
    node_list = []
    for selection in selections:
      node_list.append(self.GetItemData(selection))

    self.main_frame.calc_pixel_variation(node_list)

  def adjust_frustum(self, event):
    self.main_frame.vtk_panel_tools.vtk_panel.adjust_frustum_size()
    
  def copy_id(self, event):
    cmd = 'echo ' + str(self.GetItemText(self.GetSelections()[0])).strip() + '|clip'
    return subprocess.check_call(cmd, shell=True)

  def show_tri_pair(self, event):
    self.main_frame.vtk_panel_tools.vtk_panel.color_3d_fiducial_pair(self.GetItemText(self.current_tree_item_menu))
    
  def optimize_segment(self, event):
    self.main_frame.on_menu_item_recon_selected(event, self.current_data)

  def show_hide_segment(self, event):
    glyph_actor = self.main_frame.vtk_panel_tools.vtk_panel.props[self.current_data + "_glyph"]
    lines_actor = self.main_frame.vtk_panel_tools.vtk_panel.props[self.current_data + "_lines"]
    vis = 0 if glyph_actor.GetVisibility() else 1
    glyph_actor.SetVisibility(vis)
    lines_actor.SetVisibility(vis)
    self.main_frame.vtk_panel_tools.render()

  def show_hide_voxel_segment(self, event):
    cubes_actor = self.main_frame.vtk_panel_tools.vtk_panel.props[self.current_data + "_cubes"]
    vis = 0 if cubes_actor.GetVisibility() else 1
    cubes_actor.SetVisibility(vis)
    vis_info = self.main_frame.vtk_panel_tools.vtk_panel.glyph_visibility[self.current_data]
    vis_info[0] = vis
    self.main_frame.vtk_panel_tools.vtk_panel.glyph_visibility[self.current_data] = vis_info
    self.main_frame.vtk_panel_tools.render()

  def order_by_id(self, event):
    self.scene.sort_captures_by_id()
    self.check_timestamp.Check(False)
    self.check_type.Check(False)
    self.capture_type_mode = False
    self.refresh()
    self.Expand(self.capture_root)

  def order_by_timestamp(self, event):
    self.scene.sort_captures_by_timestamp()
    self.check_id.Check(False)
    self.check_type.Check(False)
    self.capture_type_mode = False
    self.refresh()
    self.Expand(self.capture_root)

  def order_by_type(self, event):
    self.scene.sort_captures_by_type()
    self.check_id.Check(False)
    self.check_timestamp.Check(False)
    self.capture_type_mode = True
    self.refresh()
    self.Expand(self.capture_root)

  def show_3d_fiduciaL(self, event):
    pass

  def hide_3d_fiducial(self, event):
    pass

  def show_tri_fid(self, event):
    self.main_frame.vtk_panel_tools.vtk_panel.show_tri_points(self.GetItemText(self.current_tree_item_menu))

  def hide_tri_fid(self, event):
    self.main_frame.vtk_panel_tools.vtk_panel.hide_tri_points(self.GetItemText(self.current_tree_item_menu))

  def display_capture(self, event):
    self.main_frame.view_selected_capture()

  def visualize_capture_fiducials(self, event):
    self.main_frame.draw_test(self.GetSelections())

  def delete_capture(self, event):
    self.main_frame.delete_captures(self.GetSelections())

  def show_capture(self, event):
    actor = self.main_frame.vtk_panel_tools.vtk_panel.props[f"capture_{self.current_data.id}"]
    actor.SetVisibility(True)
    self.main_frame.vtk_panel_tools.render()

  def hide_capture(self, event):
    actor = self.main_frame.vtk_panel_tools.vtk_panel.props[f"capture_{self.current_data.id}"]
    actor.SetVisibility(False)
    self.main_frame.vtk_panel_tools.render()

  def voxelize_geometry(self, event):
    selections = self.GetSelections()

    mesh_nodes = []
    for item in selections:
      mesh_nodes.append(self.GetItemData(item))

    self.main_frame.voxelize(mesh_nodes)

  def align_geometry(self, event):
    self.main_frame.align_obj(self.current_data)

  def hide_geometry(self, event):
    actor = self.main_frame.vtk_panel_tools.vtk_panel.props[self.current_data.get_name()]
    actor.SetVisibility(False)
    self.main_frame.vtk_panel_tools.render()

  def show_geometry(self, event):
    actor = self.main_frame.vtk_panel_tools.vtk_panel.props[self.current_data.get_name()]
    actor.SetVisibility(True)
    self.main_frame.vtk_panel_tools.render()

  def export_s0_tiff(self, event):
    self.export_s0(".tiff")

  def export_s0_exr(self, event):
    self.export_s0(".exr")
  
  def export_s0(self, ext):
    dia = wx.DirDialog(self, "Choose S0 Directory Folder", style=wx.DD_DIR_MUST_EXIST)
    res = dia.ShowModal()

    if res == wx.ID_OK:
      path = dia.GetPath()
      captures = self.capture_root.GetChildren()

      for capture in captures:
        capture_node =  self.GetItemData(capture)
        sub_folder = capture_node.type

        if not os.path.exists(path + "/" + sub_folder):
          os.makedirs(path + "/" + sub_folder)

        id = capture_node.id
          
        save_path = path + "/" + sub_folder + "/" + id + ext
        to_write = capture_node.capture.GetImage().stokes()[0]
        cv2.imwrite(save_path, to_write)

  def set_scene(self, scene):
    self.scene = scene
    self.refresh()

  def select_item(self, id):
    if "capture" in id:
      id = id[8:]
    item = self.items[id]
    self.SelectItem(item)
    self.ScrollTo(item)
    self.SetFocus()

  def on_key_down(self, event):
    key = event.GetKeyCode()
    selections = self.GetSelections()
    if len(selections) == 0:
      return
      
    item = self.GetSelections()[0]
    if key == 70:     # key f
      self.main_frame.vtk_panel_tools.fly_toward_selected()
      return
    elif key == 86 and isinstance(self.GetItemData(item), CaptureNode):   # key v
      self.main_frame.vtk_panel_tools.fly_to_selected()
      data = self.GetItemData(item)
      self.main_frame.vtk_panel_tools.display_capture_in_frustum(data)
      self.main_frame.vtk_panel_tools.vtk_panel.show_frustums()
      self.main_frame.vtk_panel_tools.hide_frustums(f"capture_{data.get_id()}")
    
    event.Skip()

        
  def on_tree_item_right_click(self, event):
    item = event.GetItem()
    item_text = self.GetItemText(item)
    self.current_tree_item_menu = item
    if item_text == "Captures":
      self.current_data = "Captures"
      self.PopupMenu(self.right_click_menu_capture_root, event.GetPoint())

    elif self.GetItemText(self.GetItemParent(item)) == "Fiducials":
      self.current_data = "Tri Pair"
      self.PopupMenu(self.right_click_menu_captures, event.GetPoint())

    elif self.GetItemText(self.GetItemParent(item)) == "segments":
      self.current_data = self.GetItemText(item)
      self.PopupMenu(self.right_click_menu_segment, event.GetPoint())

    elif self.GetItemText(self.GetItemParent(item)) == "Triangulated Points":
      self.PopupMenu(self.tri_point_menu, event.GetPoint())
    
    elif self.GetItemText(self.GetItemParent(item)) == "Cameras":
      self.current_data = item_text
      self.PopupMenu(self.right_click_cameras, event.GetPoint())

    self.current_data = self.GetItemData(item)

    if isinstance(self.current_data, MeshNode):
      self.PopupMenu(self.right_click_menu, event.GetPoint())
    elif isinstance(self.current_data, CaptureNode):
      self.PopupMenu(self.right_click_menu_captures, event.GetPoint())

  def reset_scene(self):
    self.DeleteAllItems()
    self.fiducial_nodes = []
  
  def add_fiducial_node(self, label, pos, tri_points, stats, id_pairs):
    self.fiducial_nodes.append((label, pos, tri_points, stats, id_pairs))
    self.refresh()

  def expand_capture_root(self):
    self.Expand(self.capture_root)

  def expand_capture_item(self, item_name):
    self.Expand(self.cap_dict[item_name])
    self.Expand(self.capture_root)

  def refresh(self):
    self.DeleteAllItems()
    self.cap_dict = {}
    self.tree_root = self.AddRoot("Scene")
    if self.scene is not None:
      aui_pane_info = self.aui_manager.GetPane(self.GetParent())
      aui_pane_info.Caption(self.scene.name)
      if len(self.scene.capture_nodes) > 0:
        self.capture_root = self.AppendItem(self.tree_root, "Captures")
        self.capture_root.SetImage(0, wx.TreeItemIcon_Normal)
        self.capture_root.SetImage(1, wx.TreeItemIcon_Expanded)
        if self.capture_type_mode:
          types = self.scene.get_types()

          root_dict = {}
          ctr = 1
          for cap_type in types:
            item = self.AppendItem(self.capture_root, str(cap_type))
            root_dict[ctr] = item
            ctr += 1

        for capture_node in self.scene.capture_nodes:
          if self.capture_type_mode:
            parent = root_dict[capture_node.get_type()]
          else:
            parent = self.capture_root
          
          item = self.AppendItem(parent, str(capture_node.get_id()))
          self.cap_dict[capture_node.get_id()] = item
          self.SetItemData(item, capture_node)
          self.items[capture_node.get_id()] = item
          cap_labels = capture_node.get_labels()
          for label in cap_labels:
            label_item = self.AppendItem(item, label)
            pos_item = self.AppendItem(label_item, str(cap_labels[label]))

      if len(self.scene.mesh_nodes) > 0:
        self.geometry_root = self.AppendItem(self.tree_root, "Meshes")
        for mesh_node in self.scene.mesh_nodes:
          item = self.AppendItem(self.geometry_root, str(mesh_node.get_name()))
          self.SetItemData(item, mesh_node)
          self.items[mesh_node.get_name()] = item
      if len(self.scene.plenoptic_field_nodes) > 0:
        self.plenoptic_field_root = self.AppendItem(self.tree_root, "Plenoptic Fields")
        for plenoptic_field_node in self.scene.plenoptic_field_nodes:
          pf_item = self.AppendItem(self.plenoptic_field_root, str(plenoptic_field_node['id']))
          self.SetItemData(pf_item, plenoptic_field_node)
          self.items[id] = pf_item
          seg_item_root = self.AppendItem(pf_item, "segments")
          for geom, medielNodes in plenoptic_field_node['pfNode'].segments.items():
            name = geom
            if self.scene.pf_assignments[name] == plenoptic_field_node['id']:
              
              seg_item = self.AppendItem(seg_item_root, name)
              # self.SetItemData(seg_item, medielNodes) #probably not needed to store
              self.items[name] = seg_item
              # at this point no need to add mediel list to tree ctrl
              # for mediel_node in medielNodes:
              #   p = mediel_node.medielInfo.position
              #   result = f"surfel_at_{p[0]:.3f}_{p[1]:.3f}_{p[2]:.3f}"
              #   med_item = self.AppendItem(seg_item, result)

          # add its lightfields
          lf_item_root = self.AppendItem(pf_item, "light_fields")
          plfs = plenoptic_field_node['plfs']
          for plf_at, plf in plfs.items():
            result_string = f"{plf_at} : {str(plf.size())} radiels"
            med_item = self.AppendItem(lf_item_root, result_string)
            
      if len(self.fiducial_nodes) > 0:
        self.fiducial_root = self.AppendItem(self.tree_root, "Fiducials")
        for fiducial_node in self.fiducial_nodes:
          item = self.AppendItem(self.fiducial_root, fiducial_node[0])
          item.SetData((fiducial_node[1], fiducial_node[3]))
          self.items[fiducial_node[0]] = item
          item_pos = self.AppendItem(item, str(fiducial_node[1]))
          tri_points = fiducial_node[4]
          tri_points_root = self.AppendItem(item, "Triangulated Points")
          ctr = 0
          for point in tri_points:
            parent_item = self.AppendItem(tri_points_root, str(fiducial_node[0]) + "_tri_" + str(ctr))
            ctr += 1
            self.AppendItem(parent_item, str(point[0]))
            self.AppendItem(parent_item, str(point[1]))
            self.AppendItem(parent_item, str(point[2]))
      
      if len(self.scene.render_cameras) > 0:
        self.cameras_root = self.AppendItem(self.tree_root, "Cameras")
        for camera in self.scene.render_cameras:
          item = self.AppendItem(self.cameras_root, f"camera_{camera}")
          item.SetData(f"camera_{camera}")
          self.items[f"camera_{camera}"] = item
      self.aui_manager.Update()