#!/bin/python3
import os
import sys
import re
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from matplotlib import dates
from matplotlib import rc
from datetime import datetime
from datetime import date, timedelta
from numpy import copy, nanmean, array, pi
import pylab 
import math
import argparse
import pandas as pd
import argparse
import pdb
import datetime as dt
import matplotlib.dates as mdates

def read_nutrients_csv(stationID):
   # convert to lower case
   stationid = stationID.lower()
   df = pd.read_csv("/users/rsg/observatory/nutrients/"+stationid+"_nutrients.csv", skiprows=[1], encoding = "ISO-8859-1", dtype={'NITRITE': np.float64, 'AMMONIA': np.float64, 'NITRATE+NIT': np.float64}, na_values=['<','N','N,','n'])
   
   return df

# Function to create a moving average
def moving_average(a, n=3) :
    ret = np.cumsum(a, dtype=float)
    ret[n:] = ret[n:] - ret[:-n]
    return ret[n - 1:] / n

def convert_as_needed(ts):
    dat_time = ''
    try:
        # parse strings
        dat_time = datetime.strptime(ts, '%d/%m/%Y')
    except:
        pass

    return dat_time

# Function to generate a list of dates for alternate months
def generate_alternate_months(start_date, end_date):
    dates = pd.date_range(start=start_date, end=end_date, freq='MS')  # Get start of every month
    alternate_dates = []
    for i, date in enumerate(dates):
        if i % 2 == 0:  # Include only alternate months
            alternate_dates.append(date)
    return alternate_dates


def main():
   
   parser = argparse.ArgumentParser(
      description=__doc__,
      formatter_class=argparse.RawDescriptionHelpFormatter
   )

   #parser.add_argument("-a", "--a",  action='store_true',  default=False, help="Some as yet undefined option")
   parser.add_argument("-s", "--station", type=str, default='', help="--station <L4 or E1>")
   
   args = parser.parse_args()
   
   if args.station:
      stationID = args.station
   else:
      print("Set station ID: --station L4 or E1")
      sys.exit(1)

   df = read_nutrients_csv(stationID)
   
   station = df['Station']
   sdate = df['Date']
   depth = df['Depth']
   nitrite = df['NITRITE']
   if stationID == "L4":
      nitrate_nitrite = df['NITRATE+NIT']
   else:
      nitrate_nitrite = df['NITRATE+NITRITE']
   ammonia = df['AMMONIA']
   silicate = df['SILICATE']
   phosphate = df['PHOSPHATE']
   frozen = df['Frozen']
   filtered = df['Filtered']
   
   nitrate = nitrate_nitrite - nitrite

   index = np.where(depth == 0.0)
   index_10 = np.where(depth == 10.0)
   index_20 = np.where(depth == 20.0)
   index_25 = np.where(depth == 25.0)
   index_30 = np.where(depth == 30.0)
   index_40 = np.where(depth == 40.0)
   index_50 = np.where(depth == 50.0)
   index_60 = np.where(depth == 60.0)
   
   xinches = 10.0
   yinches = 3.5

   datetimeList = [convert_as_needed(row) for row in sdate[index[0]]]
   datetimeList_10 = [convert_as_needed(row) for row in sdate[index_10[0]]]
   datetimeList_20 = [convert_as_needed(row) for row in sdate[index_20[0]]]
   datetimeList_25 = [convert_as_needed(row) for row in sdate[index_25[0]]]
   datetimeList_30 = [convert_as_needed(row) for row in sdate[index_30[0]]]
   datetimeList_40 = [convert_as_needed(row) for row in sdate[index_40[0]]]
   datetimeList_50 = [convert_as_needed(row) for row in sdate[index_50[0]]]
   datetimeList_60 = [convert_as_needed(row) for row in sdate[index_60[0]]]
   last_data = datetimeList[len(datetimeList) - 1]
   alternate_dates = generate_alternate_months(last_data - dt.timedelta(days=365), last_data)
   
   # Nitrate
   fig = plt.figure()
   fig.set_size_inches(xinches,yinches)
   ax1 = fig.add_subplot(121)
   ax1.set_title(stationID+" nitrate")
   ax1.set_xlabel("Year")
   ax1.xaxis.set_major_locator(dates.YearLocator(5))
   ax1.xaxis.set_minor_locator(dates.YearLocator())
   ax1.set_ylabel(r"Nitrate ($\mu$M)")
   pylab.ylim([0,20])
   plt.tight_layout()   
   plt.plot(datetimeList, nitrate[index[0]])
   plt.scatter(datetimeList, nitrate[index[0]], s=5, c='black')
   plt.grid(linestyle="--", color='gray')
   plt.grid(linestyle="--", color='gray', which='minor')

   ax2 = fig.add_subplot(122)
   ax2.set_title(stationID+" nitrate (Last 12 months)")
   ax2.set_xlabel("Year (MM/YY)")
   ax2.xaxis.set_major_locator(mdates.MonthLocator(bymonth=[i.month for i in alternate_dates]))
   ax2.xaxis.set_minor_locator(dates.MonthLocator())
   ax2.xaxis.set_major_formatter(mdates.DateFormatter('%m/%y'))
  
   pylab.ylim([0,20])
   ax2.set_xlim(last_data - dt.timedelta(days=365), last_data)
   plt.grid(linestyle="--", color='gray')
   plt.grid(linestyle="--", color='gray', which='minor')
   plt.tight_layout()   
   plt.plot(datetimeList, nitrate[index[0]], label='0 m')
   plt.scatter(datetimeList, nitrate[index[0]], s=5, c='black')
   plt.plot(datetimeList_10, nitrate[index_10[0]], label='10 m')
   plt.scatter(datetimeList_10, nitrate[index_10[0]], s=5, c='black')
   plt.plot(datetimeList_20, nitrate[index_20[0]], label='20 m')
   plt.scatter(datetimeList_20, nitrate[index_20[0]], s=5, c='black')
   plt.plot(datetimeList_25, nitrate[index_25[0]], label='25 m')
   plt.scatter(datetimeList_25, nitrate[index_25[0]], s=5, c='black')
   plt.plot(datetimeList_30, nitrate[index_30[0]], label='30 m')
   plt.scatter(datetimeList_30, nitrate[index_30[0]], s=5, c='black')
   plt.plot(datetimeList_40, nitrate[index_40[0]], label='40 m')
   plt.scatter(datetimeList_40, nitrate[index_40[0]], s=5, c='black')
   plt.plot(datetimeList_50, nitrate[index_50[0]], label='50 m')
   plt.scatter(datetimeList_50, nitrate[index_50[0]], s=5, c='black')
   plt.plot(datetimeList_60, nitrate[index_60[0]], label='60 m')
   plt.scatter(datetimeList_60, nitrate[index_60[0]], s=5, c='black')
   plt.legend()

   pylab.savefig('/users/rsg/observatory/nutrients/nitrate_'+stationID+'.png', dpi=100)

   # Nitrate
   fig = plt.figure()
   fig.set_size_inches(xinches,yinches)
   ax1 = fig.add_subplot(121)
   ax1.set_title(stationID+" nitrite")
   ax1.set_xlabel("Year")
   ax1.xaxis.set_major_locator(dates.YearLocator(5))
   ax1.xaxis.set_minor_locator(dates.YearLocator())
   ax1.set_ylabel(r"Nitrite ($\mu$M)")
   pylab.ylim([0,3])
   plt.tight_layout()   
   plt.plot(datetimeList, nitrite[index[0]], c='red')
   plt.scatter(datetimeList, nitrite[index[0]], s=5, c='black')
   plt.grid(linestyle="--", color='gray')
   plt.grid(linestyle="--", color='gray', which='minor')

   ax2 = fig.add_subplot(122)
   ax2.set_title(stationID+" nitrite (Last 12 months)")
   ax2.set_xlabel("Year (MM/YY)")
   ax2.xaxis.set_major_locator(mdates.MonthLocator(bymonth=[i.month for i in alternate_dates]))
   ax2.xaxis.set_minor_locator(dates.MonthLocator())
   ax2.xaxis.set_major_formatter(mdates.DateFormatter('%m/%y'))

   pylab.ylim([0,3])
   ax2.set_xlim(last_data - dt.timedelta(days=365), last_data)
   plt.grid(linestyle="--", color='gray')
   plt.grid(linestyle="--", color='gray', which='minor')
   plt.tight_layout()   
   plt.plot(datetimeList, nitrite[index[0]], label='0 m')
   plt.scatter(datetimeList, nitrite[index[0]], s=5, c='black')
   plt.plot(datetimeList_10, nitrite[index_10[0]], label='10 m')
   plt.scatter(datetimeList_10, nitrite[index_10[0]], s=5, c='black')
   plt.plot(datetimeList_20, nitrite[index_20[0]], label='20 m')
   plt.scatter(datetimeList_20, nitrite[index_20[0]], s=5, c='black')
   plt.plot(datetimeList_25, nitrite[index_25[0]], label='25 m')
   plt.scatter(datetimeList_25, nitrite[index_25[0]], s=5, c='black')
   plt.plot(datetimeList_30, nitrite[index_30[0]], label='30 m')
   plt.scatter(datetimeList_30, nitrite[index_30[0]], s=5, c='black')
   plt.plot(datetimeList_40, nitrite[index_40[0]], label='40 m')
   plt.scatter(datetimeList_40, nitrite[index_40[0]], s=5, c='black')
   plt.plot(datetimeList_50, nitrite[index_50[0]], label='50 m')
   plt.scatter(datetimeList_50, nitrite[index_50[0]], s=5, c='black')
   plt.plot(datetimeList_60, nitrite[index_60[0]], label='60 m')
   plt.scatter(datetimeList_60, nitrite[index_60[0]], s=5, c='black')
   plt.legend()
   pylab.savefig('/users/rsg/observatory/nutrients/nitrite_'+stationID+'.png', dpi=100)

   # Phosphate
   fig = plt.figure()
   fig.set_size_inches(xinches,yinches)
   ax1 = fig.add_subplot(121)
   ax1.set_title(stationID+" phosphate")
   ax1.set_xlabel("Year")
   ax1.xaxis.set_major_locator(dates.YearLocator(5))
   ax1.xaxis.set_minor_locator(dates.YearLocator())
   ax1.set_ylabel(r"Phosphate ($\mu$M)")
   pylab.ylim([0,2])
   plt.tight_layout()   
   plt.plot(datetimeList, phosphate[index[0]], c='green')
   plt.scatter(datetimeList, phosphate[index[0]], s=5, c='black')
   plt.grid(linestyle="--", color='gray')
   plt.grid(linestyle="--", color='gray', which='minor')

   ax2 = fig.add_subplot(122)
   ax2.set_title(stationID+" phosphate (Last 12 months)")
   ax2.set_xlabel("Year (MM/YY)")
   ax2.xaxis.set_major_locator(mdates.MonthLocator(bymonth=[i.month for i in alternate_dates]))
   ax2.xaxis.set_minor_locator(dates.MonthLocator())
   ax2.xaxis.set_major_formatter(mdates.DateFormatter('%m/%y'))

   pylab.ylim([0,2])
   ax2.set_xlim(last_data - dt.timedelta(days=365), last_data)
   plt.grid(linestyle="--", color='gray')
   plt.grid(linestyle="--", color='gray', which='minor')
   plt.tight_layout()   
   plt.plot(datetimeList, phosphate[index[0]], label='0 m')
   plt.scatter(datetimeList, phosphate[index[0]], s=5, c='black')
   plt.plot(datetimeList_10, phosphate[index_10[0]], label='10 m')
   plt.scatter(datetimeList_10, phosphate[index_10[0]], s=5, c='black')
   plt.plot(datetimeList_20, phosphate[index_20[0]], label='20 m')
   plt.scatter(datetimeList_20, phosphate[index_20[0]], s=5, c='black')
   plt.plot(datetimeList_25, phosphate[index_25[0]], label='25 m')
   plt.scatter(datetimeList_25, phosphate[index_25[0]], s=5, c='black')
   plt.plot(datetimeList_30, phosphate[index_30[0]], label='30 m')
   plt.scatter(datetimeList_30, phosphate[index_30[0]], s=5, c='black')
   plt.plot(datetimeList_40, phosphate[index_40[0]], label='40 m')
   plt.scatter(datetimeList_40, phosphate[index_40[0]], s=5, c='black')
   plt.plot(datetimeList_50, phosphate[index_50[0]], label='50 m')
   plt.scatter(datetimeList_50, phosphate[index_50[0]], s=5, c='black')
   plt.plot(datetimeList_60, phosphate[index_60[0]], label='60 m')
   plt.scatter(datetimeList_60, phosphate[index_60[0]], s=5, c='black')
   plt.legend()


   pylab.savefig('/users/rsg/observatory/nutrients/phosphate_'+stationID+'.png', dpi=100)

   # Silicate
   fig = plt.figure()
   fig.set_size_inches(xinches,yinches)
   ax1 = fig.add_subplot(121)
   ax1.set_title(stationID+" silicate")
   ax1.set_xlabel("Year")
   ax1.xaxis.set_major_locator(dates.YearLocator(5))
   ax1.xaxis.set_minor_locator(dates.YearLocator())
   ax1.set_ylabel(r"Silicate ($\mu$M)")
   pylab.ylim([0,20])
   plt.tight_layout()   
   plt.plot(datetimeList, silicate[index[0]], c='yellow')
   plt.scatter(datetimeList, silicate[index[0]], s=5, c='black')
   plt.grid(linestyle="--", color='gray')
   plt.grid(linestyle="--", color='gray', which='minor')

   ax2 = fig.add_subplot(122)
   ax2.set_title(stationID+" silicate (Last 12 months)")
   ax2.set_xlabel("Year (MM/YY)")
   ax2.xaxis.set_major_locator(mdates.MonthLocator(bymonth=[i.month for i in alternate_dates]))
   ax2.xaxis.set_minor_locator(dates.MonthLocator())
   ax2.xaxis.set_major_formatter(mdates.DateFormatter('%m/%y'))

   pylab.ylim([0,20])
   ax2.set_xlim(last_data - dt.timedelta(days=365), last_data)
   plt.grid(linestyle="--", color='gray')
   plt.grid(linestyle="--", color='gray', which='minor')
   plt.tight_layout()   
   plt.plot(datetimeList, silicate[index[0]], label='0 m')
   plt.scatter(datetimeList, silicate[index[0]], s=5, c='black')
   plt.plot(datetimeList_10, silicate[index_10[0]], label='10 m')
   plt.scatter(datetimeList_10, silicate[index_10[0]], s=5, c='black')
   plt.plot(datetimeList_20, silicate[index_20[0]], label='20 m')
   plt.scatter(datetimeList_20, silicate[index_20[0]], s=5, c='black')
   plt.plot(datetimeList_25, silicate[index_25[0]], label='25 m')
   plt.scatter(datetimeList_25, silicate[index_25[0]], s=5, c='black')
   plt.plot(datetimeList_30, silicate[index_30[0]], label='30 m')
   plt.scatter(datetimeList_30, silicate[index_30[0]], s=5, c='black')
   plt.plot(datetimeList_40, silicate[index_40[0]], label='40 m')
   plt.scatter(datetimeList_40, silicate[index_40[0]], s=5, c='black')
   plt.plot(datetimeList_50, silicate[index_50[0]], label='50 m')
   plt.scatter(datetimeList_50, silicate[index_50[0]], s=5, c='black')
   plt.plot(datetimeList_60, silicate[index_60[0]], label='60 m')
   plt.scatter(datetimeList_60, silicate[index_60[0]], s=5, c='black')
   plt.legend()

   pylab.savefig('/users/rsg/observatory/nutrients/silicate_'+stationID+'.png', dpi=100)

   sys.exit(0)
   
if __name__=='__main__':
   main()
