Subversion Repositories lagranto.ecmwf

Rev

Go to most recent revision | Blame | Compare with Previous | Last modification | View Log | Download | RSS feed

#!/usr/bin/env python
# -*- coding:utf-8 -*-

# ---------------------------------------------------------------------
# import modules
# ---------------------------------------------------------------------

from __future__ import print_function
import sys
import argparse
from dypy.lagranto import Tra
from dypy.plotting import Mapfigure
import matplotlib.pyplot as plt
import matplotlib
matplotlib.rcParams['backend'] = "Qt4Agg"
# ---------------------------------------------------------------------
# main function
# ---------------------------------------------------------------------


def main(argv):

    # parameters

    domain = None
    var = ''
    eu = [-15, 30, 30, 75]
    nh = [-180, 180, 0, 90]
    filename = ""

    parser = argparse.ArgumentParser(
        description="Display trajectories on world map")
    parser.add_argument("-d", "--domain", type=str,
                        help="Choose domain to plot nh (north hemisphere) \
                        / eu (europe) / lon1,lon2,lat1,lat2")
    parser.add_argument("-v", "--variable", type=str, help="Choose a variable")
    parser.add_argument("-s", "--save", type=str, help="Save to file")
    parser.add_argument("file", type=str, help="filename")
    args = parser.parse_args()

    # read options

    filename = args.file
    if args.domain is None:
        pass
    elif args.domain == "nh":
        domain = nh
    elif args.domain == "eu":
        domain = eu
    else:
        domain = [float(coo) for coo in args.domain.split(',')]

    # -----------------------------------------
    # read the trajectories
    # -----------------------------------------

    trajs = Tra(filename)

    print(trajs.shape)
    if trajs.ntra > 500:
        it = trajs.ntra / 500
        trajs.set_array(trajs[::it, :])
        printstring = 'Only {} trajectories are plotted (every {})'
        print(printstring.format(trajs.ntra, it))

    if domain is None:
        domain = [trajs['lon'].min(), trajs['lon'].max(),
                  trajs['lat'].min(), trajs['lat'].max()]
    # set default parameters
    if args.variable is None:
        var = trajs.dtype.names[-1]
        printstring = "Fields available for plotting: \n {0}"
        print(printstring.format(trajs.dtype.names))
    else:
        var = args.variable

    # plot trajectories

    m = Mapfigure(domain=domain, resolution='l')

    fig = plt.figure()
    m.ax = fig.add_subplot(111)
    m.drawmap()
    lc = m.plot_traj(trajs, var, rasterized=True)
    m.colorbar(lc, 'right', label=var)

    if args.save is None:
        plt.show()
    else:
        plt.savefig(args.save, bbox_inches='tight')


if __name__ == '__main__':
    main(sys.argv[1:])