Day4 - ML Pipelines#

Pipeline#

It is a declarative description of a workflow or process, typically in code or configuration, that specifies:

  • What tasks should be run

  • How they are connected or ordered

  • What inputs and outputs are involved

  • Any conditions, retries, or resources needed

Orchestrator#

An orchestrator is a system or tool that manages and coordinates the execution of complex workflows or tasks, especially when those tasks involve multiple steps, dependencies, and resources.

Dagster#

Dagster

An open-source data orchestrator for ML, analytics and (Extract, Transform, Load) ETL. It enables to build, run and monitor complex data pipelines. Dagster offers:

  • Declarative pipeline definitions (data dependencies and configuration).

  • Type-safe operations.

  • Native support for assets, schedules, and sensors.

  • Integration with popular data tools (e.g., dbt, Spark, MLFlow).

Core Concepts#

Concept

Description

@op

A function that performs a unit of work

@job

A directed graph of ops

@asset

A first-class, versioned data product

Graph

A reusable composition of ops

Resource

External dependency like S3, DB, API

Sensor / Schedule

Triggers jobs by event/time

Getting Started#

  • Install Dagster:

    pip install dagster dagit
    
  • Initialize a new Dagster project:

    dagster project scaffold --name dagster_tutorial
    cd dagster_tutorial
    
  • Run the Dagster development server:

    dagster dev
    
  • Open the Dagit UI in your browser at http://localhost:3000.

  • Create a new file ops.py in the subdirectory dagster_tutorial and add the following code:

    from dagster import op
    @op
    def get_numbers():
        return [1, 2, 3]
    @op
    def multiply(numbers):
        return [x * 10 for x in numbers]
    
  • Create a new file jobs.py in the subdirectory dagster_tutorial and add the following code:

    from dagster import job
    from .ops import get_numbers, multiply
    @job
    def process_job():
        multiply(get_numbers())
    
  • In the definitions.py file, import the job and add it to the repository:

    from dagster import Definitions, load_assets_from_modules
    from dagster_tutorial import assets  # noqa: TID252
    from dagster_tutorial.jobs import process_job
    all_assets = load_assets_from_modules([assets])
    defs = Definitions(
        assets=all_assets,
        jobs=[process_job],
    )
    
  • Modify the multiply function in ops.py to get runtime config:

    from typing import List
    from dagster import op, Config
    class MultiplyConfig(Config):
        factor: int
    @op
    def multiply(config:MultiplyConfig,numbers:List[int]):
        return [x * config.factor for x in numbers]
    
  • In the launchpad, you can now run the process_job with a configuration:

    ops:
      multiply:
        config:
          factor: 10
    
  • You can enable logging in your Dagster project by adding this anywhere you want to add logging:

    from dagster import get_dagster_logger
    logger = get_dagster_logger()
    logger.info("This is an info message")
    
  • You can also use assets to define, persist and version your data products. For example, you can create a new file assets.py in the subdirectory dagster_tutorial and add the following code:

    from dagster import asset
    @asset
    def raw_data():
        return [1, 2, 3, 4]
    @asset
    def squared_data(raw_data):
        return [x**2 for x in raw_data]
    
  • You can add a scheduler to run a job at a specific time. To do that, add the following code to the definitions.py file:

    from dagster import ScheduleDefinition
    hourly_schedule = ScheduleDefinition(
        job=process_job,
        cron_schedule="0 * * * *",  # Every hour
    )
    defs = Definitions(
      assets=all_assets,
      jobs=[process_job],
      schedules=[hourly_schedule],
    
    )
    

Mlflow#

Mlflow

  • Open-source platform for managing the ML lifecycle, including experimentation, reproducibility, and deployment.

  • It provides a central repository for tracking experiments, packaging code into reproducible runs, and sharing and deploying models.

  • It has four main components:

    • Tracking: Log and query experiments.

    • Projects: Package code in a reusable and reproducible way.

    • Models: Manage and deploy models from various ML libraries.

    • Registry: Store and manage models in a central repository.

  • Dagster can be used to orchestrate ML workflows and integrate with MLflow for tracking experiments and managing models.

  • You can use Dagster to define and run ML pipelines, and use MLflow to log and track experiments, models, and artifacts.

  • You can use Dagster’s @op decorator to define MLflow operations, and use MLflow’s Python API to log and track experiments.

  • You can use Dagster’s @job decorator to define MLflow jobs, and use MLflow’s Python API to log and track experiments.

  • You can use Dagster’s @asset decorator to define MLflow assets, and use MLflow’s Python API to log and track experiments.

  • You can use Dagster’s @schedule decorator to define MLflow schedules, and use MLflow’s Python API to log and track experiments.

  • You can use Dagster’s @sensor decorator to define MLflow sensors, and use MLflow’s Python API to log and track experiments.

To test MLFlow access, you can run the following python code:

import os
from dotenv import load_dotenv
import mlflow

load_dotenv("../../.env")

MLFLOW_SERVER_URL = os.getenv("MLFLOW_SERVER_URL", "http://localhost:5000")
MLFLOW_TRACKING_USERNAME = os.getenv("MLFLOW_TRACKING_USERNAME")
MLFLOW_TRACKING_PASSWORD = os.getenv("MLFLOW_TRACKING_PASSWORD")

MY_PREFIX = "mohanad-experiment"
os.environ["MLFLOW_TRACKING_USERNAME"] = MLFLOW_TRACKING_USERNAME
os.environ["MLFLOW_TRACKING_PASSWORD"] = MLFLOW_TRACKING_PASSWORD


mlflow.set_tracking_uri(MLFLOW_SERVER_URL)
mlflow.set_experiment(f"/{MY_PREFIX}/classification")
with mlflow.start_run():
    mlflow.log_metric("metric1", 1.0)

Dagster + Xarray + Dask to Train ERA5 Forecasting Model#

You can run the advanced Dagster project era5_forecast that integrates xarray, Dask with Dagster.

Instructions#

1- Create the dagster project:

dagster project scaffold -n era5_forecast
cd era5_forecast

2- Assuming the dependencies are installed, run the dagster server:

dagster dev

3- Open the Dagit UI in your browser at http://localhost:3000. 4- On jobs tab, select the era5_forecast job and click on the launchpad button to run it.

5- Open the Dask UI in your browser at http://localhost:8787/status.

ops.py

  1import os
  2from dagster import op, Out, get_dagster_logger
  3import dask
  4import numpy as np
  5import xarray as xr
  6from xgboost.dask import DaskDMatrix, train
  7from .resources import DaskResource
  8
  9logger = get_dagster_logger()
 10
 11# ----------- CONFIG -----------
 12ZARR_URL = "gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3"
 13LAGS = list(range(1, 7))  # 6 days autoregressive lag features
 14
 15
 16
 17
 18def create_lagged_features(da, lags):
 19    lagged = [da.shift(time=lag).rename(f"lag_{lag}") for lag in lags]
 20    return xr.merge(lagged + [da.rename("target")])
 21
 22
 23# ----------- OPS -----------
 24@op
 25def load_and_preprocess_data():
 26    ds = xr.open_zarr(
 27        ZARR_URL,
 28        chunks={"time": 1},  # type: ignore
 29        storage_options={"token": "anon"},
 30    )
 31    temp = ds["2m_temperature"].sel(
 32        time=slice("2018-01-01", "2018-01-31"),
 33        # latitude=slice(40, 50),
 34        # longitude=slice(0, 10),
 35    )
 36    daily = temp.resample(time="1D").mean()
 37    daily = daily.persist()
 38    return daily
 39
 40
 41@op(out={"features": Out(), "labels": Out()})
 42def generate_training_data(daily):
 43    ds_lagged = create_lagged_features(daily, lags=LAGS)
 44    X = xr.concat([ds_lagged[f"lag_{lag}"] for lag in LAGS], dim="feature")
 45    X = (
 46        X.stack(sample=("time", "latitude", "longitude"))
 47        .transpose("sample", "feature")
 48        .data
 49    )
 50    y = ds_lagged["target"].chunk({"time": -1, "latitude": 10, "longitude": 10})
 51    y = y.stack(sample=("time", "latitude", "longitude")).data
 52    return X, y
 53
 54
 55# @op
 56# def train_model(my_dask_resource: DaskResource, X, y):
 57#     client = my_dask_resource.make_dask_cluster()
 58#     logger.info("Dask dashboard link %s", client.dashboard_link)
 59#     sample_frac = 0.05  # or whatever % you'd like
 60#     logger.info(f"Sampling {sample_frac*100:.0f}% of data before rechunking")
 61
 62#     n_samples = X.shape[0]
 63#     sample_size = int(n_samples * sample_frac)
 64#     random_indices = np.random.permutation(n_samples)[:sample_size]
 65#     X = X[random_indices]
 66#     y = y[random_indices]
 67#     logger.info("X.shape: %s", X.shape)
 68#     logger.info("y.shape: %s", y.shape)
 69#     chunk_size = 32  # or whatever fits your cluster
 70#     logger.info("Rechunking X and y...")
 71#     X = X.rechunk((chunk_size, -1))  # all columns together in each chunk
 72#     y = y.rechunk((chunk_size,))  # chunk by rows for labels
 73#     logger.info("Creating DaskDMatrix...")
 74#     dtrain = DaskDMatrix(client, X, y)
 75#     logger.info("Starting training...")
 76#     output = train(
 77#         client,
 78#         {"verbosity": 2, "tree_method": "hist", "objective": "reg:squarederror"},
 79#         dtrain,
 80#         num_boost_round=100,
 81#         evals=[(dtrain, "train")],
 82#         early_stopping_rounds=4,
 83#     )
 84#     booster = output["booster"]
 85#     os.makedirs("models", exist_ok=True)
 86#     booster.save_model("models/xgb_model.json")
 87#     client.close()
 88
 89@op
 90def train_model(my_dask_resource: DaskResource, X, y):
 91    client = my_dask_resource.make_dask_cluster()
 92    logger.info("Dask dashboard link %s", client.dashboard_link)
 93
 94    sample_frac = 0.05
 95    logger.info(f"Sampling {sample_frac * 100:.0f}% of data before rechunking")
 96
 97    n_samples = X.shape[0]
 98    sample_size = int(n_samples * sample_frac)
 99    random_indices = np.random.permutation(n_samples)[:sample_size]
100
101    # Turn indices into a Dask array for indexing
102    random_indices = dask.array.from_array(random_indices, chunks=(sample_size,))
103
104    # Sample X and y using Dask indexing
105    X = X[random_indices]
106    y = y[random_indices]
107
108    logger.info("X.shape: %s", X.shape)
109    logger.info("y.shape: %s", y.shape)
110
111    logger.info("Splitting sampled data into train/val sets using Dask...")
112    val_frac = 0.2
113    val_size = int(sample_size * val_frac)
114    train_size = sample_size - val_size
115
116
117    X_train = X[:train_size]
118    y_train = y[:train_size]
119    X_val = X[train_size:]
120    y_val = y[train_size:]
121
122    chunk_size = 32
123    X_train = X_train.rechunk((chunk_size, -1))
124    y_train = y_train.rechunk((chunk_size,))
125    X_val = X_val.rechunk((chunk_size, -1))
126    y_val = y_val.rechunk((chunk_size,))
127
128    logger.info("Creating DaskDMatrix for training and validation...")
129    dtrain = DaskDMatrix(client, X_train, y_train)
130    dval = DaskDMatrix(client, X_val, y_val)
131
132    logger.info("Starting training with validation...")
133    output = train(
134        client,
135        {
136            "verbosity": 2,
137            "tree_method": "hist",
138            "objective": "reg:squarederror",
139        },
140        dtrain,
141        num_boost_round=100,
142        evals=[(dtrain, "train"), (dval, "validation")],
143        early_stopping_rounds=4,
144    )
145    booster = output["booster"]
146    best_iteration = booster.best_iteration
147    logger.info(f"Training stopped at iteration: {best_iteration}")
148    os.makedirs("models", exist_ok=True)
149    booster.save_model("models/xgb_model.json")
150    logger.info("Model saved to models/xgb_model.json")
151
152    client.close()

resources.py

 1from dagster import ConfigurableResource
 2from dask.distributed import Client, LocalCluster
 3
 4
 5class DaskResource(ConfigurableResource):
 6    n_workers: int
 7
 8    def make_dask_cluster(self) -> Client:
 9        client = Client(LocalCluster(n_workers=self.n_workers))
10        return client

jobs.py

1from dagster import job
2from .ops import generate_training_data, load_and_preprocess_data, train_model
3
4
5@job
6def training_pipeline():
7    daily = load_and_preprocess_data()
8    X, y = generate_training_data(daily)
9    train_model(X, y)  # pylint:disable=E1120

definitions.py

 1from dagster import Definitions, load_assets_from_modules
 2
 3from era5_forecast import assets  # noqa: TID252
 4from era5_forecast.jobs import training_pipeline
 5from era5_forecast.resources import DaskResource  # noqa: TID252
 6
 7all_assets = load_assets_from_modules([assets])
 8
 9defs = Definitions(
10    assets=all_assets,
11    jobs=[training_pipeline],
12    resources={"my_dask_resource": DaskResource(n_workers=4)},
13)

Dagster + Stackstac + Dask + MLflow to Train Sentinel-2 Land Cover Classification Model#

You can run the advanced Dagster project esa_worldcover_classification that integrates xarray, Dask with Dagster and MLflow.

Note

The ESA WorldCover dataset is a global land cover map produced by the European Space Agency, offering 10-meter resolution classification based on Sentinel-1 and Sentinel-2 satellite imagery. Released in 2021, it provides detailed and consistent information on land cover types such as forests, croplands, urban areas, and water bodies. Designed to support environmental monitoring, climate change studies, and sustainable land management, WorldCover is freely accessible and regularly updated, making it a valuable resource for researchers, policymakers, and Earth observation applications worldwide. The ESA WorldCover dataset includes 11 land cover classes, based on the UN FAO Land Cover Classification System (LCCS). Here’s the list of classes:

Class ID

Land Cover Class

10

Tree cover

20

Shrubland

30

Grassland

40

Cropland

50

Built-up

60

Bare / sparse vegetation

70

Snow and ice

80

Permanent water bodies

90

Herbaceous wetland

95

Mangroves

100

Moss and lichen

WorldCover Documentation

Instructions#

1- Run the mlflow server:

mlflow server \
  --backend-store-uri sqlite:///mlflow.db \
  --default-artifact-root ./mlruns \
  --host 0.0.0.0 \
  --port 8000

2- Create the dagster project:

dagster project scaffold -n eas_worldcover_classification
cd eas_worldcover_classification

3- Assuming the dependencies are installed, run the dagster server:

dagster dev

4- Open the Dagit UI in your browser at http://localhost:3000.

5- On jobs tab, select the esa_worldcover_classification job and click on the launchpad button to run it. Use the following configuration:

Note

Change the mlflow experiment name with your name, e.g. /mohanad_s2_classification. Make sure to have the mlflow server URL, username and password set in the .env file.

ops:
  fetch_s2_stack:
    config:
      bbox: [21.0, 38.0, 21.5, 38.5]
      time_range: "2020-01-01/2020-01-31"

  fetch_worldcover_stack:
    config:
      bbox: [21.0, 38.0, 21.5, 38.5]
      time_range: "2020-01-01/2020-01-31"

  save_to_zarr:
    config:
      zarr_cache_dir: "cache"

  train_unet:
    config:
      patch_size: 64
      stride: 32
      batch_size: 16
      model: "unet"
      epochs: 1
      learning_rate: 0.001
      loss: "cross_entropy"
      num_workers: 4
      in_channels: 5
      out_classes: 4
      mlflow_tracking_uri: null
      mlflow_experiment_name: "s2_classification"
      model_path: "models/unet_model.pth"
      zarr_cache_dir: "cache"
      device: "mps"

6- Open the MLflow UI in your browser at http://localhost:8000.

ops.py

  1import gc
  2import os
  3from typing import Optional
  4from dagster import op, Out, Config, get_dagster_logger
  5from pystac_client import Client
  6import planetary_computer
  7import requests
  8import stackstac
  9import xarray as xr
 10from rasterio.enums import Resampling
 11from .utils import clean_attrs, clean_coords, map_labels, transform_bbox
 12from .configurations import (
 13    AWS_STAC_API,
 14    CLASS_MAPPING,
 15    PLANETARY_COMPUTER_STAC_API,
 16    PLANETARY_COMPUTER_TOKEN_URL,
 17)
 18from .train import train_model
 19
 20
 21logger = get_dagster_logger()
 22
 23
 24class StackConfig(Config):
 25    bbox: list[float]
 26    time_range: str
 27
 28
 29class ZarrConfig(Config):
 30    zarr_cache_dir: str
 31
 32
 33class TrainUNetConfig(Config):
 34    patch_size: int
 35    stride: int
 36    num_workers: int
 37    batch_size: int
 38    model: str
 39    epochs: int
 40    learning_rate: float
 41    loss: str
 42    in_channels: int
 43    out_classes: int
 44    mlflow_tracking_uri: Optional[str] = None
 45    mlflow_experiment_name: str
 46    model_path: str
 47    zarr_cache_dir: str
 48    device: str
 49
 50
 51@op(out={"pixels_output": Out(), "proj_output": Out(), "valid_mask": Out()})
 52def fetch_s2_stack(config: StackConfig) -> tuple[xr.DataArray, str, xr.DataArray]:
 53    catalog = Client.open(AWS_STAC_API)
 54    search = catalog.search(
 55        collections=["sentinel-2-l2a"],
 56        bbox=config.bbox,
 57        datetime=config.time_range,
 58        query={
 59            "eo:cloud_cover": {"lt": 10},
 60            "s2:nodata_pixel_percentage": {"lt": 10},
 61        },
 62        max_items=5000,
 63    )
 64    all_items = search.item_collection()
 65    items = []
 66    granules = []
 67    for item in all_items:
 68        if item.properties["s2:granule_id"] not in granules:
 69            items.append(item)
 70            granules.append(item.properties["s2:granule_id"])
 71
 72    logger.info("Found %d Sentinel-2-L2A items", len(items))
 73    proj = items[0].properties["proj:code"]
 74    bbox_utm = transform_bbox(config.bbox, items[0].properties["proj:code"])
 75    assets = ["blue", "green", "red", "nir", "swir16", "scl"]
 76    ds = stackstac.stack(
 77        items,
 78        assets=assets,
 79        bounds=bbox_utm,  # pyright: ignore
 80        resolution=20,
 81        epsg=int(proj.split(":")[-1]),
 82        dtype="float64",  # pyright: ignore
 83        rescale=False,
 84        snap_bounds=True,
 85        resampling=Resampling.nearest,
 86        chunksize=(1, 1, 512, 512),
 87    )
 88
 89    cloud_values = [0, 1, 2, 3, 8, 9, 10]  # cloud, shadows, cirrus, etc.
 90    scl_mask = ds.sel(band="scl")
 91    bands = ds.sel(band=["blue", "green", "red", "nir", "swir16"])
 92    valid_mask = ~scl_mask.isin(cloud_values)
 93    bands_masked = bands.where(valid_mask)
 94    median_ds = bands_masked.groupby("time.month").median("time", skipna=True)
 95    valid_mask_median = ~xr.ufuncs.isnan(median_ds.isel(band=0))
 96    median_ds = median_ds.fillna(0)
 97    logger.info("Median dataset shape: %s", median_ds.shape)
 98    return (median_ds, proj, valid_mask_median)
 99
100
101@op
102def fetch_worldcover_stack(
103    config: StackConfig, proj: str, valid_mask: xr.DataArray
104) -> xr.DataArray:
105    logger.info("Projection: %s", proj)
106    catalog = Client.open(PLANETARY_COMPUTER_STAC_API)
107    response = requests.get(
108        f"{PLANETARY_COMPUTER_TOKEN_URL}/esa-worldcover", timeout=10
109    )
110
111    if response.status_code == 200:
112        response = response.json()
113        token = response["token"]
114        _ = {"Authorization": f"Bearer {token}"}
115    else:
116        print(f"Failed to get token. Status code: {response.status_code}")
117        exit()
118    search = search = catalog.search(
119        collections=["esa-worldcover"],
120        bbox=config.bbox,
121        query={
122            "start_datetime": {"eq": "2020-01-01T00:00:00Z"},
123            "end_datetime": {"eq": "2020-12-31T23:59:59Z"},
124        },
125    )
126    all_items = search.item_collection()
127    items = [planetary_computer.sign_item(item) for item in all_items]
128    logger.info("Found %d ESA World Cover items", len(items))
129    bbox_utm = transform_bbox(config.bbox, proj)
130    ds = stackstac.stack(
131        items,
132        assets=["map"],
133        bounds=bbox_utm,  # pyright: ignore
134        resolution=20,
135        epsg=int(proj.split(":")[-1]),
136        dtype="float64",  # pyright: ignore
137        rescale=False,
138        snap_bounds=True,
139        resampling=Resampling.nearest,
140        chunksize=(1, 1, 512, 512),
141    )
142    ds = ds.sel(time=ds.notnull().any(dim=["x", "y", "band"]))
143    ds = ds.where(valid_mask)
144    ds = ds.fillna(0)
145    logger.info("WorldCover dataset shape: %s", ds.shape)
146    return ds
147
148
149@op
150def save_to_zarr(
151    config: ZarrConfig, features: xr.DataArray, labels: xr.DataArray
152) -> str:
153    features = features.squeeze()
154    labels = labels.squeeze()
155    features = features.drop_vars(
156        [dim for dim in ["time", "month"] if dim in features.dims]
157    )
158    labels = labels.drop_vars([dim for dim in ["time", "month"] if dim in labels.dims])
159    features, labels = xr.align(features, labels, join="inner")
160    labels = map_labels(labels, CLASS_MAPPING)
161    img_zarr_path = os.path.join(config.zarr_cache_dir, "img.zarr")
162    mask_zarr_path = os.path.join(config.zarr_cache_dir, "mask.zarr")
163    if not os.path.exists(img_zarr_path) and not os.path.exists(mask_zarr_path):
164        features = features.chunk({"x": 512, "y": 512})
165        labels = labels.chunk({"x": 512, "y": 512})
166        features.data = features.data.compute_chunk_sizes()
167        labels.data = labels.data.compute_chunk_sizes()
168        features = clean_attrs(features)
169        features = clean_coords(features)
170        features.to_zarr(img_zarr_path)
171
172        labels = clean_attrs(labels)
173        labels = clean_coords(labels)
174        labels.to_zarr(mask_zarr_path)
175        del features
176        del labels
177        gc.collect()
178    return config.zarr_cache_dir
179
180
181@op
182def train_unet(config: TrainUNetConfig, zarr_dir: str):  # pylint:disable=W0613
183    train_model(config)

model.py

  1import torch
  2import torch.nn as nn
  3import torch.nn.functional as F
  4from transformers.models.segformer import (
  5    SegformerForSemanticSegmentation,
  6    SegformerConfig,
  7)
  8
  9
 10class UNetBaseLine(nn.Module):
 11    def __init__(self, in_channels=4, out_classes=11, dropout=0.2):
 12        super(UNetBaseLine, self).__init__()
 13
 14        def conv_block(in_ch, out_ch):
 15            return nn.Sequential(
 16                nn.Conv2d(in_ch, out_ch, 3, padding=1),
 17                nn.BatchNorm2d(out_ch),
 18                nn.ReLU(inplace=True),
 19                nn.Conv2d(out_ch, out_ch, 3, padding=1),
 20                nn.BatchNorm2d(out_ch),
 21                nn.ReLU(inplace=True),
 22                nn.Dropout2d(dropout),
 23            )
 24
 25        self.enc1 = conv_block(in_channels, 64)
 26        self.enc2 = conv_block(64, 128)
 27        self.enc3 = conv_block(128, 256)
 28
 29        self.bottleneck = conv_block(256, 512)
 30
 31        self.pool = nn.MaxPool2d(2)
 32
 33        self.dec2 = conv_block(512 + 128, 128)
 34        self.dec1 = conv_block(128 + 64, 64)
 35
 36        self.final = nn.Conv2d(64, out_classes, kernel_size=1)
 37
 38    def forward(self, x):
 39        e1 = self.enc1(x)
 40        e2 = self.enc2(self.pool(e1))
 41        e3 = self.enc3(self.pool(e2))
 42        b = self.bottleneck(self.pool(e3))
 43
 44        d2 = self._upsample_concat(b, e2)
 45        d2 = self.dec2(d2)
 46
 47        d1 = self._upsample_concat(d2, e1)
 48        d1 = self.dec1(d1)
 49
 50        return self.final(d1)  # (B, 11, H, W)
 51
 52    def _upsample_concat(self, x, skip):
 53        x = F.interpolate(x, size=skip.shape[2:], mode="bilinear", align_corners=False)
 54        return torch.cat([x, skip], dim=1)
 55
 56
 57class UNet(nn.Module):
 58    def __init__(self, in_channels=4, out_classes=11, dropout=0.1, init_weights=True):
 59        super(UNet, self).__init__()
 60
 61        def conv_block(in_ch, out_ch):
 62            return nn.Sequential(
 63                nn.Conv2d(in_ch, out_ch, 3, padding=1),
 64                nn.BatchNorm2d(out_ch),
 65                nn.ReLU(inplace=True),
 66                nn.Conv2d(out_ch, out_ch, 3, padding=1),
 67                nn.BatchNorm2d(out_ch),
 68                nn.ReLU(inplace=True),
 69                nn.Dropout2d(dropout),
 70            )
 71
 72        # Encoder
 73        self.enc1 = conv_block(in_channels, 64)
 74        self.enc2 = conv_block(64, 128)
 75        self.enc3 = conv_block(128, 256)
 76
 77        self.pool = nn.MaxPool2d(2)
 78
 79        # Bottleneck
 80        self.bottleneck = conv_block(256, 512)
 81
 82        # Decoder
 83        self.dec3 = conv_block(512 + 256, 256)
 84        self.dec2 = conv_block(256 + 128, 128)
 85        self.dec1 = conv_block(128 + 64, 64)
 86        self.dec0 = conv_block(64, 32)
 87
 88        # Final classifier
 89        self.final = nn.Conv2d(32, out_classes, kernel_size=1)
 90
 91        if init_weights:
 92            self._init_weights()
 93
 94    def _init_weights(self):
 95        for m in self.modules():
 96            if isinstance(m, nn.Conv2d):
 97                nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
 98                if m.bias is not None:
 99                    nn.init.zeros_(m.bias)
100
101    def forward(self, x):
102        e1 = self.enc1(x)  # (B, 64, H, W)
103        e2 = self.enc2(self.pool(e1))  # (B, 128, H/2, W/2)
104        e3 = self.enc3(self.pool(e2))  # (B, 256, H/4, W/4)
105        b = self.bottleneck(self.pool(e3))  # (B, 512, H/8, W/8)
106
107        d3 = self._upsample_concat(b, e3)  # (B, 512+256, H/4, W/4)
108        d3 = self.dec3(d3)
109
110        d2 = self._upsample_concat(d3, e2)  # (B, 256+128, H/2, W/2)
111        d2 = self.dec2(d2)
112
113        d1 = self._upsample_concat(d2, e1)  # (B, 128+64, H, W)
114        d1 = self.dec1(d1)
115
116        d0 = self.dec0(d1)  # (B, 32, H, W)
117
118        return self.final(d0)  # logits: (B, out_classes, H, W)
119
120    def _upsample_concat(self, x, skip):
121        x = F.interpolate(x, size=skip.shape[2:], mode="bilinear", align_corners=False)
122        return torch.cat([x, skip], dim=1)
123
124
125class SegformerB0FourBand(nn.Module):
126    def __init__(self, in_channels=4, out_classes=11):
127        super().__init__()
128
129        # Step 1: Load pretrained model to extract 3-channel conv weights
130        pretrained_model = SegformerForSemanticSegmentation.from_pretrained(
131            "nvidia/segformer-b0-finetuned-ade-512-512"
132        )
133        pretrained_conv = pretrained_model.segformer.encoder.patch_embeddings[0].proj
134        pretrained_weights = pretrained_conv.weight  # type: ignore # shape: [32, 3, 7, 7]
135
136        # Step 2: Create modified config
137        config = SegformerConfig.from_pretrained(
138            "nvidia/segformer-b0-finetuned-ade-512-512"
139        )
140        config.num_channels = in_channels
141        config.num_labels = out_classes
142
143        # Step 3: Create your model
144        self.model = SegformerForSemanticSegmentation(config)
145
146        # Step 4: Replace first conv with 4-band version and copy weights
147        old_conv = self.model.segformer.encoder.patch_embeddings[0].proj
148        new_conv = nn.Conv2d(
149            4,
150            old_conv.out_channels,  # type: ignore
151            kernel_size=old_conv.kernel_size,  # type: ignore
152            stride=old_conv.stride,  # type: ignore
153            padding=old_conv.padding,  # type: ignore
154            bias=old_conv.bias is not None,  # type: ignore
155        )
156        with torch.no_grad():
157            new_conv.weight[:, :3] = pretrained_weights  # type: ignore # use original RGB weights
158            new_conv.weight[:, 3] = pretrained_weights[:, 0]  # type: ignore # init 4th like Red
159            if old_conv.bias is not None:  # type: ignore
160                new_conv.bias.copy_(pretrained_conv.bias)  # type: ignore
161        self.model.segformer.encoder.patch_embeddings[0].proj = new_conv
162
163        # Step 5: Load other pretrained weights (except the mismatched ones)
164        state_dict = pretrained_model.state_dict()
165        for key in [
166            "segformer.encoder.patch_embeddings.0.proj.weight",
167            "decode_head.classifier.weight",
168            "decode_head.classifier.bias",
169        ]:
170            state_dict.pop(key, None)
171        self.model.load_state_dict(state_dict, strict=False)
172
173    def forward(self, x):
174        logits = self.model(pixel_values=x).logits
175        logits = F.interpolate(
176            logits, size=x.shape[-2:], mode="bilinear", align_corners=False
177        )
178        logits = logits.contiguous()
179        return logits

dataset.py

 1import os
 2import torch
 3from torch.utils.data import Dataset
 4import xarray as xr
 5from .utils import compute_valid_patch_indices
 6
 7
 8class XarrayPatchDataset(Dataset):
 9    def __init__(self, patch_size: int, stride: int, zarr_cache_dir: str,indices=None):
10        """
11        Dataset for extracting patches from xarray images, using precomputed valid patch indices.
12
13        Args:
14            patch_size (int): Size of square patches.
15            stride (int): Stride of sliding window.
16            zarr_cache_dir (str): Directory to store the cached data in Zarr format.
17        """
18        self.patch_size = patch_size
19        self.stride = stride
20        self.zarr_cache_dir = zarr_cache_dir
21        self.img_zarr_path = os.path.join(zarr_cache_dir, "img.zarr")
22        self.mask_zarr_path = os.path.join(zarr_cache_dir, "mask.zarr")
23        self.img = xr.open_zarr(self.img_zarr_path)
24        self.img = next(iter(self.img.data_vars.values()))
25        self.mask = xr.open_zarr(self.mask_zarr_path)
26        self.mask = next(iter(self.mask.data_vars.values()))
27        self.mask.data.compute_chunk_sizes()
28        if indices is None:
29            full_indices = compute_valid_patch_indices(self.mask, patch_size, stride, threshold=0.1)
30        self.valid_indices = (
31            indices if indices is not None else full_indices
32        )
33
34    def __len__(self):
35        """Return the number of valid patches."""
36        return len(self.valid_indices)
37
38    def __getitem__(self, idx):
39        """Get the patch at the given index."""
40        # Get the top-left coordinates of the patch
41        i, j = self.valid_indices[idx]
42
43        # Use Dask arrays directly for patch extraction
44        img_patch = self.img[:, i : i + self.patch_size, j : j + self.patch_size]
45        mask_patch = self.mask[i : i + self.patch_size, j : j + self.patch_size]
46
47        # Convert to Dask arrays (lazily loaded)
48        img_patch = img_patch.data  # Dask array
49        mask_patch = mask_patch.data  # Dask array
50
51        # Compute and convert to PyTorch tensors
52        x = torch.tensor(img_patch.compute(), dtype=torch.float32)
53        y = torch.tensor(mask_patch.compute(), dtype=torch.long)
54
55        # Scale and apply mask
56        x = x / 10000.0
57        x = x.clip(min=0.0, max=1.0)
58
59        return x, y

train.py

  1import os
  2from tqdm import tqdm
  3import torch
  4from torch.utils.data import DataLoader
  5import mlflow
  6from dagster import get_dagster_logger
  7import xarray as xr
  8from sklearn.model_selection import train_test_split
  9from torchmetrics import JaccardIndex
 10from torchmetrics.segmentation import DiceScore, MeanIoU
 11from dotenv import load_dotenv
 12from .loss import FocalLoss
 13from .model import UNet, UNetBaseLine, SegformerB0FourBand
 14from .dataset import XarrayPatchDataset
 15from .utils import compute_valid_patch_indices
 16
 17load_dotenv("../../.env")
 18
 19MLFLOW_SERVER_URL = os.getenv("MLFLOW_SERVER_URL")
 20MLFLOW_TRACKING_USERNAME = os.getenv("MLFLOW_TRACKING_USERNAME")
 21MLFLOW_TRACKING_PASSWORD = os.getenv("MLFLOW_TRACKING_PASSWORD")
 22S3_ACCESS_KEY=os.getenv("S3_ACCESS_KEY")
 23S3_SECRET_ACCESS_KEY=os.getenv("S3_SECRET_ACCESS_KEY")
 24S3_END_POINT=os.getenv("S3_END_POINT")
 25logger = get_dagster_logger()
 26
 27
 28def train_model(config):
 29    os.makedirs(os.path.dirname(config.model_path), exist_ok=True)
 30    if not config.mlflow_tracking_uri:
 31        mlflow_tracking_uri = MLFLOW_SERVER_URL
 32    else:
 33        mlflow_tracking_uri = config.mlflow_tracking_uri
 34    if config.model == "unet":
 35        model = UNet(
 36            in_channels=config.in_channels,
 37            out_classes=config.out_classes,
 38            dropout=0.2,
 39        )
 40    elif config.model == "segformer":
 41        model = SegformerB0FourBand(
 42            in_channels=config.in_channels,
 43            out_classes=config.out_classes,
 44        )
 45    else:
 46        model = UNetBaseLine(
 47            in_channels=config.in_channels, out_classes=config.out_classes
 48        )
 49    mask = xr.open_zarr(os.path.join(config.zarr_cache_dir, "mask.zarr"))
 50    mask = next(iter(mask.data_vars.values()))
 51    full_indices = compute_valid_patch_indices(
 52        mask,
 53        config.patch_size,
 54        config.stride,
 55        threshold=0.1
 56    )
 57    # Shuffle + split
 58    train_indices, val_indices = train_test_split(
 59        full_indices, test_size=0.2, random_state=42
 60    )
 61    train_dataset = XarrayPatchDataset(
 62        config.patch_size,
 63        config.stride,
 64        config.zarr_cache_dir,
 65        indices=train_indices
 66    )
 67    valid_dataset = XarrayPatchDataset(
 68        config.patch_size,
 69        config.stride,
 70        config.zarr_cache_dir,
 71        indices=val_indices
 72    )
 73    if config.loss == "focal":
 74        criterion = FocalLoss(alpha=None, gamma=2.0, ignore_index=0)
 75    else:
 76        criterion = torch.nn.CrossEntropyLoss(ignore_index=0)
 77    optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
 78    jaccard = JaccardIndex(
 79        task="multiclass",
 80        num_classes=config.out_classes,
 81        average="weighted",
 82        zero_division=0.0,
 83        ignore_index=0,
 84    ).to(config.device)
 85    dice = DiceScore(
 86        num_classes=config.out_classes, input_format="index", zero_division=0.0
 87    ).to(config.device)
 88    iou = MeanIoU(num_classes=config.out_classes, input_format="index").to(
 89        config.device
 90    )
 91    model.to(config.device)
 92    model.train()
 93    train_loader = DataLoader(
 94        train_dataset,
 95        batch_size=config.batch_size,
 96        shuffle=True,
 97        num_workers=config.num_workers,
 98        prefetch_factor=2,
 99    )
100    val_loader = DataLoader(
101        valid_dataset,
102        batch_size=config.batch_size,
103        shuffle=False,
104        num_workers=config.num_workers,
105        prefetch_factor=2,
106    )
107    n_train = len(train_loader)
108    n_val = len(val_loader)
109    best_val_loss = float("inf")
110    best_model_path = os.path.join(os.path.dirname(config.model_path), f"best_model.pth")
111    os.environ["MLFLOW_TRACKING_USERNAME"] = MLFLOW_TRACKING_USERNAME
112    os.environ["MLFLOW_TRACKING_PASSWORD"] = MLFLOW_TRACKING_PASSWORD
113    os.environ["AWS_ACCESS_KEY_ID"] = S3_ACCESS_KEY
114    os.environ["AWS_SECRET_ACCESS_KEY"] = S3_SECRET_ACCESS_KEY
115    os.environ["MLFLOW_S3_ENDPOINT_URL"] = S3_END_POINT
116    mlflow.set_tracking_uri(mlflow_tracking_uri)
117    mlflow.set_experiment(config.mlflow_experiment_name)
118    with mlflow.start_run():
119        mlflow.log_params(
120            {
121                "model": config.model,
122                "num_epochs": config.epochs,
123                "learning_rate": config.learning_rate,
124                "loss": config.loss,
125                "batch_size": config.batch_size,
126                "patch_size": config.patch_size,
127                "stride": config.stride,
128                "num_workers": config.num_workers,
129                "in_channels": config.in_channels,
130                "out_classes": config.out_classes,
131                "zarr_cache_dir": config.zarr_cache_dir,
132                "model_path": config.model_path,
133                "best_model_path": best_model_path,
134            }
135        )
136        for epoch in range(config.epochs):  # pylint:disable=W0612
137            progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}", leave=False)
138            total_train_loss = 0
139            total_train_dice = 0
140            total_train_jaccard = 0
141            total_train_iou = 0
142            total_val_loss = 0
143            total_val_dice = 0
144            total_val_jaccard = 0
145            total_val_iou = 0
146            model.train()
147            for _, (X_batch, y_batch) in enumerate(progress_bar):
148                X_batch = X_batch.to(config.device)
149                y_batch = y_batch.to(config.device)
150                out = model(X_batch)
151                loss = criterion(out, y_batch)
152                preds = torch.argmax(out, dim=1)
153                dice_coef = dice(preds, y_batch)
154                jaccard_index = jaccard(preds, y_batch)
155                iou_index = iou(preds, y_batch)
156                optimizer.zero_grad()
157                # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
158                loss.backward()
159                optimizer.step()
160                total_train_loss += loss.item()
161                total_train_dice += dice_coef.item()
162                total_train_jaccard += jaccard_index.item()
163                total_train_iou += iou_index.item()
164
165            mlflow.log_metric("loss", total_train_loss / n_train, step=epoch)
166            mlflow.log_metric("dice", total_train_dice / n_train, step=epoch)
167            mlflow.log_metric("jaccard", total_train_jaccard / n_train, step=epoch)
168            mlflow.log_metric("iou", total_train_iou / n_train, step=epoch)
169            model.eval()
170            with torch.no_grad():
171                for X_batch, y_batch in val_loader:
172                    X_batch = X_batch.to(config.device)
173                    y_batch = y_batch.to(config.device)
174                    out = model(X_batch)
175                    loss = criterion(out, y_batch)
176                    preds = torch.argmax(out, dim=1)
177                    total_val_loss += loss.item()
178                    total_val_dice += dice(preds, y_batch).item()
179                    total_val_jaccard += jaccard(preds, y_batch).item()
180                    total_val_iou += iou(preds, y_batch).item()
181            avg_val_loss = total_val_loss / n_val
182            mlflow.log_metric("val_loss", avg_val_loss, step=epoch)
183            mlflow.log_metric("val_dice", total_val_dice / n_val, step=epoch)
184            mlflow.log_metric("val_jaccard", total_val_jaccard / n_val, step=epoch)
185            mlflow.log_metric("val_iou", total_val_iou / n_val, step=epoch)
186            if avg_val_loss < best_val_loss:
187                best_val_loss = avg_val_loss
188                torch.save(model.state_dict(),best_model_path)
189                logger.info(f"Saved new best model at epoch {epoch+1} with val_loss={avg_val_loss:.4f}")
190        torch.save(model.state_dict(), config.model_path)
191        input_example = X_batch[0].unsqueeze(0).cpu().numpy()
192        mlflow.pytorch.log_model(model, artifact_path="last_model",input_example=input_example)
193        model.load_state_dict(torch.load(best_model_path))
194        mlflow.pytorch.log_model(model, artifact_path="best_model",input_example=input_example)

utils.py

  1import logging
  2import xarray as xr
  3import numpy as np
  4from pyproj import Transformer
  5import dask.array as da
  6from numpy.lib.stride_tricks import sliding_window_view
  7
  8
  9def transform_bbox(bbox_wgs84, dest_crs):
 10    transformer = Transformer.from_crs("epsg:4326", dest_crs, always_xy=True)
 11    minx, miny = transformer.transform(bbox_wgs84[0], bbox_wgs84[1])
 12    maxx, maxy = transformer.transform(bbox_wgs84[2], bbox_wgs84[3])
 13    bbox_utm = [minx, miny, maxx, maxy]
 14    return bbox_utm
 15
 16
 17def map_labels(labels: xr.DataArray, mapping: dict[int, int]) -> xr.DataArray:
 18    data = labels.values
 19    mapped = np.full_like(data, fill_value=0, dtype=np.int32)
 20    for original, new in mapping.items():
 21        mapped[data == original] = new
 22    return xr.DataArray(mapped, coords=labels.coords, dims=labels.dims)
 23
 24
 25def clean_attrs(da_or_ds):
 26    # Only keep simple, JSON-serializable entries in attrs
 27    da_or_ds.attrs = {
 28        k: v
 29        for k, v in da_or_ds.attrs.items()
 30        if isinstance(v, (str, int, float, list, dict, bool, type(None)))
 31    }
 32    return da_or_ds
 33
 34
 35def clean_coords(  # pylint:disable=W0102
 36    ds: xr.DataArray, keep: list[str] = []
 37) -> xr.DataArray:
 38    """
 39    Remove non-dimension coordinates and attributes from an xarray Dataset,
 40    except those explicitly listed in `keep`.
 41
 42    Parameters:
 43    ----------
 44    ds : xr.Dataset
 45        The dataset to clean.
 46    keep : list of str, optional
 47        List of coordinate names to keep, even if they are non-dimension coords.
 48
 49    Returns:
 50    -------
 51    xr.Dataset
 52        Cleaned dataset with only dimension coordinates and selected metadata.
 53    """
 54    # Get all dimension coordinates
 55    dim_coords = list(ds.dims)
 56
 57    # Identify coordinates to drop (non-dim, not explicitly kept)
 58    drop_coords = [c for c in ds.coords if c not in dim_coords and c not in keep]
 59
 60    # Drop extra coordinates
 61    ds = ds.drop_vars(drop_coords)
 62
 63    return ds
 64
 65
 66def compute_valid_patch_indices(
 67    mask: xr.DataArray, patch_size: int, stride: int, threshold: float = 0.1
 68):
 69    """
 70    Compute top-left coordinates of valid patches based on non-NaN ratio in each patch.
 71
 72    Args:
 73        mask (xr.DataArray): 2D xarray DataArray (e.g. land cover labels or cloud mask).
 74        patch_size (int): Size of square patches.
 75        stride (int): Stride of sliding window.
 76        threshold (float): Minimum fraction of valid (non-NaN) pixels in a patch.
 77
 78    Returns:
 79        List of (i, j) tuples (top-left corners of valid patches).
 80    """
 81    # Convert to dask array and binarize: 1 = valid, 0 = NaN
 82    binary_mask = (~da.isnan(mask.data)).astype("uint8")  # type: ignore
 83
 84    # Convert to NumPy (only metadata) to get shape info
 85    _, _ = binary_mask.shape
 86    logging.info("Shape of binary mask: %s", binary_mask.shape)
 87    # Sliding windows over 2D: shape becomes (num_patches_y, num_patches_x, patch_size, patch_size)
 88    sw = sliding_window_view(binary_mask, (patch_size, patch_size))[::stride, ::stride]
 89
 90    # Compute mean valid fraction per patch (still lazy)
 91    patch_valid_fraction = sw.mean(axis=(-1, -2))
 92
 93    # Evaluate only the valid locations
 94    patch_valid_fraction = patch_valid_fraction.compute()
 95
 96    # Get coordinates where patch validity exceeds threshold
 97    valid_coords = np.argwhere(patch_valid_fraction >= threshold)
 98    # Map patch indices back to full image pixel coords
 99    indices = [(i * stride, j * stride) for i, j in valid_coords]
100
101    return indices

jobs.py

 1from dagster import job
 2from .ops import fetch_s2_stack, fetch_worldcover_stack, save_to_zarr, train_unet
 3
 4
 5@job
 6def s2_worldcover_landcover_classification_pipeline():
 7    features, proj, mask = fetch_s2_stack()  # pylint:disable=E1120
 8    labels = fetch_worldcover_stack(proj, mask)  # pylint:disable=E1120
 9    zarr_dir = save_to_zarr(features, labels)  # pylint:disable=E1120
10    train_unet(zarr_dir)  # pylint:disable=E1120

definitions.py

 1from dagster import Definitions, load_assets_from_modules
 2
 3from esa_worldcover_classification import assets  # noqa: TID252
 4from esa_worldcover_classification.jobs import (
 5    s2_worldcover_landcover_classification_pipeline,
 6)
 7
 8all_assets = load_assets_from_modules([assets])
 9
10defs = Definitions(
11    assets=all_assets,
12    jobs=[s2_worldcover_landcover_classification_pipeline],  # pylint:disable=E1120
13)