import os, sys
sys.path.append(os.path.abspath('../../'))

import torch
import numpy as np
import torch.nn.functional as F
from typing import Optional, Tuple

import yaml
import time
import cv2


PROFILING_ENABLED = False

def normalized_quat_to_rotmat(quat: torch.Tensor) -> torch.Tensor:
    assert quat.shape[-1] == 4, quat.shape
    w, x, y, z = torch.unbind(quat, dim=-1)
    mat = torch.stack(
        [
            1 - 2 * (y**2 + z**2),
            2 * (x * y - w * z),
            2 * (x * z + w * y),
            2 * (x * y + w * z),
            1 - 2 * (x**2 + z**2),
            2 * (y * z - w * x),
            2 * (x * z - w * y),
            2 * (y * z + w * x),
            1 - 2 * (x**2 + y**2),
        ],
        dim=-1,
    )
    return mat.reshape(quat.shape[:-1] + (3, 3))


def quat_to_rotmat(quat: torch.Tensor) -> torch.Tensor:
    assert quat.shape[-1] == 4, quat.shape
    return normalized_quat_to_rotmat(F.normalize(quat, dim=-1))


def RGB2SH(rgb):
    """
    Converts from RGB values [0,1] to the 0th spherical harmonic coefficient
    """
    C0 = 0.28209479177387814
    return (rgb - 0.5) / C0

def load_cfg(path):
    with open(path, 'r') as f:
        return yaml.load(f, Loader=yaml.FullLoader)


def profile_execution(func):
    """A decorator to measure the execution time of a function."""
    def wrapper(*args, **kwargs):
        if PROFILING_ENABLED:
            start = time.perf_counter()
        
        result = func(*args, **kwargs)
        
        if PROFILING_ENABLED:
            end = time.perf_counter()
            fps = 1 / (end - start)
            print(f"{func.__name__} executed in {fps:.1f} FPS")
        
        return result
    return wrapper


def profile_step(func_name, start_time, profiling_enabled):
    if profiling_enabled:
        end_time = time.time()
        fps = 1 / (end_time - start_time)
        print(f"{func_name} ({fps:.1f} FPS)")


def camera_to_world(position, direction):
    forward = direction / np.linalg.norm(direction)
    
    world_up = np.array([0, -1, 0], dtype=float)
    
    if np.allclose(forward, world_up) or np.allclose(forward, -world_up):
        world_up = np.array([0, 0, 1], dtype=float)
    
    right = np.cross(world_up, forward)
    right /= np.linalg.norm(right)
    
    up = np.cross(forward, right)
    up /= np.linalg.norm(up)
    
    rotation = np.column_stack((right, up, forward))
    transformation_matrix = np.column_stack((rotation, position.reshape(3, 1)))
    transformation_matrix = np.vstack((transformation_matrix, [0, 0, 0, 1]))
    transformation_matrix = torch.from_numpy(transformation_matrix)
    return transformation_matrix


def create_image_ray_bundle(camera_position, camera_direction, f, W, H):
    pose = camera_to_world(camera_position, camera_direction)

    K = torch.tensor([[f, 0, W // 2], [0, f, H // 2], [0, 0, 1]], dtype=torch.float32)

    if pose is None:
        pose = torch.eye(4, dtype=torch.float32)
    else:
        #pose = pose.clone().detach().float()
        pose = pose.clone().float()

    if K is None:
        f = 0.5 * max(H, W)  # Default focal length
        K = torch.tensor([[f, 0, W // 2], [0, f, H // 2], [0, 0, 1]], dtype=torch.float32)

    coords = torch.stack(torch.meshgrid(torch.arange(H), torch.arange(W)), -1)[:, :, (1, 0)]  # H,W,2

    coords = coords.reshape(-1, 2).float() + 0.5  # H*W,2
    coords = torch.cat([coords, torch.ones( H * W, 1, dtype=torch.float32)], 1)  # imn,h*w,3
    dirs = coords @ torch.inverse(K).T  # H*W,3

    rays_d = torch.matmul(dirs, pose[:3, :3].t())  # (H*W, 3)
    rays_d = F.normalize(rays_d, dim=-1)
    rays_o = pose[:3, 3].expand(rays_d.shape)  # (H*W, 3)
    return rays_o.cuda(), rays_d.cuda()


def create_single_ray_bundle(camera_position, camera_direction):
    ro = torch.tensor(camera_position, dtype=torch.float32).reshape(1,3).cuda()
    rd = torch.tensor(camera_direction, dtype=torch.float32).reshape(1,3).cuda()
    return ro, rd


def create_double_ray_bundle(camera_position, camera_direction, radius):
    # return None, None
    ro = torch.tensor(camera_position, dtype=torch.float32).reshape(1,3)
    rd = torch.tensor(camera_direction, dtype=torch.float32).reshape(1,3)
    return ro, rd


# Add normal related depth computation and losses.
def compute_normals(gaussians, mu, quats, scales, c2w):
    normals = torch.nn.functional.one_hot(
        torch.argmin(scales, dim=-1), num_classes=3
    ).float()
    rots = quat_to_rotmat(quats)
    normals = torch.bmm(rots, normals[:, :, None]).squeeze(-1)
    normals = F.normalize(normals, dim=1)
    viewdirs = (
        -mu.clone().detach() + c2w.clone().detach()[..., :3, 3]
    )
    viewdirs = viewdirs / viewdirs.norm(dim=-1, keepdim=True)
    dots = (normals * viewdirs).sum(-1)
    negative_dot_indices = dots < 0
    normals = torch.where(negative_dot_indices.unsqueeze(-1), -normals, normals)
    gaussians.gauss_params["normals"] = normals.clone().detach()
    normals = normals @ c2w[:3, :3]
    return normals


def pcd_to_normal(xyz):
    hd, wd, _ = xyz.shape
    bottom_point = xyz[..., 2:hd, 1 : wd - 1, :]
    top_point = xyz[..., 0 : hd - 2, 1 : wd - 1, :]
    right_point = xyz[..., 1 : hd - 1, 2:wd, :]
    left_point = xyz[..., 1 : hd - 1, 0 : wd - 2, :]
    left_to_right = right_point - left_point
    bottom_to_top = top_point - bottom_point
    xyz_normal = torch.cross(left_to_right, bottom_to_top, dim=-1)
    xyz_normal = torch.nn.functional.normalize(xyz_normal, p=2, dim=-1)
    xyz_normal = torch.nn.functional.pad(
        xyz_normal.permute(2, 0, 1), (1, 1, 1, 1), mode="constant"
    ).permute(1, 2, 0)
    return xyz_normal


def get_camera_coords(img_size: tuple, pixel_offset: float = 0.5):
    """Generates camera pixel coordinates [W,H]

    Returns:
        stacked coords [H*W,2] where [:,0] corresponds to W and [:,1] corresponds to H
    """

    # img size is (w,h)
    image_coords = torch.meshgrid(
        torch.arange(img_size[0]),
        torch.arange(img_size[1]),
        indexing="xy",  # W = u by H = v
    )
    image_coords = (
        torch.stack(image_coords, dim=-1) + pixel_offset
    )  # stored as (x, y) coordinates
    image_coords = image_coords.view(-1, 2)
    image_coords = image_coords.float()

    return image_coords


def get_means3d_backproj(
    depths: torch.Tensor,
    fx: float,
    fy: float,
    cx: int,
    cy: int,
    img_size: tuple,
    c2w: torch.Tensor,
    device: torch.device
):
    """Backprojection using camera intrinsics and extrinsics

    image_coords -> (x,y,depth) -> (X, Y, depth)

    Returns:
        Tuple of (means: Tensor, image_coords: Tensor)
    """

    if depths.dim() == 3:
        depths = depths.view(-1, 1)
    elif depths.shape[-1] != 1:
        depths = depths.unsqueeze(-1).contiguous()
        depths = depths.view(-1, 1)
    if depths.dtype != torch.float:
        depths = depths.float()
        c2w = c2w.float()
    if c2w.device != device:
        c2w = c2w.to(device)

    image_coords = get_camera_coords(img_size)
    image_coords = image_coords.to(device)  # note image_coords is (H,W)

    # TODO: account for skew / radial distortion
    means3d = torch.empty(
        size=(img_size[0], img_size[1], 3), dtype=torch.float32, device=device
    ).view(-1, 3)
    means3d[:, 0] = (image_coords[:, 0] - cx) * depths[:, 0] / fx  # x
    means3d[:, 1] = (image_coords[:, 1] - cy) * depths[:, 0] / fy  # y
    means3d[:, 2] = depths[:, 0]  # z

    if c2w is None:
        c2w = torch.eye((means3d.shape[0], 4, 4), device=device)

    # to world coords
    means3d = means3d @ torch.linalg.inv(c2w[..., :3, :3]) + c2w[..., :3, 3]
    return means3d, image_coords


def normal_from_depth_image(
    depths: torch.Tensor,
    fx: float,
    fy: float,
    cx: float,
    cy: float,
    img_size: tuple,
    c2w: torch.Tensor,
    device: torch.device,
    smooth: bool = False,
):
    """estimate normals from depth map"""
    if smooth:
        if torch.count_nonzero(depths) > 0:
            print("Input depth map contains 0 elements, skipping smoothing filter")
        else:
            kernel_size = (9, 9)
            depths = torch.from_numpy(
                cv2.GaussianBlur(depths.cpu().numpy(), kernel_size, 0)
            ).to(device)
    means3d, _ = get_means3d_backproj(depths, fx, fy, cx, cy, img_size, c2w, device)
    means3d = means3d.view(img_size[1], img_size[0], 3)
    normals = pcd_to_normal(means3d)
    return normals

def compute_normal_from_depth(depth, camera, device):
    gt_normal = normal_from_depth_image(
        depths=depth,
        fx=camera.fx.item(),
        fy=camera.fy.item(),
        cx=camera.cx.item(),
        cy=camera.cy.item(),
        img_size=(camera.width.item(), camera.height.item()),
        c2w=torch.eye(4, dtype=torch.float, device=depth.device),
        device=device,
        smooth=False,
    )
    gt_normal = gt_normal @ torch.diag(
        torch.tensor(
            [1, -1, -1], device=depth.device, dtype=depth.dtype
        )
    )
    gt_normal = (1 + gt_normal) / 2
    return gt_normal


def project_pix(
    p: torch.Tensor,
    fx: float,
    fy: float,
    cx: int,
    cy: int,
    c2w: torch.Tensor,
    device: torch.device,
    return_z_depths: bool = False,
) -> torch.Tensor:
    """Projects a world 3D point to uv coordinates using intrinsics/extrinsics

    Returns:
        uv coords
    """
    if c2w is None:
        c2w = torch.eye((p.shape[0], 4, 4), device=device)  # type: ignore
    if c2w.device != device:
        c2w = c2w.to(device)

    points_cam = (p.to(device) - c2w[..., :3, 3]) @ c2w[..., :3, :3]
    u = points_cam[:, 0] * fx / points_cam[:, 2] + cx  # x
    v = points_cam[:, 1] * fy / points_cam[:, 2] + cy  # y
    if return_z_depths:
        return torch.stack([u, v, points_cam[:, 2]], dim=-1)
    return torch.stack([u, v], dim=-1)


def get_colored_points_from_depth(
    depths: torch.Tensor,
    rgbs: torch.Tensor,
    c2w: torch.Tensor,
    fx: float,
    fy: float,
    cx: int,
    cy: int,
    img_size: tuple,
    mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Return colored pointclouds from depth and rgb frame and c2w. Optional masking.

    Returns:
        Tuple of (points, colors)
    """
    points, _ = get_means3d_backproj(
        depths=depths.float(),
        fx=fx,
        fy=fy,
        cx=cx,
        cy=cy,
        img_size=img_size,
        c2w=c2w.float(),
        device=depths.device,
    )
    points = points.squeeze(0)
    if mask is not None:
        if not torch.is_tensor(mask):
            mask = torch.tensor(mask, device=depths.device)
        colors = rgbs.view(-1, 3)[mask]
        points = points[mask]
    else:
        colors = rgbs.view(-1, 3)
        points = points
    return (points, colors)


def generate_camera_positions(N, initial_N=6):
    """
    Generate camera positions and directions around the origin in a circular pattern at different heights.

    Parameters:
    - N (int): Total number of cameras to generate.
    - initial_N (int): Initial number of cameras per height (default is 6).

    Returns:
    - positions (np.ndarray): An array of shape (N, 3) containing the (x, y, z) positions of the cameras.
    - directions (np.ndarray): An array of shape (N, 3) containing the normalized direction vectors pointing to the origin.
    """
    positions = []
    directions = []

    if N <= initial_N:
        H = 1
        cameras_per_height = N
    else:
        H = N // initial_N
        if N % initial_N != 0:
            H += 1  # Add an extra height level if cameras are remaining
        cameras_per_height = initial_N

    total_cameras = 0
    # Center the heights around z=0
    heights = np.linspace(-(H - 1) / 2, (H - 1) / 2, H)

    for idx, h in enumerate(heights):
        # Adjust the number of cameras for the last height level if necessary
        if idx == H - 1:
            num_cameras = N - total_cameras
        else:
            num_cameras = cameras_per_height
        total_cameras += num_cameras

        # Distribute cameras evenly around the circle at the current height
        theta = np.linspace(0, 2 * np.pi, num_cameras, endpoint=False)
        for angle in theta:
            x = 5.0 * np.cos(angle)
            y = 5.0 * np.sin(angle)
            z = h  # Current height level
            position = np.array([x, y, z])
            positions.append(position)

            # Calculate the normalized direction vector pointing towards the origin
            direction = -position
            norm = np.linalg.norm(direction)
            if norm != 0:
                direction /= norm
            directions.append(direction)

    return np.array(positions), np.array(directions)