
from modules.scene.scene import Scene
#from qr.plenoptic_py import PlenopticField
from modules.scene.fiducial_point import FiducialPoint
#import open3d as o3d
import numpy as np
import cv2
import os
from PIL import Image
import time
from pathlib import Path

def link(self, event):
  print("link")

class SceneController:
  def __init__(self, main_frame):
    self.scene = Scene()

    self.main_frame = main_frame
    self.main_frame.on_menu_item_open_scene_callback = self.load_scene
    self.main_frame.on_menu_item_save_scene_callback = self.save_scene
    self.main_frame.on_menu_item_import_obj_callback = self.import_obj
    self.main_frame.on_menu_item_add_pf_callback = self.create_pf
    self.main_frame.on_menu_item_voxalize_geom_callback = self.voxelize_geom
    self.main_frame.on_menu_item_sample_plf_callback = self.sample_plf
    self.main_frame.on_recon_silhouette_carving_callback = self.recon_silhouette_carving
    self.main_frame.receive_capture_callback = self.get_fiducials
    self.main_frame.save_fiducial_callback = self.add_fiducial
    self.main_frame.remove_fiducial_callback = self.remove_fiducial
    self.main_frame.receive_label_callback = self.get_fiducial_labels
    self.main_frame.save_label_callback = self.add_fiducial_label
    self.main_frame.remove_label_callback = self.remove_fiducial_label
    self.main_frame.get_fiducials_by_label_callback = self.get_fiducials_by_label
    self.main_frame.create_fiducial_point_callback = self.create_fiducial_point
    self.main_frame.apply_colmap_poses_callback = self.apply_colmap_poses


    self.saved_fiducials = {}
    self.saved_fiducial_labels = []

  def apply_colmap_poses(self, colmap_dir):
    self.scene.apply_colmap_poses(colmap_dir)
    
  def create_fiducial_point(self, label, points, captures):
    return FiducialPoint(label, points, captures, self.scene)
  
  def import_obj(self, obj_file, fids_file, auto_align):
    mesh_node = self.scene.create_geometry_node(obj_file, fids_file, auto_align)
    self.main_frame.panel_viewport.add_imported_mesh(mesh_node)

  def get_fiducial_labels(self):
    return self.saved_fiducial_labels
  
  def add_fiducial_label(self, label):
    self.saved_fiducial_labels.append(label)

  def remove_fiducial_label(self, label):
    self.saved_fiducial_labels.remove(label)

  def inser_saved_fiducials(self, capture_id, fiducials):
    self.saved_fiducials[capture_id] = fiducials

  def load_scene(self, scene_file_path):
    
    zipped = False
    ft = 0
    
    suffix = Path(scene_file_path).suffix
    if suffix == '.zip':
      zipped = True
      ft = 1
    elif suffix == '.json':
      ft = 1

    self.scene = Scene(scene_file_path, json_zipped = zipped, file_type = ft)
    
    if not self.scene.is_loaded():
      self.scene = None
      return
    
    self.saved_fiducials = self.scene.get_loaded_fiducials()
    
    if len(self.scene.labels) == 0:
      for capture in self.saved_fiducials:
        for label in self.saved_fiducials[capture]:
          if label not in self.saved_fiducial_labels:
            self.saved_fiducial_labels.append(label)
    else:
      self.saved_fiducial_labels = self.scene.labels

    # self.main_frame.set_scene(self.scene)
    # self.main_frame.menu_item_add_pf.Enable(True)
    # if self.scene.plenoptic_field_nodes:
    #   self.main_frame.menu_item_select_voxel.Enable(True)
    #   self.main_frame.menu_item_select_roi.Enable(True)
    #   self.main_frame.menu_item_sample_plf.Enable(True)

  def save_scene(self, scene_file_path):
    self.scene.save_scene(scene_file_path, self.saved_fiducials)

  def create_pf(self, min_corner, size):
    if(self.scene != None):
      pf = PlenopticField(min_corner, size)
      self.scene.add_plenoptic_field_node(pf)
      self.main_frame.vtk_panel.add_pf_node(pf)
      self.main_frame.scene_tree_ctrl.refresh()
      self.main_frame.menu_item_sample_plf.Enable(True)
      self.main_frame.menu_item_select_voxel.Enable(True)
      self.main_frame.menu_item_select_roi.Enable(True)

  def sample_plf(self, query_position, level):
      plfCube = self.scene.sample_plf(query_position, level)
      self.main_frame.vtk_panel.add_splf(plfCube, query_position, level)

  def voxelize_geom(self, geomNode):
    pf = self.scene.plenoptic_field_nodes[0]['pfNode'].plenopticField
    pf_size = pf.get_world_size()
    print(f"pf size: {pf_size}")
    start_time = time.time()
    self.scene.voxelize_geometry(geomNode)
    end_time = time.time()
    elapsed_time = end_time - start_time
    print(f"Elapsed time1: {elapsed_time:.6f} seconds")
    start_time = time.time()
    self.main_frame.vtk_panel.add_mediel_nodes(self.scene.mediel_nodes, pf_size, True, False)
    end_time = time.time()
    elapsed_time = end_time - start_time
    print(f"Elapsed time2: {elapsed_time:.6f} seconds")

  def recon_silhouette_carving(self):
    print("Scene ctrl recon sil carving")
    mask_folder = "C:/Users/AlexRablau/Desktop/brass_pear_genie/good_masks"
    scene_capture_nodes = self.scene.entities["captures"]
   # voxel_grid = o3d.geometry.VoxelGrid.create_dense(origin = [-0.19, -0.04, 0.65], color = [1,0,1], voxel_size = 0.05, width = 0.25, height = 0.25, depth = 0.25)
    #voxel_grid = o3d.geometry.VoxelGrid.create_dense(origin = [-2, -2, 0], color = [1,0,1], voxel_size = 0.5, width = 4, height = 4, depth = 4)

    for scene_capture_node in scene_capture_nodes:
      print(f"Capture {scene_capture_node.id}")
      mask_path = os.path.join(mask_folder, f"{scene_capture_node.id}.png")
      # Check if mask exists
      if not os.path.exists(mask_path):
        print(f"Mask {mask_path} not found")
        continue

      # Get capture parameters
      camera_model = scene_capture_node.capture.GetCameraModel()
      camera_image = scene_capture_node.capture.GetImage()
      image_size = camera_image.size()
      K = camera_model.GetCameraIntrinsicMatrix()
      T = scene_capture_node.capture.GetPose().GetWorldMatrix()
      
      # Get R and t from world_matrix
      R = T[:3,:3]
      R1 = R[:,0].reshape(3,1)
      R2 = R[:,1].reshape(3,1)
      R3 = R[:,2].reshape(3,1)
      # Concatonate R1 2 and 3 horizontally
      R = np.concatenate((R3,-R1,-R2), axis=1)
      t = T[:3,3]

      extrinsic_matrix = np.eye(4)
      extrinsic_matrix[:3,:3] = R.T
      extrinsic_matrix[:3,3] = -R.T @ t
      
      # Load the mask
      #mask_image = Image.open(mask_path)
      mask_image = cv2.imread(mask_path, cv2.IMREAD_UNCHANGED)
      cv2.imshow("mask", mask_image)
      cv2.waitKey(0)
      # convert to single channel cv2
      mask_image = cv2.cvtColor(mask_image, cv2.COLOR_BGR2GRAY)
      mask_image = cv2.resize(mask_image, (image_size[0], image_size[1]))
      #o3d_image = o3d.geometry.Image(mask_image)

      # Carve
      #pinhole_parameters = o3d.camera.PinholeCameraParameters()  
      #pinhole_parameters.extrinsic = extrinsic_matrix
      #pinhole_parameters.intrinsic.set_intrinsics(image_size[0], image_size[1], K[0,0], K[1,1], K[0,2], K[1,2])
      #voxel_grid.carve_silhouette(o3d_image, pinhole_parameters, keep_voxels_outside_image=True)

    #self.main_frame.vtk_panel.add_voxel_grid(voxel_grid)
  
  def get_fiducials_by_label(self):
    to_return = {}

    for label in self.scene.get_labels():
      fid_list = []
      for capture in self.scene.capture_nodes:
        cap_labels = capture.get_labels()
        if label in cap_labels:
          fid_list.append((cap_labels[label], capture))
        
      to_return[label] = fid_list

    return to_return

  def add_fiducial(self, id, fiducial):
    
    # print(id)
    # print(fiducial)
    self.saved_fiducials[id] = fiducial
    print(self.saved_fiducials)

  def remove_fiducial(self, id, fiducial):
    del self.saved_fiducials[id][fiducial]

  def get_fiducials(self, capture_id):
    if capture_id in self.saved_fiducials:
      return self.saved_fiducials[capture_id]
    else:
      return None
