"""Various GS mesh exporters"""
import sys, os
sys.path.append(os.path.abspath('../../../'))
sys.path.append(os.path.abspath('../'))
sys.path.append(os.path.abspath('./'))

import random
from dataclasses import dataclass
from pathlib import Path
from typing import Literal, Optional, Tuple, Union

import numpy as np
import open3d as o3d
import torch
import torch.nn.functional as F
import tyro
from tqdm import tqdm
from typing_extensions import Annotated

from nerfstudio.benchmark_utils.utils import (
    get_colored_points_from_depth,
    get_means3d_backproj,
    project_pix,
    quat_to_rotmat
)

from nerfstudio.data.utils.colmap_parsing_utils import (
    read_points3D_binary,
)

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np

from nerfstudio.cameras.cameras import Cameras
from nerfstudio.models.raytracingfacto import RaytracingfactoModel
from nerfstudio.utils.eval_utils import eval_setup
from nerfstudio.data.utils.colmap_parsing_utils import (
    Camera, Image, Point3D,
    rotmat2qvec, write_cameras_binary, write_images_binary, write_points3D_binary
)
from nerfstudio.utils.rich_utils import CONSOLE

@dataclass
class ColmapFormatCameraExtractor:
    """
    Extract refined poses to COLMAP format
    """

    load_config: Path
    """Path to the trained config YAML file."""

    output_dir: Path = Path("./sparse_refined/0")
    """Path to the output sparse model directory."""

    def main(self):
        if not self.output_dir.exists():
            self.output_dir.mkdir(parents=True)

        _, pipeline, _, _ = eval_setup(self.load_config)
        model: RaytracingfactoModel = pipeline.model

        with torch.no_grad():
            cameras: Cameras = pipeline.datamanager.train_dataset.cameras  # type: ignore
            # TODO: do eval dataset as well if present
            
            colmap_images = {}
            colmap_cameras = {}
            image_filenames = pipeline.datamanager.train_dataset.image_filenames

            for image_idx, data in tqdm(enumerate(
                pipeline.datamanager.train_dataset
            ), "Calculating refined poses", total=len(pipeline.datamanager.train_dataset)):  # type: ignore
                camera_orig = cameras[image_idx : image_idx + 1]
                outputs = model.get_outputs_for_camera(camera=camera_orig)
                camera_refined = outputs["refined_camera"]
                
                c2w = torch.eye(4, dtype=torch.float, device=camera_refined.device)
                c2w[:3, :4] = camera_refined.camera_to_worlds.squeeze(0)
                c2w = c2w @ torch.diag(
                    torch.tensor([1, -1, -1, 1], device=c2w.device, dtype=torch.float)
                )
                
                R_w2c = c2w[:3, :3].T.cpu().numpy()
                t_w2c = torch.squeeze(-c2w[:3, :3].T @ c2w[:3, 3:4]).cpu().numpy()
                image_filename = image_filenames[image_idx]
                image_name = os.path.basename(image_filename)
                image_id = image_idx + 1
                
                colmap_images[image_id] = Image(
                    id=image_id, 
                    qvec=rotmat2qvec(R_w2c), 
                    tvec=t_w2c,
                    camera_id=image_id, 
                    name=image_name,
                    xys=[],
                    point3D_ids=[],
                )
                
                cam_params = np.array([
                    camera_refined.fx.item(), 
                    camera_refined.fy.item(),
                    camera_refined.cx.item(), 
                    camera_refined.cy.item(),
                ])
                colmap_cameras[image_id] = Camera(
                    id=image_id, 
                    model="PINHOLE",
                    width=camera_refined.width.item(),
                    height=camera_refined.height.item(), 
                    params=cam_params,
                )

            # Write Gaussian means as points
            points = {}
            for i, pt in tqdm(enumerate(model.gauss_params["means"].cpu().numpy()), 
                    "Extracting Gaussian means", total=model.gauss_params["means"].shape[0]):
                points[i + 1] = Point3D(
                    id=i + 1, 
                    xyz=pt,
                    rgb=np.array([0, 0, 255]), # TODO get color, requires evaluating SH
                    error=0.0, 
                    image_ids=np.array([1,2,3]),
                    point2D_idxs=np.array([1,2,3]),
                )
            
            write_images_binary(colmap_images, os.path.join(self.output_dir, "images.bin"))
            write_cameras_binary(colmap_cameras, os.path.join(self.output_dir, "cameras.bin"))
            write_points3D_binary(points, os.path.join(self.output_dir, "points3D.bin"))

            CONSOLE.print(
                f"Writing refined sparse model to: {str(self.output_dir)}"
            )


Commands = tyro.conf.FlagConversionOff[
    Union[
        Annotated[ColmapFormatCameraExtractor, tyro.conf.subcommand(name="colmap")],
    ]
]


def entrypoint():
    """Entrypoint for use with pyproject scripts."""
    tyro.extras.set_accent_color("bright_yellow")
    tyro.cli(Commands).main()


if __name__ == "__main__":
    entrypoint()
