import matplotlib.pyplot as plt
import numpy as np
import pydicom as dicom
from skimage.morphology import binary_closing
from skimage.morphology import disk
from skimage import color, io, measure, img_as_ubyte, img_as_float
from scipy.spatial import distance
from skimage.transform import rotate
import glob
from sklearn.decomposition import PCA
from skimage.transform import rescale


def pca_on_ct_image_f_2025():
    in_dir = r"data/chest_ct/"
    all_images = ["1-100.dcm", "1-110.dcm", "1-120.dcm", "1-130.dcm",
                  "1-140.dcm", "1-150.dcm", "1-160.dcm", "1-170.dcm",
                  "1-180.dcm", "1-190.dcm", "1-200.dcm"]
    n_samples = len(all_images)

    ct = dicom.dcmread(in_dir + all_images[0])
    im_org = ct.pixel_array

    # Read first image to get image dimensions
    im_shape = im_org.shape
    height = im_shape[0]
    width = im_shape[1]
    channels = 1
    n_features = height * width * channels

    # print(f"Found {n_samples} image files. Height {height} Width {width} Channels {channels} n_features {n_features}")

    data_matrix = np.zeros((n_samples, n_features))

    idx = 0
    for image_file in all_images:
        ct = dicom.dcmread(in_dir + image_file)
        img = ct.pixel_array
        flat_img = img.flatten()
        data_matrix[idx, :] = flat_img
        idx += 1

    average_image = np.mean(data_matrix, 0)

    img = average_image.reshape(height, width)
    # plt.imshow(img, cmap="gray", vmin=100, vmax=200)
    # plt.title('The Average Image')
    # plt.show()

    min_hu = 100
    max_hu = 500

    bin_img = (img > min_hu) & (img < max_hu)
    # plt.imshow(bin_img, cmap="gray")
    # plt.title("Binary image")
    # plt.show()

    footprint = disk(5)
    closing = binary_closing(bin_img, footprint)
    # io.imshow(closing)
    # plt.title("Closed")
    # io.show()

    label_img = measure.label(closing)

    n_labels = label_img.max()
    print(f"Answer: Number of labels: {n_labels}")

    vb_label_colour = color.label2rgb(label_img)
    # plt.imshow(vb_label_colour)
    # plt.title("Segmented")
    # plt.show()

    region_props = measure.regionprops(label_img)

    min_area = 2000
    max_area = 5000
    min_per = 400
    max_per = 600

    min_found_area = np.inf
    max_found_area = -np.inf
    # Create a copy of the label_img
    label_img_filter = label_img.copy()
    for region in region_props:
        a = region.area
        p = region.perimeter
        if a < min_found_area:
            min_found_area = a
        if a > max_found_area:
            max_found_area = a

        if a < min_area or a > max_area or p < min_per or p > max_per:
            for cords in region.coords:
                label_img_filter[cords[0], cords[1]] = 0
        else:
            print(f"Answer: Area {a:.0f} Perimeter {p:.0f}")

    # Create binary image from the filtered label image
    i_bone = label_img_filter > 0
    # plt.imshow(i_bone, cmap="gray")
    # plt.show()

    ref_image = f"{in_dir}1-115.dcm"
    ct = dicom.dcmread(ref_image)
    im_ref = ct.pixel_array
    im_ref_bin = (im_ref > min_hu) & (im_ref < max_hu)
    # plt.imshow(im_ref_bin, cmap="gray")
    # plt.show()

    dice_score = 1 - distance.dice(i_bone.ravel(), im_ref_bin.ravel())
    print(f"Answer: DICE score {dice_score:.3f}")

    # print("Computing PCA")
    image_pca = PCA(n_components=5)
    image_pca.fit(data_matrix)

    # plt.plot(image_pca.explained_variance_ratio_ * 100)
    # plt.xlabel('Principal component')
    # plt.ylabel('Percent explained variance')
    # plt.show()

    components = image_pca.transform(data_matrix)

    pc_1 = components[:, 0]
    pc_2 = components[:, 1]

    plt.plot(pc_1, pc_2, '.')
    plt.xlabel('PC1')
    plt.ylabel('PC2')
    plt.show()

    im_miss_in = f"{in_dir}1-105.dcm"
    ct = dicom.dcmread(im_miss_in)
    im_miss = ct.pixel_array

    im_miss_flat = im_miss.flatten()
    im_miss_flat = im_miss_flat.reshape(1, -1)

    pca_coords = image_pca.transform(im_miss_flat)
    pca_coords = pca_coords.flatten()

    pca_distances = np.linalg.norm(components - pca_coords, axis=1)
    best_match = np.argmin(pca_distances)

    print(f"Answer: Best matching image {all_images[best_match]}")


def change_detection_f2025():
    """
    """
    in_dir = "data/video/"
    all_images = ["movie_000.jpg", "movie_001.jpg", "movie_002.jpg", "movie_003.jpg", "movie_004.jpg", "movie_005.jpg"]

    imgs = []
    for image_file in all_images:
        img = io.imread(f"{in_dir}{image_file}")
        gray_img = color.rgb2gray(img)
        imgs.append(gray_img)

    alpha = 0.80

    background_gray = imgs[0]
    for i in range(1, 5):
        new_frame_gray = imgs[i]
        background_gray = alpha * background_gray + (1 - alpha) * new_frame_gray

    max_val = np.max(background_gray)
    print(f"Answer: maximum value : {max_val:.2f}")

    # plt.imshow(background_gray)
    # plt.title('The Average Image')
    # plt.show()

    new_image = imgs[5]
    dif_img = np.abs(new_image - background_gray)
    # plt.imshow(dif_img)
    # plt.title('The Difference Image')
    # plt.show()

    dif_thres = 0.2
    dif_bin = (dif_img > dif_thres)
    plt.imshow(dif_bin, cmap='gray')
    plt.title('The Binary Difference Image')
    plt.show()
    changed_pixels = np.sum(dif_bin)
    print(f"Answer: Changed pixels {changed_pixels}")

    label_img = measure.label(dif_bin)
    n_labels = label_img.max()
    print(f"Answer: Number of labels: {n_labels}")

    # plt.imshow(label_img)
    # plt.title('The Label Image')
    # plt.show()


def satellite_image_transformation_and_bilinear_interpolation_f_2025():
    # Load the image
    in_dir = "data/satellite/"
    satA = io.imread(f'{in_dir}satellite_A.png', as_gray=True)
    ROI_A = io.imread(f'{in_dir}ROI_A.png', as_gray=True)

    # PLotting for checking the images
    # plt.subplot(1, 2, 1)
    # plt.imshow(satA, cmap='gray')
    # plt.title('Satellite A')
    # plt.axis('off')
    # plt.subplot(1, 2, 2)
    # plt.imshow(ROI_A, cmap='gray')
    # plt.title('ROI A')
    # plt.axis('off')
    # plt.show()
    # print("satA shape:", satA.shape)

    # Make the geometrical matrix in Homogeneous coordinates
    translation_matrix = np.array([[1, 0, 5], [0, 1, 5], [0, 0, 1]])
    phi = 10
    theta = np.deg2rad(phi)
    rotation_matrix = np.array([[np.cos(theta), -np.sin(theta), 0],
                                [np.sin(theta), np.cos(theta), 0], [0, 0, 1]])  # Clockwise rotation
    scale = 0.5
    scale_matrix = np.array([[scale, 0, 0], [0, scale, 0], [0, 0, 1]])
    # Combine the transformations into a single transformation matrix
    transformation_matrix = np.dot(scale_matrix, np.dot(rotation_matrix, translation_matrix))

    print(f"Answer: trans matrix value: {transformation_matrix[0][0]:.2f}")

    # Find the pixel positions in the ROI_A image
    # print("ROI_A shape:", ROI_A.shape)
    y, x = np.where(ROI_A == 1)
    # print("Pixel positions in ROI_A:", list(zip(x, y)))

    # Forward transformation matrix to the pixel positions
    transformed_positions = []
    for i in range(len(x)):
        position = np.dot(transformation_matrix, np.array([x[i], y[i], 1]))
        transformed_positions.append(np.floor(position[0:2]))

    print("Answer: Transformed positions:", transformed_positions)

    # Load ROI_B image
    ROI_B = io.imread(f'{in_dir}ROI_B.png', as_gray=True)

    # Finding the pixel positions in the ROI_B image
    y_B, x_B = np.where(ROI_B == 1)

    # print("Pixel positions in ROI_B:", list(zip(x_B, y_B)))

    N_points_B = len(x_B)
    # print("Number of points in ROI_B:", N_points_B)

    # Backward transformation matrix to the pixel positions
    inverse_transformation_matrix = np.linalg.inv(transformation_matrix)

    transformed_positions_B = []
    for i in range(N_points_B):
        position_B = np.dot(inverse_transformation_matrix, np.array([x_B[i], y_B[i], 1]).T)
        transformed_positions_B.append(position_B[0:2])

   # print("Backward transformed positions in ROI_B:", transformed_positions_B)

    # Calculate the bilinear interpolation for the transformed positions
    # Moving image
    satA_m = satA.copy()
    g = np.zeros(N_points_B)  # List to store bilinear interpolation values
    for i in range(N_points_B):
        # Get the transformed pixel positions
        x_t, y_t = transformed_positions_B[i]

        x_a = np.floor(x_t)
        y_a = np.floor(y_t)
        x_b = np.ceil(x_t)
        y_b = np.ceil(y_t)

        # Change in Pixels
        dx = x_t - x_a
        dy = y_t - y_a

        g[i] = satA_m[int(y_a), int(x_a)] * (1 - dx) * (1 - dy) + \
               satA_m[int(y_a), int(x_b)] * dx * (1 - dy) + \
               satA_m[int(y_b), int(x_a)] * (1 - dx) * dy + \
               satA_m[int(y_b), int(x_b)] * dx * dy

    g_std = np.std(g)
    g_mean = np.mean(g)
    print(f"Answer: Mean of bilinear interpolation values: {g_mean:.1f}")
    # print("Standard deviation of bilinear interpolation values:", g_std)


def gradient_descent_f_2025_v2():
    x_1_start = -20
    x_2_start = 20
    step_length = 0.02

    n_steps = 13
    x_1 = x_1_start
    x_2 = x_2_start
    cs = []
    x_1_s = []
    x_2_s = []
    for i in range(n_steps):
        x_1_s.append(x_1)
        x_2_s.append(x_2)

        grad_x_1 = 24 * x_1 + x_2
        grad_x_2 = x_1 + 10 * x_2

        new_x_1 = x_1 - step_length * grad_x_1
        new_x_2 = x_2 - step_length * grad_x_2
        x_1 = new_x_1
        x_2 = new_x_2
        c = 12 * x_1 * x_1 + x_1 * x_2 + 5 * x_2 * x_2
        cs.append(c)
    plt.scatter(x_1_s, x_2_s, c = cs)
    plt.plot(x_1_s, x_2_s)
    plt.xlim([-20, 20])
    plt.ylim([-20, 20])
    plt.show()


def animal_similarity_f_2025():
    # Load images
    in_dir = "data/animals/"
    imgA = io.imread(f'{in_dir}ImageA.png')
    imgB = io.imread(f'{in_dir}ImageB.png')
    imgA_gray = color.rgb2gray(imgA)
    imgB_gray = color.rgb2gray(imgB)

    meanA = np.mean(imgA_gray)
    meanB = np.mean(imgB_gray)
    stdA = np.std(imgA_gray)
    stdB = np.std(imgB_gray)
    # print("Mean of ImgA:", meanA)
    # print("Mean of ImgB:", meanB)
    # print("Standard deviation of ImgA:", stdA)
    # print("Standard deviation of ImgB:", stdB)

    imgA_gray_normalized = (imgA_gray - meanA) / stdA +50
    imgB_gray_normalized = (imgB_gray - meanB) / stdB +50

    # Take every second pixel in both images
    imgA_gray_normalized = imgA_gray_normalized[::2, ::2].flatten()
    imgB_gray_normalized = imgB_gray_normalized[::2, ::2].flatten()

    # Find the histogram of image A
    plt.hist(imgA_gray_normalized, bins=50, range=(np.min(imgA_gray_normalized), np.max(imgA_gray_normalized)), alpha=0.5, label='Image A')
    plt.show()

    # Calculate the NCC
    nominator = np.sum((imgA_gray_normalized-np.mean(imgA_gray_normalized))* (imgB_gray_normalized-np.mean(imgB_gray_normalized)))
    denominator = (np.linalg.norm(imgA_gray_normalized-np.mean(imgA_gray_normalized)) * np.linalg.norm(imgB_gray_normalized-np.mean(imgB_gray_normalized)))
    ncc = nominator / denominator
    print(f"Normalized Cross-Correlation (NCC): {ncc:.2f}")

    ncc_angle = np.arccos(ncc) * 180 / np.pi
    print(f"NCC angle in degrees: {ncc_angle:.2f}")


# Use analysic solution (page 135 of the book)
def gauss_intercept(sigma1, sigma2, mu1, mu2):
    # Calculate the intercept of two Gaussian distributions
    # with means mu1 and mu2 and standard deviations sigma1 and sigma2
    # The intercept is the point where the two distributions are equal
    # and is given by the formula:
    in_front = sigma1 ** 2 * mu2 - sigma2 ** 2 * mu1
    in_sqrt = np.sqrt(-sigma1 ** 2 * sigma2 ** 2 * (2 * mu2 * mu1 - mu2 ** 2 - 2 * sigma2 ** 2 * np.log(
        sigma2 / sigma1) - mu1 ** 2 + 2 * sigma1 ** 2 * np.log(sigma2 / sigma1)))
    in_denominator = -sigma2 ** 2 + sigma1 ** 2
    v_plus = (in_front + in_sqrt) / in_denominator
    v_minus = (in_front - in_sqrt) / in_denominator
    return [v_plus, v_minus]

def car_segmentation_f_2025():
    in_dir = "data/car_navigation/"
    # Load image A and two class images
    imgA = io.imread(f'{in_dir}CarTraining.png')
    class1 = io.imread(f'{in_dir}Class1.png')
    class2 = io.imread(f'{in_dir}Class2.png')
    BG_class = io.imread(f'{in_dir}BG.png')

    # Training examples from the different classes
    class1_data = imgA[class1 > 0]
    class2_data = imgA[class2 > 0]
    BG_class_data = imgA[BG_class > 0]

    # Calculate the mean and standard deviation of the classes
    mean_class1 = np.mean(class1_data)
    mean_class2 = np.mean(class2_data)
    mean_BG_class = np.mean(BG_class_data)
    std_class1 = np.std(class1_data)
    std_class2 = np.std(class2_data)
    std_BG_class = np.std(BG_class_data)

    v_class1_BG = gauss_intercept(std_class1, std_BG_class, mean_class1, mean_BG_class)
    intercept_class1_BG = v_class1_BG[0]  # 175
    v_class2_BG = gauss_intercept(std_class2, std_BG_class, mean_class2, mean_BG_class)
    intercept_class2_BG = v_class2_BG[1]  # 123

    # print(f"Class range 1: {intercept_class1_BG:.0f} - 255")
    print(f"Answer: Class range BG: {intercept_class2_BG:.0f} - {intercept_class1_BG:.0f}")
    # print(f"Class range 2: 0 - {intercept_class2_BG:.0f}")


    # Then set the background to 255 between the two classes
    imgA_BG = imgA.copy()
    idxs = np.where((imgA_BG >= intercept_class2_BG) & (imgA_BG < intercept_class1_BG))
    imgA_BG[idxs] = 255
    plt.imshow(imgA_BG, cmap='gray')
    plt.title('Image A with Background')
    plt.axis('off')
    plt.show()


def pedestrian_registration_f2025():
    in_dir = "data/pedestrian/"
    CarTrain_landmarks = io.imread(f"{in_dir}CarTraining_landmarks.png")
    x_center = np.zeros(8)
    y_center = np.zeros(8)
    for i in range(8):
        mask = CarTrain_landmarks == (i + 1)
        y, x = np.where(mask)
        x_center[i] = np.mean(x)
        y_center[i] = np.mean(y)

    human1_landmarks = np.c_[x_center[:4], y_center[:4]]
    human2_landmarks = np.c_[x_center[4:], y_center[4:]]
    center_of_mass_human1 = np.mean(human1_landmarks, axis=0)
    center_of_mass_human2 = np.mean(human2_landmarks, axis=0)

    # Calculate the translation vector
    translation_vector = center_of_mass_human1 - center_of_mass_human2
    print("Answer: Translation vector:", translation_vector)

    human1_transform = human1_landmarks - translation_vector
    human2_transform = human2_landmarks  # -translation_vector
    # print("Transformed landmarks:", human1_transform)

    # SUM squared distance between the landmarks
    ssd = np.sum((human1_transform - human2_transform) ** 2)
    print("Answer: SUM Squared Distance (SSD):", ssd)


def teabag_analysis_f_2025():
    in_dir = "data/tea_bags/"
    im_org = io.imread(f"{in_dir}TeaBag.png")
    # plt.imshow(im_org)
    # plt.show()

    r_comp = im_org[:, :, 0]
    g_comp = im_org[:, :, 1]
    b_comp = im_org[:, :, 2]

    segm_image = (r_comp < 100) & (g_comp < 100) & (b_comp > 100)

    plt.imshow(segm_image, cmap='gray')
    plt.title('Segmented image')
    plt.show()


def teabag_analysis_letters_f_2025():
    in_dir = "data/tea_bags/"
    im_org = io.imread(f"{in_dir}TeaBag.png")

    im_gray = color.rgb2gray(im_org)
    segm_image = im_gray > 0.6

    label_img = measure.label(segm_image)
    # image_label_overlay = label2rgb(label_img)
    # plt.imshow(image_label_overlay, cmap='gray')
    # plt.title('Labels')
    # plt.show()
    n_labels = np.max(label_img)
    print(f"Answer: found {n_labels} objects")

    region_props = measure.regionprops(label_img)
    perimeters = np.array([prop.perimeter for prop in region_props])

    n_perms = np.sum(perimeters > 15)
    print(f"Answer: BLOBS with perimeters above 15: {n_perms}")

    min_area = 100
    max_area = 1000
    min_perm = 10
    max_perm = 200
    label_img_filter = label_img.copy()
    num_objects = 0
    for region in region_props:
        p = region.perimeter
        a = region.area
        if p < min_perm or p > max_perm or a < min_area or a > max_area:
            for cords in region.coords:
                label_img_filter[cords[0], cords[1]] = 0
        else:
            num_objects = num_objects + 1

    i_perm = label_img_filter > 0

    # print("Objects: ", num_objects)
    plt.imshow(i_perm, cmap='gray')
    plt.title('Area and perimeter based segmentation')
    plt.show()



if __name__ == '__main__':
    pedestrian_registration_f2025()
    animal_similarity_f_2025()
    teabag_analysis_letters_f_2025()
    change_detection_f2025()
    satellite_image_transformation_and_bilinear_interpolation_f_2025()
    pca_on_ct_image_f_2025()
    gradient_descent_f_2025_v2()
    teabag_analysis_f_2025()
    car_segmentation_f_2025()
