#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue May  3 14:00:23 2022

@author: maltejensen
"""
'''
Load data
'''
import pandas as pd
import numpy as np
import os 
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from termcolor import colored
from tqdm import tqdm
%matplotlib qt

    

file_root = ''

df = pd.read_csv(file_root)

#%%
'''
Onehot encode data before analysis
'''

AA = ['A','R','N','D','C','Q','E','G','H','I','L','K','M','F','P','S','T','W','Y','V']

AA_map = {AA[i]: i for i in range(len(AA))}
AA_map_rev = {i: AA[i] for i in range(len(AA))}


# map of properties
prop_map = {
    0: ['G','A','V','L','I','P'],       # Aliphatic side-chains
    1: ['S', 'T'],                      # Polar neutral side-chains
    2: ['N', 'Q'],                      # Amide side-chains
    3: ['C', 'M'],                      # Sulfur-containing side-chains
    4: ['F', 'Y', 'W'],                 # Aromatic side-chains
    5: ['D', 'E'],                      # Anionic side-chains
    6: ['H', 'K', 'R']                  # Cationic side-chains
    }

prop_map_rev = {
    0: 'Aliphatic',
    1: 'Polar',
    2: 'Amide',
    3: 'Sulfur',
    4: 'Aromatic',
    5: 'Anionic',
    6: 'Cationic'
    }

prop_map = {
    0: ['G','A','V','L','I','P'],       # Aliphatic side-chains
    1: ['S', 'T'],                      # Polar neutral side-chains
    2: ['N', 'Q'],                      # Amide side-chains
    3: ['C', 'M'],                      # Sulfur-containing side-chains
    4: ['F', 'Y', 'W'],                 # Aromatic side-chains
    5: ['D', 'E'],                      # Anionic side-chains
    6: ['H', 'K', 'R']                  # Cationic side-chains
    }

prop_map_rev_groups = {
    'Aliphatic': ['G','A','V','L','I','P'],       # Aliphatic side-chains
    'Polar': ['S', 'T'],                      # Polar neutral side-chains
    'Amide': ['N', 'Q'],                      # Amide side-chains
    'Sulfur': ['C', 'M'],                      # Sulfur-containing side-chains
    'Aromatic': ['F', 'Y', 'W'],                 # Aromatic side-chains
    'Anionic': ['D', 'E'],                      # Anionic side-chains
    'Cationic': ['H', 'K', 'R']                  # Cationic side-chains
    }

def OneHotEncodeAA(seq, AA_map, num_seq, type_):
    
    seq = seq.strip('\'')
    crop_seq = seq[:num_seq]
    
    num_col = len(AA_map)
    
    one_hot = np.zeros(num_col*len(crop_seq), dtype=np.int8)
    
    if type_ == 'AA':
        for i in range(len(crop_seq)):
            one_hot[i*num_col + AA_map[crop_seq[i]]] = 1
    elif type_ == 'prop':
        for i in range(len(crop_seq)):
            for key in AA_map.keys():
                if crop_seq[i] in AA_map[key]:
                    one_hot[i*num_col + key] = 1
                    break
    return one_hot


# embedding type
embedding_type = 'AA' # AA or prop

# set number of samples to include and how long the main sequence is
num_samples = len(df) #10000
num_seq = 12

if embedding_type == 'AA':
    seq_map = AA_map
    seq_map_rev = AA_map_rev
    type_ = 'AA'
elif embedding_type == 'prop':
    seq_map = prop_map
    seq_map_rev = prop_map_rev
    type_ = 'prop'

num_groups = len(seq_map)
one_hot = []

for index, row in df.iterrows():
    one_hot.append(OneHotEncodeAA(row['AA'], seq_map, num_seq=num_seq, type_=type_))
    
    if index == num_samples-1:
        break
    
one_hot = np.vstack(one_hot)
y = df['Relative (%)'][:num_samples]



#%%
'''
Heatmap of amino acids 
'''
from mpl_toolkits.axes_grid1 import make_axes_locatable

# define baseline
num_codons = {
    'A': 2,
    'R': 3,
    'N': 1,
    'D': 1,
    'C': 1,
    'Q': 2,
    'E': 1,
    'G': 2,
    'H': 1,
    'I': 1,
    'L': 2,
    'K': 1,
    'M': 1,
    'F': 1,
    'P': 2,
    'S': 3,
    'T': 2,
    'W': 1,
    'Y': 1,
    'V': 2
    }

# Order: 'A', 'R', 'N', 'D', 'C', 'Q', 'E', 'G', 'H', 'I', 'L', 'K', 'M', 'F', 'P', 'S', 'T', 'W', 'Y', 'V']


baseline_freq = np.array([num_codons[aa] for aa in AA]).reshape(-1,1)
baseline_freq = baseline_freq/baseline_freq.sum()

AA_heatmap = []

for position in range(num_seq):
    AA_count = one_hot[:,len(seq_map_rev)*position:len(seq_map_rev)*(position+1)].sum(axis=0)/one_hot.shape[0]
    AA_heatmap.append(AA_count.reshape(-1,1))
    
AA_heatmap = np.hstack(AA_heatmap)-baseline_freq

fig, axx = plt.subplots()
abs_max = np.max(np.abs(AA_heatmap))
img = axx.imshow(AA_heatmap, cmap='bwr', vmin=-abs_max, vmax=abs_max)


axx.set_yticks(np.arange(len(seq_map_rev)))
axx.set_yticklabels(list(seq_map_rev.values()))
axx.set_xticks(np.arange(num_seq))
axx.set_xticklabels(np.arange(1, num_seq+1))

axx.set_xlabel('Position')
axx.set_ylabel('Amino acid')
axx.set_title('Heat map of amino acid position')

divider = make_axes_locatable(axx)
cax = divider.append_axes("right", size="5%", pad=0.1)
cbar = plt.colorbar(img, cax=cax, label='Fraction (difference from baseline)')

plt.show()

