Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

gdptools-Pangeo Method Comparison for CONUS404 Spatial Aggregation

In this notebook, we will be comparing two spatial aggregation methods to aggregate from grids to polygons. One uses gdptools. The other uses conservative regional methods with xarray and geopandas natively (see this Pangeo Discourse for details).

The goal of this comparision is to see how the results of the two methods compare to help judge the efficacy of one versus the other.

%xmode minimal
import os
# Needed when boto3 >= 1.36.0 or the rechunking process will fail
# This needs to be set before the boto3 library gets loaded
# See: https://github.com/aws/aws-cli/issues/9214#issuecomment-2606619168
os.environ['AWS_REQUEST_CHECKSUM_CALCULATION'] = 'when_required'
os.environ['AWS_RESPONSE_CHECKSUM_VALIDATION'] = 'when_required'
import time
import xarray as xr
import geopandas as gp
import pandas as pd
import numpy as np
import sparse

import hvplot.pandas
import hvplot.xarray
import dask
import cf_xarray

from pynhd import NLDI, WaterData
from pygeohydro import watershed
import cartopy.crs as ccrs
from shapely.geometry import Polygon

import pyproj
from gdptools import WeightGen, AggGen, UserCatData
import pystac
from packaging.version import Version
import zarr

Open dataset from Intake Catalog

First, let’s begin by loading the CONUS404 daily data.

def get_children(catalog, collection_id=None):
    """
    This function retrieves a specified collection from a STAC catalog/collection and prints key metadata 
    for exploring/accessing the datasets contained within it.
    If there is no collection ID provided, the collections in the top level of the catalog will be printed.
    If a collection ID is provided, it will retrieve the collection with that ID from the input catalog/collection.
    If the collection ID points to a dataset, it will print the assets available for the dataset.
    If the collection ID points to another collection, it will list the child collections in the IDed collection.

    Args:
        catalog (pystac.Catalog | pystac.Collection): The STAC catalog/collection object.
        collection_id (str): The ID of the collection or dataset to retrieve from catalog.
    
    Returns:
        collection (pystac.Catalog | pystac.Collection): The collection object corresponding to the provided ID
                                                         or the top-level catalog if no ID is provided.
    """
    dataset = False
    if collection_id:
        collection = catalog.get_child(collection_id)
        if collection.assets:
            dataset = True
            print(f"{collection_id} is a dataset. Please review the assets below and select one to open.")

        else:
            print(f"{collection_id} is a collection. Please review the child items and select one to open in the next cell.")
    else:
        collection = catalog
    if dataset==True:
        # List the assets
        for asset in collection.assets:
            print(f"Asset ID: {asset}")
            print(f"    Title: {collection.assets[asset].title}")
            print(f"    Description: {collection.assets[asset].description}")
    else:
        collections = list(collection.get_collections())
        print(f"Number of collections: {len(collections)}")
        print("Collections IDs:")
        for child_collection in collections:
            id = child_collection.id
            cite_as = "Not available"
            for link in child_collection.links:
                if link.rel == "cite-as":
                    cite_as = link.target
            print(f"- {id}, Source: {cite_as}")
    return collection
# url for the WMA STAC Catalog
catalog_url = "https://api.water.usgs.gov/gdp/pygeoapi/stac/stac-collection/"

# use pystac to read the catalog
catalog = pystac.Catalog.from_file(catalog_url)

# list the collections in the catalog
catalog = get_children(catalog)
# select a collection from the catalog, replace the collection ID with the one you want to use:
collection = get_children(catalog, collection_id="conus404")
# select a collection from the catalog, replace the collection ID with the one you want to use:
collection = get_children(collection, collection_id="conus404_daily")

As we can see there are two different locations for the conus404_daily data set. The locations are (1) -hovenweep meaning it is stored on the USGS Hovenweep HPC and (2) -osn meaning the data is on the USGS open storage network (OSN). As the OSN is free to access from any environment, we will use that for this example, but the location can easily be changed depending on your needs.

# replace with the asset ID you want to use:
selected_asset_id = "zarr-s3-osn"

# read the asset metadata
asset = collection.assets[selected_asset_id]
# uncomment the lines below to read in your AWS credentials if you want to access data from a requester-pays bucket (-cloud)
# os.environ['AWS_PROFILE'] = 'default'
# %run ../../../environment_set_up/Help_AWS_Credentials.ipynb

Finally, read in the daily CONUS404 data set and select the accumulated grid scale precipitation. We select the precipitation rather then all variables to keep things simple for this example, but aggregation of other variables would follow the same methodology.

if Version(zarr.__version__) < Version("3.0.0"):
    conus404 = xr.open_dataset(
        asset.href,
        storage_options=asset.extra_fields['xarray:storage_options'],
        **asset.extra_fields['xarray:open_kwargs']
    )
else:
    conus404 = xr.open_dataset(
    asset.href,
    storage_options=asset.extra_fields['xarray:storage_options'],
    **asset.extra_fields['xarray:open_kwargs'],
    zarr_format=2
    )

# Include the crs as we will need it later
conus404 = conus404[['PREC_ACC_NC', 'crs']]
conus404

Parallelize with Dask (optional)

Some of the steps we will take are aware of parallel clustered compute environments using dask. We can start a cluster now so that future steps take advantage of this ability. This is an optional step, but speed ups data loading significantly, especially when accessing data from the cloud.

We have documentation on how to start a Dask Cluster in different computing environments here. Uncomment the cluster start up that works for your compute environment.

%run ../../../environment_set_up/Start_Dask_Cluster_Nebari.ipynb
## If this notebook is not being run on Nebari, replace the above 
## path name with a helper appropriate to your compute environment.  Examples:
# %run ../../../environment_set_up/Start_Dask_Cluster_Denali.ipynb
# %run ../../../environment_set_up/Start_Dask_Cluster_Tallgrass.ipynb
# %run ../../../environment_set_up/Start_Dask_Cluster_Desktop.ipynb

Load the Feature Polygons

Now that we have read in the CONUS404 data, we need to read in some polygons to aggregate the data. For this example, we will use the HUC12 basins within the Delaware River Basin. To get these HUC12 polygons, we can use pygeohydro.watershed to query the Hydro Network Linked Data Index (NLDI). All we need to get the basins is the general IDs of the HUC12 basins. For the Delaware Basin those are ones that start with 020401 or 020402.

%%time
wbd = watershed.WBD("huc4")
delaware_basin = wbd.byids(field="huc4", fids="0204")
huc12_basins = WaterData('wbd12').bygeom(delaware_basin.iloc[0].geometry)
huc12_basins = huc12_basins[huc12_basins['huc12'].str.startswith(('020401', '020402'))]

Let’s plot the HUC12 basins to see how they look.

huc12_basins.hvplot(
    c='huc12', title="Delaware River HUC12 basins",
    coastline='50m', geo=True,
    aspect='equal', legend=False, frame_width=300
)

An important thing to note is that all geodataframes should have a coordinate reference system (CRS). Let’s check to make sure our geodataframe has a CRS.

huc12_basins.crs

Limit CONUS404 Spatial Range

With the HUC12 basins read in, we only need the CONUS404 data that spans these polygons as they are the regions we will be aggregating. So, let’s limit the CONUS404 spatial range to that of the basins. This will save on memory and computation. Note doing this is mainly useful when the regions footprint is much smaller than the footprint of the gridded model.

To limit the spatial range, we first need to convert the CRS of the basins to that of CONUS404. Then extract the bounding box of the basins.

huc12_basins_conus404_crs = huc12_basins.to_crs(conus404.crs.crs_wkt)
bbox = huc12_basins_conus404_crs.total_bounds
bbox

Then select the CONUS404 data within the bounding box. However, when we do this, we will extend the bounds out by 5% of their range to ensure all of our basins are within the spatially limited data. We do this as the reprojections of the CRS can cause slight distortions that make polygons on the bounds not fall fully within the data.

bbox_x_range = bbox[2] - bbox[0]
bbox_y_range = bbox[3] - bbox[1]
x_range = slice(bbox[0] - bbox_x_range * 0.05,
                bbox[2] + bbox_x_range * 0.05)
y_range = slice(bbox[1] - bbox_y_range * 0.05,
                bbox[3] + bbox_y_range * 0.05)

conus404 = conus404.sel(x=x_range, y=y_range)
conus404

To make sure this worked as intended, let’s plot the full basin over the extracted CONUS404 data.

# Select a single timestamp for simple plotting
timestamp = '2000-5-02'
cutout = conus404.sel(time=timestamp).drop_vars(['lat', 'lon'])
# We need to write the CRS to the CONUS404 dataset and
# reproject for clean plotting with hvplot
cutout = cutout.rio.write_crs(conus404.crs.crs_wkt).rio.reproject('EPSG:4326')

cutout_plt = cutout.hvplot(
    coastline='50m', geo=True,
    aspect='equal', cmap='viridis', frame_width=300
)
huc12_plt = huc12_basins.hvplot(
    geo=True, alpha=0.3, c='r'
)
cutout_plt * huc12_plt

Looks good!

Aggregate CONUS404 to Feature Polygons

Now that we have our gridded data and polygons, it is time to aggregate them using gdptools and what we consider the native method that uses geopandas and xarray.

NOTE: gdptools handles a number of pre-processing steps for the user:

  • Subsets the gridded data to a buffered bounding box of the targets polygons.

  • Checks latitude bounds and if it’s in the interval 0-360, it’s rotated into -180 - 180.

  • Checks the order of the longitude bounds, i.e. top-to-bottom or bottom-to-top, and autmatically acconts for this is the sub-setting operation above.

gdptools Aggregation

Let’s start by using the gdptools aggregation method, where we use three data classes provided by gdptools, in the order discussed below.

  1. UserCatData stores the data required to perfom the aggregation.

  2. WeightGen is a class used to generate the areal-weights used to calculate the areal-weighted interpolation. The weights generated between a source and target dataset can be reused as long as the source and target are consistent. For example, If a new time-period became available, or a different set of variables is needed, the same weights can be used.

  3. AggGen is a class that is used to calculate the aggregation.

The first step to aggregating with gdptools is to convert the input data to a UserCatData class. Note additionally that the var parameter could be a list of variables, such that when the user_data object is used in AggGen, the calculuate_agg() method will perform the aggregation over all the list of variables.

user_data = UserCatData(
    ds=conus404,
    proj_ds=conus404.crs.crs_wkt,
    x_coord='x',
    y_coord='y',
    t_coord='time',
    var='PREC_ACC_NC',
    f_feature=huc12_basins,
    proj_feature=huc12_basins.crs,
    id_feature='huc12',
    period=[pd.Timestamp(conus404.time.values.min()),
            pd.Timestamp(conus404.time.values.max())],
)

The UserCatData can then be used to generate weights for each polygon. An important thing to note is that when generating the weights we need to use an equal area projection (i.e., equal area CRS).

crs_area = "EPSG:6931" # Good for northern hemisphere
# crs_area = "EPSG:5070" # Good for CONUS

# time the weight generation for later comparison
t0 = time.time()

weight_gen = WeightGen(
    user_data=user_data,
    # use serial here vs dask as the dask overhead would cause
    # a slow down since our example is relatively small scale
    method="serial",
    weight_gen_crs=crs_area,
)

df_gdptools_weights = weight_gen.calculate_weights()

gdptools_weights_time = time.time() - t0

df_gdptools_weights

With the weights, we can now perform the aggregation.

Note that the return values of calculate_agg() are:

  1. ngdf the target GeoDataFrame, sorted by id_feature and filtered to only those ids that have weights. In other words if, there was not complete overlay of the source to target datasets, some target ids will not have values. If the user wishes to plot the resulting interpolated data, the returned GeoDataFrame’s id order is the same as the gdptools_aggregation.

  2. gdptools_aggregation, which is an xarray.Dataset containing the interpolated output with dimensions of time and id_feature. In the case below, the agg_writer parameter is set to 'none', it can be set to 'netcdf', 'csv', or 'parquet' for archiving the results to a file.

t0 = time.time()

agg_gen = AggGen(
    user_data=user_data,
    # Use masked to ignore NaNs
    # Note that a we use mean vs sum as sum seems to ignore
    # weights even though they should be equivalent methods
    # (i.e., weighted sum = weighted mean)
    stat_method="masked_mean",
    agg_engine="dask",
    weights=df_gdptools_weights,
    agg_writer='none',
)
_, gdptools_aggregation = agg_gen.calculate_agg()

gdptools_agg_time = time.time() - t0

gdptools_aggregation

Let’s make a nice plot of the aggregated HUC12 basins to make sure the aggregation worked as expected.

# xarray holds the huc12s in sorted order
gdptools_huc12_basins = huc12_basins.copy().sort_values('huc12')
gdptools_huc12_basins['aggregation'] = gdptools_aggregation.sel(time=timestamp)['PREC_ACC_NC']

gdptools_plt = gdptools_huc12_basins.hvplot(
    c='aggregation', title="Accumulated Precipitation over HUC12 basins",
    coastline='50m', geo=True, cmap='viridis',
    aspect='equal', legend=False, frame_width=300
)

cutout_plt * gdptools_plt + cutout_plt

Native Method

For the native method, we first need to extract the grid information from our CONUS404 data set. We then use it to create polygon boxes that we overlay with the basin polygons to generate weights. Finally like gdptools, we use the weights to aggregate via a weighted sum.

To give a fair computational time comparison with gdptools, we will group all steps to generate the weights into one timed cell.

Create Weights

To generate the weights, we (1) extract grid information (includes extracting the x and y grid and getting their bounds), (2) use these bounds to create polygons of the grid, (3) assign the polygons to a GeoDataFrame with the CONUS404 dataset’s CRS, (4) overlay the grid polygons and basin polygons, (5) use the overlay to get fractional area weights.

%%time
t0 = time.time()
# (1) extract grid info
grid = conus404[['x', 'y']].drop_vars(['lat', 'lon']).reset_coords()
grid = grid.cf.add_bounds(['x', 'y'])


# (2) create polygons of the grid
# use a simple helper function. This way we can use xarray to parallelize.
def bounds_to_poly(x_bounds, y_bounds):
    return Polygon([
        (x_bounds[0], y_bounds[0]),
        (x_bounds[0], y_bounds[1]),
        (x_bounds[1], y_bounds[1]),
        (x_bounds[1], y_bounds[0])
    ])

# Stack the grid cells into a single stack (i.e., x-y pairs)
points = grid.stack(point=('y', 'x'))

# Apply the function to create polygons from bounds
boxes = xr.apply_ufunc(
    bounds_to_poly,
    points.x_bounds,
    points.y_bounds,
    input_core_dims=[("bounds",),  ("bounds",)],
    output_dtypes=[np.dtype('O')],
    vectorize=True
)


# (3) assign polygons to geodataframe with CRS
grid_polygons = gp.GeoDataFrame(
    data={"geometry": boxes.values, "y": boxes['y'], "x": boxes['x']},
    index=boxes.indexes["point"],
    crs=conus404.crs.crs_wkt
)


# (4) overlay the grid polygons with basin polygons
# transform both to an area preserving projection
huc12_basins_area = huc12_basins.to_crs(crs_area)
grid_polygons = grid_polygons.to_crs(crs_area)

# overlay the polygons.
overlay = grid_polygons.overlay(huc12_basins_area, keep_geom_type=True)


# (5)calculate the area fraction for each region
grid_cell_fraction = overlay.geometry.area.groupby(overlay['huc12']).transform(lambda x: x / x.sum())

# turn this into a series
multi_index = overlay.set_index(['y', 'x', 'huc12']).index
df_native_weights = pd.Series(grid_cell_fraction.values, index=multi_index)

da_native_weights_stacked = xr.DataArray(df_native_weights)

# unstack to a sparse array.
native_weights = da_native_weights_stacked.unstack(sparse=True, fill_value=0.)

native_weights_time = time.time() - t0

native_weights

Now that we have our weights, we can clearly see that this is a sparse matrix, with a density of ~0.0025 (i.e., only 0.25% of values are non-zero). So, maintaining it as a sparse martix is the right move for memory conservation, especially as this process scales up.

Also, this process is area conserving. We can verify this for each basin’s area with a simple area calculation.

# calculate areas of HUC12s from overlay and original polygons
overlay_area = overlay.geometry.area.groupby(overlay['huc12']).sum()
huc12_area = huc12_basins_area.geometry.area.groupby(huc12_basins_area['huc12']).sum()
# find the max fractional difference
(np.abs(overlay_area - huc12_area) / huc12_area).max()

Nice! This means the differences can be attributed to machine precision.

We can also verify that the cell fractions all sum up to one.

grid_cell_fraction.groupby(overlay['huc12']).sum().unique()

Perform Aggregation

To aggregate the data, we can use xarray.Dataset.weighted to do our weighted calculations. This is simple as it will take a sparse array as weights and compute the aggregation.

%%time
t0 = time.time()

native_aggregation = conus404.drop_vars('crs').weighted(native_weights).sum(dim=['x', 'y']).compute()

native_agg_time = time.time() - t0

native_aggregation

Like the gdptools aggregation results, let’s make some plots to make sure this worked as expected.

# xarray holds the huc12s in sorted order
native_huc12_basins = huc12_basins.copy().sort_values('huc12')
native_huc12_basins['aggregation'] = native_aggregation.sel(time=timestamp)['PREC_ACC_NC'].data.todense()
native_plt = native_huc12_basins.hvplot(
    c='aggregation', title="Accumulated Precipitation over HUC12 basins",
    coastline='50m', geo=True, cmap='viridis',
    aspect='equal', legend=False, frame_width=300
)

cutout_plt * native_plt + cutout_plt

Compare the Results

With both aggregation methods complete, we are now ready to compare the results. We can do this both for the final output and the intermediate weights.

Weight Comparison

To do the weight comparison, we first need to standardize the weight outputs. This is relatively simple as we just need to convert the gdptools DataFrame weights into an xarray.DataArray. We can do this just like we did for the conservative method, but assigning the x and y values to the gdptools data frame using the given indices.

# Due to the buffer region, gdptools weights i index values
# are off by 3 and j are off by 1. This was found from a manual inspection
df_gdptools_weights['y'] = conus404['y'].isel(y=df_gdptools_weights['i']+3).data
df_gdptools_weights['x'] = conus404['x'].isel(x=df_gdptools_weights['j']+1).data
gdptool_weights = xr.DataArray(
    df_gdptools_weights.set_index(['y', 'x', 'huc12'])['wght']
).unstack(sparse=True, fill_value=0)

Now, a simple max fractional difference is the simple check for how they compare.

(np.abs(gdptool_weights - native_weights) / native_weights).max()

Look at that. They are identical (up to machine precision). So, the only other thing to compare would be the time required for the computation.

print(f'gdptools weights computation time: {gdptools_weights_time:0.3f} seconds')
print(f'native weights computation time: {native_weights_time:0.3f} seconds')
print(f'computation time difference: {(gdptools_weights_time - native_weights_time):0.3f} seconds')
print(f'computation time ratio: {(gdptools_weights_time / native_weights_time):0.3f}')

So, from this comparison, we can see that both methods give the same weights, but the method using xarray and geopandas slightly faster (and likely not significantly). However, this does not test how well either of the two methods scale.

Aggregation Comparison

To do the aggregated data comparison, there is no need for any data formatting, as both gdptools and the native method have matching xarray.Dataset formats. So, let’s start with the simple max fractional difference to compare.

(np.abs(gdptools_aggregation - native_aggregation) / native_aggregation).max()

Well, as expected, they are nearly identical, since they had nearly identical weights.

Let’s plot the fractional difference for a timestamp just to see how they compare.

# xarray holds the huc12s in sorted order
diff_huc12_basins = huc12_basins.copy().sort_values('huc12')
diff_huc12_basins['aggregation'] = (np.abs(gdptools_aggregation - native_aggregation) / native_aggregation).sel(time=timestamp)['PREC_ACC_NC']

diff_huc12_basins.hvplot(
    c='aggregation', title="Difference in Precipitation over HUC12 basins",
    coastline='50m', geo=True, cmap='viridis',
    aspect='equal', legend=False, frame_width=300
)

Finally, let’s compare the computational times.

print(f'gdptools aggregation computation time: {gdptools_agg_time:0.3f} seconds')
print(f'native aggregation computation time: {native_agg_time:0.3f} seconds')
print(f'computation time difference: {(gdptools_agg_time - native_agg_time):0.3f} seconds')
print(f'computation time ratio: {(gdptools_agg_time / native_agg_time):0.3f}')

It looks like this step takes about the same time for both. So, let’s compare the total computation time (weights and aggregation).

gdptools_total_time = gdptools_weights_time+gdptools_agg_time
native_total_time = native_weights_time + native_agg_time
print(f'gdptools total computation time: {gdptools_total_time:0.3f} seconds')
print(f'native total computation time: {native_total_time:0.3f} seconds')
print(f'total computation time difference: {(gdptools_total_time - native_total_time):0.3f} seconds')
print(f'total computation time ratio: {(gdptools_total_time / native_total_time):0.3f}')

Alright, since both the aggregation and weights times are about equal, the overall performance of both is equal as well. Therefore, it appears that either method is a solid choice. The only other thing to test would be how well each method scales to larger feature polygons and larger grids. However, we will leave that comparison for another notebook.

Shut down the Dask Client

If utilized, we should shut down the dask client.

client.close()