
# coding: utf-8

# In[3]:

# Conditional Histograms for Boreal area of Variables conditioned to VCFF State
get_ipython().magic(u'matplotlib inline')
import matplotlib
#matplotlib.use('macosx')  # TkAgg Force mpl to use Tk backend
import matplotlib.pyplot as plt
from mpl_toolkits.basemap import Basemap, shiftgrid
import numpy as np
import seaborn as sns
import datetime
import netCDF4
import sys
import math 
import matplotlib.patches as mpatches
# needs gridspec
from matplotlib import gridspec
from matplotlib.colors import ListedColormap
import palettable 
from pandas import DataFrame 
from collections import Counter
import pandas as pd



from Boreal_Variables import bare_na_west_update as bare
from Boreal_Variables import savanna_na_west_update as savanna
from Boreal_Variables import forest_na_west_update as forest 

dataforr_na_west = np.loadtxt("/Users/Beniamino/Documents/OneDrive/MPI-M/iPython_Notebooks/Boreal_Area/boreal_na_west_database_canadianupdate.txt") # columns=["VCFF", "MAR", "MASM", "Tmin", "PZI", "FF", "GDD0", "State", "Elev", "Lon", "Lat"]
columns=["VCFF", "MAR", "Mean_Spring_SM", "Mean_Tmin", "PZI", "FF", "GDD0", "PTD", "ST", "State","Elev", "Lon", "Lat" ]
dataframe_na_west = DataFrame(dataforr_na_west[:,0:13], columns=columns[0:13])


sns.set_style(style="white")

# conditional histograms sectiondel_ind = []
del_ind = []
for i in range(len(forest[:,5])):
    if forest[i,5]==0: del_ind.append(i)
for_with_fire = np.delete(forest[:], del_ind, 0)

del_ind = []
for i in range(len(bare[:,5])):
    if bare[i,5]==0: del_ind.append(i)
bar_with_fire = np.delete(bare[:], del_ind, 0)

del_ind = []
for i in range(len(savanna[:,5])):
    if savanna[i,5]==0: del_ind.append(i)
sav_with_fire = np.delete(savanna[:], del_ind, 0)

del_ind = []
for i in range(len(dataforr_na_west[:,5])):
    if dataforr_na_west[i,5]==0: del_ind.append(i)
data_with_fire = np.delete(dataforr_na_west[:], del_ind, 0)

cmapbare = sns.light_palette("#7570b3", n_colors=5, reverse=False, as_cmap=True)
cmapsava = sns.light_palette("#d95f02", n_colors=5, reverse=False, as_cmap=True)
cmapfore = sns.light_palette("#1b9e77", n_colors=5, reverse=False, as_cmap=True)
colbare = '#7570b3'
colsava = '#d95f02'
colfore = '#1b9e77'

columns_2=["VCFF [%]", "MAR - ["+r'mm yr$^{-1}$'+"]", "Mean_Spring_SM - ["+r'mm'+"]", "Mean_Tmin - ["+r'$^\circ$C'+"]",           "PZI", "FF - ["+r'fires yr$^{-1}$'+"]", "GDD0 - ["+r'$^\circ$C yr$^{-1}$'+"]", "PTD", "ST", "State","Elev", "Lon", "Lat" ]
columns=["VCFF", "MAR", "Mean_Spring_SM", "Mean_Tmin", "PZI", "FF", "GDD0", "PTD", "ST", "State","Elev", "Lon", "Lat" ]

# create plot with conditional histograms and joint scatterplot
def scatter_rug_hist_plot(bare, savanna, forest, dataforr_ea_east, index, index_2):
    if index==1:
        binBoundaries = np.linspace(np.min(dataforr_ea_east[:,index])+0.01,np.max(dataforr_ea_east[:,index]),51)#np.linspace(0,1000,51)
    else:
        binBoundaries = np.linspace(np.min(dataforr_ea_east[:,index])+0.01,np.max(dataforr_ea_east[:,index]),51)
    
    if index_2==1:
        binBoundaries2 = np.linspace(np.min(dataforr_ea_east[:,index_2]),np.max(dataforr_ea_east[:,index_2]),51)#np.linspace(0,1000,51)
    else:
        binBoundaries2 = np.linspace(np.min(dataforr_ea_east[:,index_2]),np.max(dataforr_ea_east[:,index_2]),51)
    xmin, ymin = np.min(binBoundaries), np.min(binBoundaries2)
    xmax, ymax = np.max(binBoundaries), np.max(binBoundaries2)

    xmax, xmin = tuple(np.array([xmax, xmin]) + 0.25*(xmax - xmin)*np.array([1, -1]))
    ymax, ymin = tuple(np.array([ymax, ymin]) + 0.25*(ymax - ymin)*np.array([1, -1]))
    #Define grid for subplots
    gs = gridspec.GridSpec(2, 2, width_ratios=[3, 1], height_ratios = [1, 4])

        
    #Create scatter plot
    fig = plt.figure(figsize=(12,8))
    ax = plt.subplot(gs[1, 0],frameon = False,xticks=[],yticks=[], xlim = (xmin, xmax), ylim = (ymin, ymax))
    
    cax = ax.scatter(bare[:,index], bare[:,index_2], alpha=.4, color = colbare)
    sns.rugplot(bare[:,index], color=colbare, ax=ax, height=0.02)
    sns.rugplot(bare[:,index_2], color=colbare,vertical=True, ax=ax, height=0.02);
    
    cax1 = ax.scatter(savanna[:,index], savanna[:,index_2], color = colsava, alpha=.4)
    sns.rugplot(savanna[:,index], color=colsava, ax=ax, height=0.02)
    sns.rugplot(savanna[:,index_2], color=colsava,vertical=True, ax=ax, height=0.02);
    
    cax2 = ax.scatter(forest[:,index], forest[:,index_2], color=colfore, alpha=.4)
    sns.rugplot(forest[:,index], color=colfore, ax=ax, height=0.02)
    sns.rugplot(forest[:,index_2], color=colfore,vertical=True, ax=ax, height=0.02);
    
    
    
    plt.xlabel(columns_2[index])
    plt.ylabel(columns_2[index_2])

    sns.despine()
    #Turn off all axes
    #_=ax.axis('off')

    #Create Y-marginal (right)
    axr = plt.subplot(gs[1, 1], xticks=[], frameon = False,  ylim = (ymin, ymax))
    axr.hist([bare[:,index_2], savanna[:,index_2], forest[:,index_2]], label=["treeless", "open woodland", "forest"]             ,stacked=True,bins=binBoundaries2, alpha=0.8,normed=True, orientation='horizontal', color = [colbare, colsava, colfore])#bare[:,3], color = '#5673E0', orientation = 'horizontal', normed = True)

    #Create X-marginal (top)
    axt = plt.subplot(gs[0,0], frameon = False,  yticks = [], xlim = (xmin, xmax))
    axt.set_title('Western North America '+str(columns[index])+' vs '+str(columns[index_2]))

    axt.hist([bare[:,index], savanna[:,index], forest[:,index]], label=["treeless", "open woodland", "forest"],stacked=True,             bins=binBoundaries, alpha=0.8,normed=True, color = [colbare, colsava, colfore])#hist(bare[:,1], color = '#5673E0', normed = True)

    #Create ax for legend with data that will not appear
    axl = plt.subplot(gs[0, 1],frameon = False,  xticks=[],yticks = [], xlim=(xmin-100,xmin-50))
    axl.hist([bare[0:2,index], savanna[0:2,index], forest[0:2,index]], label=["treeless", "open woodland", "forest"],              alpha=0.8,normed=True, color = [colbare, colsava, colfore])
    leg = axl.legend( loc="lower left", ncol = 1, labels=['Treeless','Open woodland','Forest'])

    #plt.savefig('/Users/Beniamino/Documents/OneDrive/MPI-M/iPython_Notebooks/Pic/Meeting_Victor/Conditional_Histograms/Eurasia/NEA_East/bor_ea_east_joint_scatter_'+str(columns[index])+'_'+str(columns[index_2])+'.pdf',dpi=300, transparent=False)
    plt.show()

def kde_rug_hist_plot(bare, savanna, forest, dataforr_ea_east, index, index_2, fsize = 16, figsize = (12,8)):
    # Define limits and bins for histograms 
    if index==1:
        binBoundaries = np.linspace(np.min(dataforr_ea_east[:,index])+0.01,np.max(dataforr_ea_east[:,index]),51)#np.linspace(0,1000,51)
    else:
        binBoundaries = np.linspace(np.min(dataforr_ea_east[:,index])+0.01,np.max(dataforr_ea_east[:,index]),51)
    
    if index_2==1:
        binBoundaries2 = np.linspace(np.min(dataforr_ea_east[:,index_2]),np.max(dataforr_ea_east[:,index_2]),51)#np.linspace(0,1000,51)
    else:
        binBoundaries2 = np.linspace(np.min(dataforr_ea_east[:,index_2]),np.max(dataforr_ea_east[:,index_2]),51)
    xmin, ymin = np.min(binBoundaries), np.min(binBoundaries2)
    xmax, ymax = np.max(binBoundaries), np.max(binBoundaries2)
    xmax, xmin = tuple(np.array([xmax, xmin]) + 0.25*(xmax - xmin)*np.array([1, -1]))
    ymax, ymin = tuple(np.array([ymax, ymin]) + 0.25*(ymax - ymin)*np.array([1, -1]))

    #Define grid for subplots
    gs = gridspec.GridSpec(2, 2, width_ratios=[3, 1], height_ratios = [1, 4])

    #Create kdeplot with rugplot

    fig = plt.figure(figsize=figsize)
    ax = plt.subplot(gs[1, 0],frameon = False,xticks=[],yticks=[], xlim = (xmin, xmax), ylim = (ymin, ymax))
    

    sns.kdeplot(bare[:,index], bare[:,index_2], ax=ax,cmap=cmapbare, linewidths = 2)
    sns.rugplot(bare[:,index], color='#7570b3', ax=ax, height=0.02, linewidth=0.2)
    sns.rugplot(bare[:,index_2], color='#7570b3',vertical=True, ax=ax, height=0.02, linewidth=0.2);
    
    sns.kdeplot(savanna[:,index], savanna[:,index_2], ax=ax,cmap=cmapsava, linewidths = 2)
    sns.rugplot(savanna[:,index], color='#d95f02', ax=ax, height=0.02, linewidth=0.2)
    sns.rugplot(savanna[:,index_2], color='#d95f02',vertical=True, ax=ax, height=0.02, linewidth=0.2);

    sns.kdeplot(forest[:,index], forest[:,index_2], ax=ax,cmap=cmapfore, linewidths = 2)
    sns.rugplot(forest[:,index], color='#1b9e77', ax=ax, height=0.02, linewidth=0.2)
    sns.rugplot(forest[:,index_2], color='#1b9e77',vertical=True, ax=ax, height=0.02, linewidth=0.2);
    
    plt.xlabel(columns_2[index], fontsize=fsize)
    plt.ylabel(columns_2[index_2], fontsize=fsize)

    sns.despine()
#Turn off all axes
#_=ax.axis('off')

#Create Y-marginal (right)
    axr = plt.subplot(gs[1, 1], xticks=[], frameon = False,  ylim = (ymin, ymax))
    hist_Y = axr.hist([bare[:,index_2], savanna[:,index_2], forest[:,index_2]], label=["treeless", "open woodland", "forest"],             stacked=True,bins=binBoundaries2, alpha=0.8,normed=False, orientation='horizontal', color = ['#7570b3', '#d95f02', '#1b9e77'])#bare[:,3], color = '#5673E0', orientation = 'horizontal', normed = True)
    mi, qua, mid, triqua, ma = 0, (np.min(hist_Y[0])+np.max(hist_Y[0]))/4,(np.min(hist_Y[0])+np.max(hist_Y[0]))/2,(np.min(hist_Y[0])+np.max(hist_Y[0]))/4+(np.min(hist_Y[0])+np.max(hist_Y[0]))/2, np.max(hist_Y[0])
    axr.set_xticks([int(mi), int(qua), int(mid), int(triqua), int(ma)])
    plt.xticks(fontsize = fsize)
    plt.yticks(fontsize = fsize)


#Create X-marginal (top)
    axt = plt.subplot(gs[0,0], frameon = False, yticks=[], xlim = (xmin, xmax)) #,  yticks = []
    axt.set_title('Western North America '+str(columns[index])+' vs '+str(columns[index_2]), fontsize=fsize)
    hist_X = axt.hist([bare[:,index], savanna[:,index], forest[:,index]], label=["treeless", "open woodland", "forest"],             stacked=True,bins=binBoundaries, alpha=0.8,normed=False, color = ['#7570b3', '#d95f02', '#1b9e77'])#hist(bare[:,1], color = '#5673E0', normed = True)
    mi, qua, mid, triqua, ma = 0, (np.min(hist_X[0])+np.max(hist_X[0]))/4,(np.min(hist_X[0])+np.max(hist_X[0]))/2,(np.min(hist_X[0])+np.max(hist_X[0]))/4+(np.min(hist_X[0])+np.max(hist_X[0]))/2, np.max(hist_X[0])
    axt.set_yticks([int(mi), int(qua), int(mid), int(triqua), int(ma)])
    plt.xticks(fontsize = fsize)
    plt.yticks(fontsize = fsize)
    

#Create ax for legend with data that will not appear

    axl = plt.subplot(gs[0, 1],frameon = False,  xticks=[],yticks = [], xlim=(xmin-100,xmin-50))
    axl.hist([bare[0:2,index], savanna[0:2,index], forest[0:2,index]], label=["grassland", "savanna", "forest"],             alpha=0.8,normed=True, color = ['#7570b3', '#d95f02', '#1b9e77'])
    leg = axl.legend( loc="lower left", ncol = 1, labels=['Treeless','Open woodland','Forest'] , fontsize=fsize)



    #plt.savefig('/Users/Beniamino/Documents/OneDrive/MPI-M/My_Papers/First_Paper_First_Draft/Conditional/bor_na_west_joint_kde_'+str(columns[index])+'_'+str(columns[index_2])+'.pdf',\
    #            dpi=300, transparent=True)
    plt.show()
    return
    #plt.close()


# In[4]:

kde_rug_hist_plot(bare,savanna,forest,dataforr_na_west,2,6, fsize = 20, figsize = (12,8))


# In[3]:

kde_rug_hist_plot(bar_with_fire,sav_with_fire,for_with_fire,data_with_fire,3,5)


# In[ ]:



