Subversion Repositories lagranto.ecmwf

Rev

Go to most recent revision | Details | Last modification | View Log | RSS feed

Rev Author Line No. Line
35 michaesp 1
#!/usr/bin/env python
2
# -*- coding:utf-8 -*-
3
 
4
# ---------------------------------------------------------------------
5
# import modules
6
# ---------------------------------------------------------------------
7
 
8
from __future__ import print_function
9
import sys
10
import argparse
11
from dypy.lagranto import Tra
12
from dypy.plotting import Mapfigure
13
import matplotlib.pyplot as plt
14
import matplotlib
15
matplotlib.rcParams['backend'] = "Qt4Agg"
16
# ---------------------------------------------------------------------
17
# main function
18
# ---------------------------------------------------------------------
19
 
20
 
21
def main(argv):
22
 
23
    # parameters
24
 
25
    domain = None
26
    var = ''
27
    eu = [-15, 30, 30, 75]
28
    nh = [-180, 180, 0, 90]
29
    filename = ""
30
 
31
    parser = argparse.ArgumentParser(
32
        description="Display trajectories on world map")
33
    parser.add_argument("-d", "--domain", type=str,
34
                        help="Choose domain to plot nh (north hemisphere) \
35
                        / eu (europe) / lon1,lon2,lat1,lat2")
36
    parser.add_argument("-v", "--variable", type=str, help="Choose a variable")
37
    parser.add_argument("-s", "--save", type=str, help="Save to file")
38
    parser.add_argument("file", type=str, help="filename")
39
    args = parser.parse_args()
40
 
41
    # read options
42
 
43
    filename = args.file
44
    if args.domain is None:
45
        pass
46
    elif args.domain == "nh":
47
        domain = nh
48
    elif args.domain == "eu":
49
        domain = eu
50
    else:
51
        domain = [float(coo) for coo in args.domain.split(',')]
52
 
53
    # -----------------------------------------
54
    # read the trajectories
55
    # -----------------------------------------
56
 
57
    trajs = Tra(filename)
58
 
59
    print(trajs.shape)
60
    if trajs.ntra > 500:
61
        it = trajs.ntra / 500
62
        trajs.set_array(trajs[::it, :])
63
        printstring = 'Only {} trajectories are plotted (every {})'
64
        print(printstring.format(trajs.ntra, it))
65
 
66
    if domain is None:
67
        domain = [trajs['lon'].min(), trajs['lon'].max(),
68
                  trajs['lat'].min(), trajs['lat'].max()]
69
    # set default parameters
70
    if args.variable is None:
71
        var = trajs.dtype.names[-1]
72
        printstring = "Fields available for plotting: \n {0}"
73
        print(printstring.format(trajs.dtype.names))
74
    else:
75
        var = args.variable
76
 
77
    # plot trajectories
78
 
79
    m = Mapfigure(domain=domain, resolution='l')
80
 
81
    fig = plt.figure()
82
    m.ax = fig.add_subplot(111)
83
    m.drawmap()
84
    lc = m.plot_traj(trajs, var, rasterized=True)
85
    m.colorbar(lc, 'right', label=var)
86
 
87
    if args.save is None:
88
        plt.show()
89
    else:
90
        plt.savefig(args.save, bbox_inches='tight')
91
 
92
 
93
if __name__ == '__main__':
94
    main(sys.argv[1:])