#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Created on Tue May 3 14:00:23 2022 @author: maltejensen """ ''' Load data ''' import pandas as pd import numpy as np import os import matplotlib.pyplot as plt from sklearn.decomposition import PCA from termcolor import colored from tqdm import tqdm %matplotlib qt file_root = '' df = pd.read_csv(file_root) #%% ''' Onehot encode data before analysis ''' AA = ['A','R','N','D','C','Q','E','G','H','I','L','K','M','F','P','S','T','W','Y','V'] AA_map = {AA[i]: i for i in range(len(AA))} AA_map_rev = {i: AA[i] for i in range(len(AA))} # map of properties prop_map = { 0: ['G','A','V','L','I','P'], # Aliphatic side-chains 1: ['S', 'T'], # Polar neutral side-chains 2: ['N', 'Q'], # Amide side-chains 3: ['C', 'M'], # Sulfur-containing side-chains 4: ['F', 'Y', 'W'], # Aromatic side-chains 5: ['D', 'E'], # Anionic side-chains 6: ['H', 'K', 'R'] # Cationic side-chains } prop_map_rev = { 0: 'Aliphatic', 1: 'Polar', 2: 'Amide', 3: 'Sulfur', 4: 'Aromatic', 5: 'Anionic', 6: 'Cationic' } prop_map = { 0: ['G','A','V','L','I','P'], # Aliphatic side-chains 1: ['S', 'T'], # Polar neutral side-chains 2: ['N', 'Q'], # Amide side-chains 3: ['C', 'M'], # Sulfur-containing side-chains 4: ['F', 'Y', 'W'], # Aromatic side-chains 5: ['D', 'E'], # Anionic side-chains 6: ['H', 'K', 'R'] # Cationic side-chains } prop_map_rev_groups = { 'Aliphatic': ['G','A','V','L','I','P'], # Aliphatic side-chains 'Polar': ['S', 'T'], # Polar neutral side-chains 'Amide': ['N', 'Q'], # Amide side-chains 'Sulfur': ['C', 'M'], # Sulfur-containing side-chains 'Aromatic': ['F', 'Y', 'W'], # Aromatic side-chains 'Anionic': ['D', 'E'], # Anionic side-chains 'Cationic': ['H', 'K', 'R'] # Cationic side-chains } def OneHotEncodeAA(seq, AA_map, num_seq, type_): seq = seq.strip('\'') crop_seq = seq[:num_seq] num_col = len(AA_map) one_hot = np.zeros(num_col*len(crop_seq), dtype=np.int8) if type_ == 'AA': for i in range(len(crop_seq)): one_hot[i*num_col + AA_map[crop_seq[i]]] = 1 elif type_ == 'prop': for i in range(len(crop_seq)): for key in AA_map.keys(): if crop_seq[i] in AA_map[key]: one_hot[i*num_col + key] = 1 break return one_hot # embedding type embedding_type = 'AA' # AA or prop # set number of samples to include and how long the main sequence is num_samples = len(df) #10000 num_seq = 12 if embedding_type == 'AA': seq_map = AA_map seq_map_rev = AA_map_rev type_ = 'AA' elif embedding_type == 'prop': seq_map = prop_map seq_map_rev = prop_map_rev type_ = 'prop' num_groups = len(seq_map) one_hot = [] for index, row in df.iterrows(): one_hot.append(OneHotEncodeAA(row['AA'], seq_map, num_seq=num_seq, type_=type_)) if index == num_samples-1: break one_hot = np.vstack(one_hot) y = df['Relative (%)'][:num_samples] #%% ''' Heatmap of amino acids ''' from mpl_toolkits.axes_grid1 import make_axes_locatable # define baseline num_codons = { 'A': 2, 'R': 3, 'N': 1, 'D': 1, 'C': 1, 'Q': 2, 'E': 1, 'G': 2, 'H': 1, 'I': 1, 'L': 2, 'K': 1, 'M': 1, 'F': 1, 'P': 2, 'S': 3, 'T': 2, 'W': 1, 'Y': 1, 'V': 2 } # Order: 'A', 'R', 'N', 'D', 'C', 'Q', 'E', 'G', 'H', 'I', 'L', 'K', 'M', 'F', 'P', 'S', 'T', 'W', 'Y', 'V'] baseline_freq = np.array([num_codons[aa] for aa in AA]).reshape(-1,1) baseline_freq = baseline_freq/baseline_freq.sum() AA_heatmap = [] for position in range(num_seq): AA_count = one_hot[:,len(seq_map_rev)*position:len(seq_map_rev)*(position+1)].sum(axis=0)/one_hot.shape[0] AA_heatmap.append(AA_count.reshape(-1,1)) AA_heatmap = np.hstack(AA_heatmap)-baseline_freq fig, axx = plt.subplots() abs_max = np.max(np.abs(AA_heatmap)) img = axx.imshow(AA_heatmap, cmap='bwr', vmin=-abs_max, vmax=abs_max) axx.set_yticks(np.arange(len(seq_map_rev))) axx.set_yticklabels(list(seq_map_rev.values())) axx.set_xticks(np.arange(num_seq)) axx.set_xticklabels(np.arange(1, num_seq+1)) axx.set_xlabel('Position') axx.set_ylabel('Amino acid') axx.set_title('Heat map of amino acid position') divider = make_axes_locatable(axx) cax = divider.append_axes("right", size="5%", pad=0.1) cbar = plt.colorbar(img, cax=cax, label='Fraction (difference from baseline)') plt.show()