import numpy as np
import matplotlib.pyplot as plt
import sys
import os
sys.path.append('../../src/')
from plot_utils import *
from tqdm import tqdm
%matplotlib inline
plt.rcParams['font.sans-serif'] = "Arial"
plt.rcParams['figure.facecolor'] = "w"
plt.rcParams['figure.autolayout'] = True
plt.rcParams['pdf.fonttype'] = 3
# Making the relevant figure directories that we want
main_figdir = '../../plots/schematic_plots/'
supp_figdir = '../../plots/supp_figs/schematic_plots/'
os.makedirs(main_figdir, exist_ok=True)
os.makedirs(supp_figdir, exist_ok=True)
f, ax = plt.subplots(1,1,figsize=(3,3))
# Plotting a serial sampling
ax.plot([0.3,0.4], [0.0, 0.6], lw=4, color="black")
ax.plot([0.45,0.4], [0.4,0.6], lw=4, color="black")
# ax.plot([0.2,0.6], [0.6,0.6], lw=5, color="black")
# ax.plot([0.4,0.4], [0.6,1.0], lw=5, color="black")
# Defining variables of interest
ax.annotate(text='', xy=(0.51,0), xytext=(0.51,0.4),
arrowprops=dict(arrowstyle='|-|', color='blue', shrinkA=0, shrinkB=0))
ax.annotate(text='', xy=(0.51,0.4), xytext=(0.51,0.6),
arrowprops=dict(arrowstyle='|-|', color='green', shrinkA=0, shrinkB=0))
ax.annotate(text='', xy=(0.22,0.0), xytext=(0.22,0.6),
arrowprops=dict(arrowstyle='|-|', color='orange', shrinkA=0, shrinkB=0))
ax.text(0.52, 0.2, r'$t_a$', color='blue', fontsize=14)
ax.text(0.52, 0.5, r'$T_A$', color='green', fontsize=14)
ax.text(0.1, 0.3, r'$H_A$', color='orange', fontsize=14)
x = 0.5
ax.plot([0.3+x,0.4+x], [0.0, 0.7], lw=4, color="black")
ax.plot([0.45+x,0.4+x], [0.4,0.7], lw=4, color="black")
# Defining variables of interest
ax.annotate(text='', xy=(0.51+x,0), xytext=(0.51+x,0.4),
arrowprops=dict(arrowstyle='|-|', color='blue', shrinkA=0, shrinkB=0))
ax.annotate(text='', xy=(0.51+x,0.4), xytext=(0.51+x,0.7),
arrowprops=dict(arrowstyle='|-|', color='green', shrinkA=0, shrinkB=0))
ax.annotate(text='', xy=(0.22+x,0.0), xytext=(0.22+x,0.7),
arrowprops=dict(arrowstyle='|-|', color='orange', shrinkA=0, shrinkB=0))
ax.text(0.52+x, 0.2, r'$t_a$', color='blue', fontsize=14)
ax.text(0.52+x, 0.55, r'$T_B$', color='green', fontsize=14)
ax.text(0.1+x, 0.35, r'$H_B$', color='orange', fontsize=14)
# Define legend
ax.set_xlim(0.1,1.25)
ax.set_xticks([])
ax.set_yticks([])
ax.axis('off')
plt.savefig(supp_figdir + 'two_samples_two_loci.pdf', dpi=300, bbox_inches='tight')
# Plotting Divergence Model
f, ax = plt.subplots(1,1,figsize=(3,3))
# Plotting the
ax.plot([0.25,0.35], [0.0,0.6], lw=2, color='black', linestyle='--')
ax.plot([0.3,0.4], [0.0,0.6], lw=2, color='black', linestyle='--')
ax.plot([0.45,0.4], [0.3,0.6], lw=2, color='black', linestyle='--')
ax.plot([0.5,0.45], [0.3,0.6], lw=2, color='black', linestyle='--')
ax.plot([0.35,0.35], [0.6,0.8], lw=2, color='black', linestyle='--')
ax.plot([0.45,0.45], [0.6,0.8], lw=2, color='black', linestyle='--')
ax.plot([0.275,0.4], [0.0,0.75], lw=5, color="black")
ax.plot([0.475,0.4], [0.3,0.75], lw=5, color="black")
# # Defining variables of interest
ax.annotate(text='', xy=(0.52,0), xytext=(0.52,0.3),
arrowprops=dict(arrowstyle='|-|', color='blue', shrinkA=0, shrinkB=0))
ax.annotate(text='', xy=(0.52,0.6), xytext=(0.52,0.75),
arrowprops=dict(arrowstyle='|-|', color='green', shrinkA=0, shrinkB=0))
ax.annotate(text='', xy=(0.52,0.3), xytext=(0.52,0.6),
arrowprops=dict(arrowstyle='|-|', color='orange', shrinkA=0, shrinkB=0))
ax.annotate(text='', xy=(0.2,0), xytext=(0.2,0.75),
arrowprops=dict(arrowstyle='|-|', color='black', shrinkA=0, shrinkB=0))
#
# ax.axhline(y=0.3, color='blue', linestyle='--')
#
ax.text(0.53, 0.15, r'$t_a$', color='blue', fontsize=14)
ax.text(0.53, 0.65, r'$T$', color='green', fontsize=14)
ax.text(0.53, 0.45, r'$t_{div}$', color='orange', fontsize=14)
ax.text(0.21, 0.4, r'$H$', color='black', fontsize=14)
# Define legend
ax.set_xlim(0.1,0.7)
ax.set_xticks([])
ax.set_yticks([])
ax.axis('off')
plt.savefig(supp_figdir + 'schematic_two_samples_divergence.pdf', dpi=300, bbox_inches='tight')