# Region masks

POP includes a default region mask as a component of the grid information. This is often not super 
relevant for analyses. `pop_tools` provides several alternative region masks; these are demostrated here.

## Import packages

In [None]:
%matplotlib inline
import matplotlib.colors as colors
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr

import pop_tools

## Load POP grid as `xarray.Dataset`

In [None]:
grid_name = 'POP_gx1v7'

ds = pop_tools.get_grid(grid_name)
ds

## Plot default `REGION_MASK`
The default `REGION_MASK` is a 2-D array with unique integer values for each region. Negative integers denote "marginal seas," which are not directly connected to the ocean.

In [None]:
regions = np.array(np.unique(ds.REGION_MASK))
regions

In [None]:
ds.REGION_MASK.plot.contourf(levels=regions, cmap='tab20');

## More useful region masks

It's often more useful to define a region mask as a 3-D array of zeros and ones, where the first dimension is `region`; this permits overlapping regions and is convenient for computation because the mask can be applied by multiplication, which yields a `region` dimension via broadcasting.

`pop_tools` supports converting the default `REGION_MASK` to this type of mask thru the `region_mask_3d` function.

In [None]:
mask3d = pop_tools.region_mask_3d(grid_name, mask_name='default')
mask3d

This mask includes a coordinate variable with the region names.

In [None]:
mask3d.region

A particular region can be selected by name.

In [None]:
mask3d.sel(region='Southern Ocean').plot();

To visualize all the regions, we can define a help plotting function,

In [None]:
def visualize_mask(mask3d):
    nregion = len(mask3d.region)

    # mask out land
    mask3d = mask3d.where(ds.KMT > 0)

    # visualize the regions
    ncol = int(np.sqrt(nregion))
    nrow = int(nregion / ncol) + min(1, nregion % ncol)

    fig, ax = plt.subplots(nrow, ncol, figsize=(4 * ncol, 3 * nrow), constrained_layout=True)

    for i, region in enumerate(mask3d.region.values):
        plt.axes(ax.ravel()[i])
        mask3d.sel(region=region).plot()

    # delete the unused axes
    for i in range(nregion, ncol * nrow):
        fig.delaxes(ax.ravel()[i])

    fig.suptitle(f'Mask name = {mask3d.mask_name}', fontsize=16)

and apply it to the default mask created above.

In [None]:
visualize_mask(mask3d)

## Alternative region masks

Other useful region masks are pre-defined in the package. `list_region_masks` returns a list of pre-defined masks.

In [None]:
region_masks = pop_tools.list_region_masks(grid_name)
region_masks

We can visualize all of these using the helper function above.

In [None]:
for region_mask in region_masks:
    mask3d = pop_tools.region_mask_3d(grid_name, mask_name=region_mask)
    visualize_mask(mask3d)

To illustrated how regions cover the global domain, including with overlap, we can sum over the `region` dimension. 

In [None]:
mask3d = pop_tools.region_mask_3d(grid_name, mask_name='lat-range-basin')
mask3d.sum('region').plot();

## User defined region masks

Finally, it is also possible to make a region mask on the fly by building a dictionary containing the defining logic. `region_mask_3d` accepts a `region_defs` argument. This is a dictionary of the following form.

```python
region_defs = {region1_name: list_of_criteria_dicts_1, 
               region2_name: list_of_criteria_dicts_2,...}
```

The `list_of_criteria_dicts` are lists of dictionaries; each must include the keys 'match' or 'bounds'. For instance:

```python
list_of_criteria_dicts_1 = [{'match': {'REGION_MASK': [1, 2, 3, 6]}, 
                             'bounds': {'TLAT': [-90., -30.]}}]
```
will return a mask where the default `REGION_MASK` matches the specified values and `TLAT` falls between the specified bounds. Multiple entries in the `list_of_criteria_dicts` are applied with an "or" condition.

Here's an example region mask generated for the North Atlantic Subpolar and Subtropical Gyres.

In [None]:
region_defs = {
    'NAtl-STG': [
        {'match': {'REGION_MASK': [6]}, 'bounds': {'TLAT': [32.0, 42.0], 'TLONG': [310.0, 350.0]}}
    ],
    'NAtl-SPG': [
        {'match': {'REGION_MASK': [6]}, 'bounds': {'TLAT': [50.0, 60.0], 'TLONG': [310.0, 350.0]}}
    ],
}

mask3d = pop_tools.region_mask_3d(grid_name, region_defs=region_defs, mask_name='N. Atlantic Gyres')
visualize_mask(mask3d)

In [None]:
%load_ext watermark
%watermark -d -iv -m -g -h