import argparse
import torch
import json
import numpy as np
import shutil
import os

import cv2
import matplotlib
import matplotlib.pyplot as plt
from torchvision.utils import save_image
from pathlib import Path
from nerfstudio.utils.eval_utils import eval_setup
from nerfstudio.cameras.cameras import Cameras
from typing import Optional
from tqdm import tqdm

from nerfstudio.models.splatfacto import SplatfactoModel, resize_image
from nerfstudio.utils.dn_utils import normal_from_depth_image

np.random.seed(11111)


def load_configs(scene_name, benchmark_path, filenames, result_path):
    views_path = result_path.joinpath(f"{scene_name}.json")
    if benchmark_path is not None:
        views = benchmark_path.joinpath(f"render_configs/{scene_name}.json")

    if views_path.exists():
        views = views_path
    else:
        idx = np.random.choice(len(filenames), int(len(filenames) * 0.05), replace=False)
        images = {filenames[i].name : filenames[i] for i in idx}
        data = {
            'images': [str(f) for f in images.values()]
        }
        with open(views_path, "w") as f:
            filedata = json.dumps(data, indent=4)
            f.writelines(filedata)
        views = views_path

    with open(views, "r") as fs:
        data = json.load(fs)

    if "images" not in data.keys():
        print("Make sure json a key called 'images' which contains list.")
    images = {Path(f).name : f for f in data["images"]}
    return images

def update_config_callback(config):
    config.pipeline.datamanager.camera_res_scale_factor = 0.25
    return config


def render(scene_name, dataset, model_name, timestamp, result_path=None, config_path=None, benchmark_path=None):

    print("Generating image.")
    if result_path is None:
        result_path = Path().home().joinpath(f"benchmark/{scene_name}")

    if config_path is None:
        config_path = result_path.joinpath(dataset, model_name, timestamp, "config.yml")
        print("Using the config path at: ", config_path)

    gt_path = result_path.joinpath("gt_images")
    gt_path.mkdir(parents=True, exist_ok=True)

    gt_depth_path = result_path.joinpath("gt_depth")
    gt_depth_path.mkdir(exist_ok=True)

    gt_surface_path = result_path.joinpath("gt_surface")
    gt_surface_path.mkdir(exist_ok=True)


    rgb_path= result_path.joinpath("rgb_renders")
    rgb_path.mkdir(exist_ok=True)
    depth_path = result_path.joinpath("depth_renders")
    depth_path.mkdir(exist_ok=True)
    surface_path = result_path.joinpath("surface_renders")
    surface_path.mkdir(exist_ok=True)
    normal_path = result_path.joinpath("normal_renders")
    normal_path.mkdir(exist_ok=True)

    assert config_path.exists(), f"Make sure you are using valid experiment: {config_path}"

    _, pipeline, _, _ = eval_setup(config_path, update_config_callback=update_config_callback)
    model = pipeline.model

    filenames = pipeline.datamanager.train_dataset.image_filenames
    images = load_configs(
        scene_name,
        benchmark_path=benchmark_path,
        filenames=filenames,
        result_path=result_path
    )
    
    export_data = {}
    cmap = matplotlib.colormaps.get_cmap('Spectral_r')

    with torch.no_grad():
        cameras: Cameras = pipeline.datamanager.train_dataset.cameras  # type: ignore
        for image_idx, data in enumerate(pipeline.datamanager.train_dataset):
            camera = cameras[image_idx : image_idx + 1]
            c2w = torch.eye(4, dtype=torch.float, device=camera.device)
            c2w[:3, :4] = camera.camera_to_worlds.squeeze(0)
            image_name = pipeline.datamanager.train_dataset.image_filenames[data['image_idx']].name
            image_stem = pipeline.datamanager.train_dataset.image_filenames[data['image_idx']].stem
            if image_name not in images.keys():
                continue

            image_file = pipeline.datamanager.train_dataset.image_filenames[data['image_idx']]
            depth_image_file = image_file.parent.parent.joinpath(f'mono_depth/{image_name}.npy')
            if not depth_image_file.exists():
                depth_image_file = image_file.parent.parent.joinpath(f'mono_depth/{image_stem}.npy')

            depth_image = torch.from_numpy(np.load(depth_image_file, allow_pickle=True)).to(camera.device)
            depth_image = resize_image(depth_image.unsqueeze(2), 4)

            normalized_image = depth_image
            shutil.copy(depth_image_file, gt_depth_path.joinpath(f"{image_stem}.npy"))
            surface_normal = normal_from_depth_image(
                normalized_image,
                fx=camera.fx.item(),
                fy=camera.fy.item(),
                cx=camera.cx.item(),
                cy=camera.cy.item(),
                img_size=(camera.width.item(), camera.height.item()),
                c2w=torch.eye(4, dtype=torch.float, device=camera.device),
                device=camera.device,
                smooth=False,
            )
            surface_normal = surface_normal @ torch.diag(
                torch.tensor([1, -1, -1], device=camera.device, dtype=torch.float)
            )
            surface_normal = (1 + surface_normal) / 2
            save_image(surface_normal.permute(2, 0, 1), gt_surface_path.joinpath(f"{image_stem}.png"))

            save_image(data["image"].permute(2, 0, 1), gt_path.joinpath(f"{image_stem}.png"))

            outputs = model.get_outputs_for_camera(camera=camera)
            save_image(outputs["rgb"].permute(2, 0, 1), rgb_path.joinpath(f"{image_stem}.png"))

            depth = outputs["depth"].detach().squeeze(-1).cpu().numpy()
            np.save(depth_path.joinpath(f"{image_stem}.npy"), depth)

            np.save(depth_path.joinpath(f"{image_stem}.npy"), depth)


            # Normalize the depth
            depth_image = normalized_image.detach().squeeze(2).cpu().numpy()
            min_depth = min(depth_image.min(), depth.min())
            max_depth = max(depth_image.max(), depth.max())

            gt_depth = (depth_image - min_depth) / (max_depth - min_depth) * 255.0
            gt_depth = gt_depth.astype(np.uint8)
            gt_depth = (cmap(gt_depth)[:, :, :3] * 255)[:, :, ::-1].astype(np.uint8)
            # Save image
            cv2.imwrite(gt_depth_path.joinpath(f"{image_stem}.png").as_posix(), gt_depth)

            depth = (depth - min_depth) / (max_depth - min_depth) * 255.0
            depth = depth.astype(np.uint8)
            depth = (cmap(depth)[:, :, :3] * 255)[:, :, ::-1].astype(np.uint8)
            # Save image
            cv2.imwrite(depth_path.joinpath(f"{image_stem}.png").as_posix(), depth)

            save_image(outputs["normal"].permute(2, 0, 1), normal_path.joinpath(f"{image_stem}.png"))
            save_image(outputs["surface_normal"].permute(2, 0, 1), surface_path.joinpath(f"{image_stem}.png"))
            export_data[image_name] = {
                "path": images[image_name],
                "c2w": camera.camera_to_worlds.squeeze(0).detach().cpu().numpy().tolist(),
                "k": camera.get_intrinsics_matrices().squeeze(0).detach().cpu().numpy().tolist(),
                "w": camera.width.item(),
                "h": camera.height.item(),
            }

        with open(result_path.joinpath("camera.json"), "w") as fp:
            data = json.dumps(export_data, indent=4)
            fp.writelines(data)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--output_scene", type=str, required=True)
    parser.add_argument("--dataset", type=str, required=True)
    parser.add_argument("--model_name", type=str, required=True)
    parser.add_argument("--timestamp", type=str, required=True)
    parser.add_argument("--config_path", type=Path, default=None, required=True)
    parser.add_argument("--result_path", type=Path, default=None, required=True)
    parser.add_argument("--benchmark_path", type=Path, default=None, required=False)
    args = parser.parse_args()

    scene_name = args.output_scene
    dataset = args.dataset
    model_name = args.model_name
    timestamp = args.timestamp

    config_path = args.config_path
    result_path = args.result_path
    benchmark_path = args.benchmark_path

    render(scene_name, dataset, model_name, timestamp, result_path, config_path, benchmark_path)
