#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# -----------------------------------------------------------------------------
# THIS FILE IS PART OF THE CYLC WORKFLOW ENGINE.
# Copyright (C) NIWA & British Crown (Met Office) & Contributors.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
# -----------------------------------------------------------------------------
# This is illustrative code developed for tutorial purposes, it is not
# intended for scientific use and is not guarantied to be accurate or correct.
"""
Usage:
    consolidate-observations

Environment Variables:
    DOMAIN: The area in which to generate forecasts for in the format
        (lng1, lat1, lng2, lat2).
    RESOLUTION: The length/width of each grid cell in degrees.

"""

from glob import glob
import math
import os
import sys

import util


def plot_wind_data(wind_x, wind_y, x_range, y_range, x_coords, y_coords,
                   z_coords):
    """Generate a 2D vector plot of the wind data if matplotlib is installed"""
    try:
        import matplotlib
        matplotlib.use('Agg')
        import matplotlib.pyplot as plt
    except ModuleNotFoundError:
        print('matplotlib not installed: plotting disabled.', file=sys.stderr)
        return

    plt.figure()
    x_points = [point[0] for point in util.permutations(x_range, y_range)]
    y_points = [point[1] for point in util.permutations(x_range, y_range)]
    plt.quiver(x_points,
               y_points,
               [wind_x(x, y) for x, y in zip(x_points, y_points)],
               [wind_y(x, y) for x, y in zip(x_points, y_points)])
    plt.quiver(x_coords,
               y_coords,
               [x[0] for x in z_coords],
               [y[1] for y in z_coords],
               color='red')
    plt.savefig('wind.png')


def get_wind_fields():
    """Read in wind.csv files, interpolate and return the wind fields."""
    # Filename glob for observation files.
    files = glob(os.path.join(
        os.environ['CYLC_WORKFLOW_WORK_DIR'],
        os.environ['CYLC_TASK_CYCLE_POINT'],
        'get_observations*',
        'wind.csv'))

    # Extract observation data.
    data = []
    for filename in files:
        with open(filename, 'r') as datafile:
            data.append(list(map(float, datafile.read().split(', '))))
    data.sort(key=lambda x: x[1])

    # Convert data into (lat, lng, (wind_x, wind_y)) components.
    x_coords = [datum[1] for datum in data]
    y_coords = [datum[0] for datum in data]
    z_coords = [(datum[3] * math.sin(math.radians(datum[2])),
                 datum[3] * math.cos(math.radians(datum[2])))
                for datum in data]

    # Interpolate this data to generate wind "fields".
    wind_x = util.SurfaceFitter(x_coords, y_coords, [x[0] for x in z_coords])
    wind_y = util.SurfaceFitter(x_coords, y_coords, [y[1] for y in z_coords])

    return wind_x, wind_y, (x_coords, y_coords, z_coords)


def main():
    domain = util.parse_domain(os.environ['DOMAIN'])
    resolution = float(os.environ['RESOLUTION'])

    # Interpolate fields for the longitudinal (x) and latitudinal (y)
    # components of wind.
    wind_x, wind_y, data_points = get_wind_fields()

    # The coordinates of the domain.
    x_range = list(util.frange(domain['lng1'], domain['lng2'], resolution))
    y_range = list(util.frange(domain['lat1'], domain['lat2'], resolution))

    # Plot wind fields for reference.
    plot_wind_data(wind_x, wind_y, x_range, y_range, *data_points)

    # Write wind data to csv.
    util.field_to_csv(wind_x, x_range, y_range, 'wind_x.csv')
    util.field_to_csv(wind_y, x_range, y_range, 'wind_y.csv')


if __name__ == '__main__':
    main()
