ESA WorldCover Classification Inference#

In this notebook, we will use the model trained using Sentinel-2 L2A data and ESA WorldCover from 2020 to perform inference on Sentinel-2 images from 2022. The model will classify the land cover types in the images based on the classes specified during the model training. We will use a combination of what we have learned so far during this training STAC catalogs, StackSTAC, xarray, and PyTorch.

Import libraries#

import geopandas as gpd
import numpy as np
import pandas as pd
import pystac_client
import stackstac
import torch
from rasterio.enums import Resampling
from shapely import Point
import torch
import torch.nn.functional as F
from scipy.ndimage import zoom
import torch.nn as nn
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap, BoundaryNorm
import rasterio

Setup the model architecture (the same as training)#

class UNet(nn.Module):
    def __init__(self, in_channels=4, out_classes=11, dropout=0.1, init_weights=True):
        super(UNet, self).__init__()

        def conv_block(in_ch, out_ch):
            return nn.Sequential(
                nn.Conv2d(in_ch, out_ch, 3, padding=1),
                nn.BatchNorm2d(out_ch),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_ch, out_ch, 3, padding=1),
                nn.BatchNorm2d(out_ch),
                nn.ReLU(inplace=True),
                nn.Dropout2d(dropout),
            )

        # Encoder
        self.enc1 = conv_block(in_channels, 64)
        self.enc2 = conv_block(64, 128)
        self.enc3 = conv_block(128, 256)

        self.pool = nn.MaxPool2d(2)

        # Bottleneck
        self.bottleneck = conv_block(256, 512)

        # Decoder
        self.dec3 = conv_block(512 + 256, 256)
        self.dec2 = conv_block(256 + 128, 128)
        self.dec1 = conv_block(128 + 64, 64)
        self.dec0 = conv_block(64, 32)

        # Final classifier
        self.final = nn.Conv2d(32, out_classes, kernel_size=1)

        if init_weights:
            self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def forward(self, x):
        e1 = self.enc1(x)  # (B, 64, H, W)
        e2 = self.enc2(self.pool(e1))  # (B, 128, H/2, W/2)
        e3 = self.enc3(self.pool(e2))  # (B, 256, H/4, W/4)
        b = self.bottleneck(self.pool(e3))  # (B, 512, H/8, W/8)

        d3 = self._upsample_concat(b, e3)  # (B, 512+256, H/4, W/4)
        d3 = self.dec3(d3)

        d2 = self._upsample_concat(d3, e2)  # (B, 256+128, H/2, W/2)
        d2 = self.dec2(d2)

        d1 = self._upsample_concat(d2, e1)  # (B, 128+64, H, W)
        d1 = self.dec1(d1)

        d0 = self.dec0(d1)  # (B, 32, H, W)

        return self.final(d0)  # logits: (B, out_classes, H, W)

    def _upsample_concat(self, x, skip):
        x = F.interpolate(x, size=skip.shape[2:], mode="bilinear", align_corners=False)
        return torch.cat([x, skip], dim=1)

Search STAC catalog for Sentinel-2 L2A images#

lat, lon = 38.25, 21.25
start = "2022-01-01"
end = "2022-01-31"
STAC_API = "https://earth-search.aws.element84.com/v1"
COLLECTION = "sentinel-2-l2a"

# Search the catalogue
catalog = pystac_client.Client.open(STAC_API)
search = catalog.search(
    collections=[COLLECTION],
    datetime=f"{start}/{end}",
    bbox=(lon - 1e-3, lat - 1e-3, lon + 1e-3, lat + 1e-3),
    max_items=100,
   query={
            "eo:cloud_cover": {"lt": 20},
            "s2:nodata_pixel_percentage": {"lt": 10},
        },
)

all_items = search.get_all_items()
items = []
granules = []
for item in all_items:
    if item.properties["s2:granule_id"] not in granules:
        items.append(item)
        granules.append(item.properties["s2:granule_id"])


print("Found %d Sentinel-2-L2A items" %len(items))
/Users/syam/virtualenvs/myvenv/lib/python3.13/site-packages/pystac_client/item_search.py:896: FutureWarning: get_all_items() is deprecated, use item_collection() instead.
  warnings.warn(
Found 4 Sentinel-2-L2A items
items[0]

Prepare bounding box and create the dataset using StacSTAC#

proj = items[0].properties["proj:code"]
poidf = gpd.GeoDataFrame(
    pd.DataFrame(),
    crs="EPSG:4326",
    geometry=[Point(lon, lat)],
).to_crs(proj)

coords = poidf.iloc[0].geometry.coords[0]

# Create bounds in projection
size = 2047
gsd = 20
bounds = (
    coords[0] - (size * gsd) // 2,
    coords[1] - (size * gsd) // 2,
    coords[0] + (size * gsd) // 2,
    coords[1] + (size * gsd) // 2,
)
assets = ["blue", "green", "red", "nir","swir16", "scl"]

ds = stackstac.stack(
    items,
    assets=assets,
    bounds=bounds,  # pyright: ignore
    resolution=gsd,
    epsg=int(proj.split(":")[-1]),
    dtype="float64",  # pyright: ignore
    rescale=False,
    snap_bounds=True,
    resampling=Resampling.nearest,
    chunksize=(1, 1, 512, 512),
)
cloud_values = [0, 1, 2, 3, 8, 9, 10]  # cloud, shadows, cirrus, etc.
scl_mask = ds.sel(band="scl")
bands = ds.sel(band=["blue", "green", "red", "nir","swir16"])
valid_mask = ~scl_mask.isin(cloud_values)
bands_masked = bands.where(valid_mask)
median_ds = bands_masked.groupby("time.month").median("time", skipna=True)
median_ds = median_ds.fillna(0)
median_ds
<xarray.DataArray 'stackstac-432af8210ba1d087a72b9ed0b155cf6d' (month: 1,
                                                                band: 5,
                                                                y: 2048, x: 2048)> Size: 168MB
dask.array<where, shape=(1, 5, 2048, 2048), dtype=float64, chunksize=(1, 2, 1024, 1024), chunktype=numpy.ndarray>
Coordinates: (12/20)
  * band                                     (band) <U6 120B 'blue' ... 'swir16'
  * x                                        (x) float64 16kB 5.014e+05 ... 5...
  * y                                        (y) float64 16kB 4.254e+06 ... 4...
    s2:datatake_type                         <U8 32B 'INS-NOBS'
    earthsearch:boa_offset_applied           bool 1B False
    s2:processing_baseline                   <U5 20B '03.01'
    ...                                       ...
    s2:degraded_msi_data_percentage          int64 8B 0
    processing:software                      object 8B {'sentinel2-to-stac': ...
    mgrs:latitude_band                       <U1 4B 'S'
    s2:product_type                          <U7 28B 'S2MSI2A'
    epsg                                     int64 8B 32634
  * month                                    (month) int64 8B 1
Attributes:
    spec:        RasterSpec(epsg=32634, bounds=(501400, 4213100, 542360, 4254...
    crs:         epsg:32634
    transform:   | 20.00, 0.00, 501400.00|\n| 0.00,-20.00, 4254060.00|\n| 0.0...
    resolution:  20
median_ds = median_ds.compute()
/Users/syam/virtualenvs/myvenv/lib/python3.13/site-packages/dask/_task_spec.py:744: RuntimeWarning: All-NaN slice encountered
  return self.func(*new_argspec, **kwargs)

Visualize the ROI#

median_ds.sel(band=["red", "green", "blue"]).plot.imshow(
    row="month", rgb="band", vmin=0, vmax=2000, col_wrap=6
)
<xarray.plot.facetgrid.FacetGrid at 0x168ffe510>
../_images/9f0d226a9d2736f632feda3be16e6da30c91c95fdd22def396679c5ae924c2bf.png

Get data as tensor#

pixels = torch.from_numpy(median_ds.data.astype(np.float32))
pixels.shape
torch.Size([1, 5, 2048, 2048])

Convert the tensor to patches of size 5x64x64#

patch_size = 64
stride = 64
batch_size, bands, height, width = pixels.shape
patches = F.unfold(
    pixels, kernel_size=patch_size, stride=stride
)  # (BATCH, BANDS*PATCH_SIZE*PATCH_SIZE, NUM_PATCHES)
patches = patches.permute(0, 2, 1)  # (BATCH, NUM_PATCHES, BANDS*PATCH_SIZE*PATCH_SIZE)
patches = patches.view(
    batch_size, -1, bands, patch_size, patch_size
)  # (BATCH, NUM_PATCHES, BANDS, PATCH_SIZE, PATCH_SIZE)
patches.shape
torch.Size([1, 1024, 5, 64, 64])
plt.imshow(patches[0,50,0:3,:,:].permute(1,2,0).numpy()/10000*2.5)
plt.axis("off")
plt.show()
../_images/1406dd8f1697897d815ff19a6a593746577a77cb8058cfa5c7bb5212aed78fbb.png
patches = patches.reshape(-1, 5, 64, 64)
print(patches.shape)
torch.Size([1024, 5, 64, 64])

Initialize the model#

model = UNet(
    in_channels=5,
    out_classes=4,
    dropout=0.0
)
model_path = "/Users/syam/Documents/code/meditwin/meditwin-summer-school/docs/chapter4/esa_worldcover_classification/models/best_model.pth"
model.load_state_dict(torch.load(model_path,weights_only=True))
model.to("mps")
model.eval()
UNet(
  (enc1): Sequential(
    (0): Conv2d(5, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): Dropout2d(p=0.0, inplace=False)
  )
  (enc2): Sequential(
    (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): Dropout2d(p=0.0, inplace=False)
  )
  (enc3): Sequential(
    (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): Dropout2d(p=0.0, inplace=False)
  )
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (bottleneck): Sequential(
    (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): Dropout2d(p=0.0, inplace=False)
  )
  (dec3): Sequential(
    (0): Conv2d(768, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): Dropout2d(p=0.0, inplace=False)
  )
  (dec2): Sequential(
    (0): Conv2d(384, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): Dropout2d(p=0.0, inplace=False)
  )
  (dec1): Sequential(
    (0): Conv2d(192, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): Dropout2d(p=0.0, inplace=False)
  )
  (dec0): Sequential(
    (0): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): Dropout2d(p=0.0, inplace=False)
  )
  (final): Conv2d(32, 4, kernel_size=(1, 1), stride=(1, 1))
)

Perform inference on the patches#

norm_patches = patches / 10000.0
norm_patches = norm_patches.clip(min=0.0, max=1.0)
norm_patches = norm_patches.to("mps")
with torch.no_grad():
    predictions = model(norm_patches)
masks = torch.argmax(predictions, dim=1)
# import gc
# norm_patches = patches / 10000.0
# norm_patches = norm_patches.clip(min=0.0, max=1.0)
# #norm_patches = norm_patches.to("cuda")
# batch_size = 1  # Reduce this to fit memory
# predictions = []
# torch.cuda.empty_cache()
# with torch.no_grad():
#     for i in range(0, norm_patches.size(0), batch_size):
#         batch = norm_patches[i:i+batch_size].to('cuda')
#         out = model(batch)
#         predictions.append(out.cpu())  # Move to CPU to save GPU memory
#         del batch, out  # Free up GPU memory
#         gc.collect()
#         torch.cuda.empty_cache()
# #with torch.no_grad():
# #    predictions = model(norm_patches)
# predictions = torch.cat(predictions, dim=0)
# masks = torch.argmax(predictions, dim=1)

Show an example of the inference result#

idx = 1020 #1000#170#230
cmap = ListedColormap(
    [
        "black",  # 0 = ignore
        "yellow",  # 1 = other
        "blue",  # 2 = water
        "red",  # 3 = building
    ]
)
bounds = np.arange(0, 5)
norm = BoundaryNorm(bounds, cmap.N)
fig, axs = plt.subplots(1, 2, figsize=(10, 7))

rgb = norm_patches[idx, 0:3].detach().cpu().permute(1, 2, 0).numpy() * 2.5
rgb = rgb[..., [2, 1, 0]]
mask_nodata = rgb[..., 0] == 0

msk = masks[idx].detach().cpu().numpy()
msk[mask_nodata] = 0
msk = msk.astype(np.uint8)

im0 = axs[0].imshow(msk, cmap=cmap, norm=norm)
axs[0].set_title("Label Mask")

cbar = fig.colorbar(im0, ax=axs[0], shrink=0.5, ticks=np.arange(0.5, 4.5))
cbar.ax.set_yticklabels(["Ignore", "Others", "Water", "Building"])

axs[1].imshow(rgb)
axs[1].set_title("Input RGB Patch")

for ax in axs:
    ax.axis("off")

plt.tight_layout()
plt.show()
../_images/7b379d2580429c177946753d08c661e951b1590ae5313fd0ae9b703a0ebcace5.png

Reconstruct the patch predictions to the original image size#

B, C, H, W = pixels.shape
num_patches_per_img = (H // patch_size) * (W // patch_size)
masks = masks.to(torch.float32)
masks = masks.reshape(B, num_patches_per_img, patch_size, patch_size)
masks = masks.reshape(B, num_patches_per_img, -1).permute(0, 2, 1)
reconstructed = F.fold(
    masks, output_size=(H, W), kernel_size=patch_size, stride=stride
)
/Users/syam/virtualenvs/myvenv/lib/python3.13/site-packages/torch/nn/functional.py:5561: UserWarning: The operator 'aten::col2im' is not currently supported on the MPS backend and will fall back to run on the CPU. This may have performance implications. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/mps/MPSFallback.mm:14.)
  return torch._C._nn.col2im(

Visualize the full image prediction#

reconstructed = reconstructed.squeeze(1)
reconstructed = reconstructed.cpu().numpy()[0,...]
plt.imshow(reconstructed, cmap=cmap, norm=norm)
plt.axis("off")
plt.show()
../_images/c9431e53c8577d7b76658c9192245210f4693a64c2e450b0e48081a5c8d37fbc.png

Save the prediction as a GeoTIFF#

transform = median_ds.attrs["transform"]
crs = median_ds.attrs["crs"]
with rasterio.open(
    "../data/esa_classification_2022_4_classes.tif",
    "w",
    driver="GTiff",
    height=reconstructed.shape[0],
    width=reconstructed.shape[1],
    count=1,
    dtype=reconstructed.dtype,
    crs=crs,
    transform=transform,
) as dst:
    dst.write(reconstructed[np.newaxis, ...])