Day4 - ML Pipelines#
Pipeline#
It is a declarative description of a workflow or process, typically in code or configuration, that specifies:
What tasks should be run
How they are connected or ordered
What inputs and outputs are involved
Any conditions, retries, or resources needed
Orchestrator#
An orchestrator is a system or tool that manages and coordinates the execution of complex workflows or tasks, especially when those tasks involve multiple steps, dependencies, and resources.
Comparison Between Popular ML Pipelines Orchestrators#
Feature |
Dagster |
Airflow |
Kubeflow |
Prefect |
Metaflow |
|---|---|---|---|---|---|
Primary Language |
Python |
Python |
Python |
Python |
Python |
Designed For |
Data & ML pipelines |
General workflow orchestration |
ML pipelines on Kubernetes |
General workflow orchestration |
ML pipelines, prototyping to prod |
ML Native Features |
✅ Strong ML support (IO management, type systems) |
❌ Limited ML support |
✅ Tight integration with TF, K8s |
⚠️ Minimal built-in ML features |
✅ ML-focused abstractions |
Kubernetes Native |
✅ (via Dagster K8s executor) |
✅ (with Helm, K8sExecutor) |
✅ Fully K8s-native |
✅ Optional |
✅ Optional |
Local Dev Experience |
✅ Very good (CLI & UI) |
⚠️ Okay but clunky |
❌ Heavyweight (needs K8s) |
✅ Excellent (easy local → cloud) |
✅ Excellent (local-first) |
UI / Observability |
✅ Excellent UI & asset tracking |
✅ Basic but mature |
✅ Full UI |
✅ Good (flow run UI) |
✅ Great (incl. lineage, retry) |
Type Safety / IO Mgmt |
✅ Strong typing & asset materialization |
❌ Minimal support |
⚠️ Basic through component specs |
⚠️ Basic typing |
✅ Simple but effective |
Data Lineage |
✅ First-class asset lineage |
⚠️ Custom plugins needed |
✅ via ML Metadata store |
⚠️ Partial |
✅ Built-in |
Execution Flexibility |
✅ Local, multiprocess, K8s, etc. |
✅ Executors: Celery, K8s, etc. |
❌ K8s only |
✅ Cloud or local agents |
✅ AWS, K8s, local |
Extensibility |
✅ Modular, Pythonic design |
✅ Strong DAG customization |
⚠️ Custom container components |
✅ Flows as Python code |
✅ Highly extensible Python code |
Community / Maturity |
⭐⭐ Growing fast |
⭐⭐⭐ Very mature (but older) |
⭐⭐ Kubernetes/Google centric |
⭐⭐ Fast-growing |
⭐⭐ Strong in enterprise/data science |
Best Use Case |
ML + data pipelines, asset-driven |
ETL, batch jobs |
Large-scale ML on K8s |
Lightweight orchestration |
ML workflows from notebook to prod |
🟦 Dagster
Pros: Modern, type-safe, asset-centric, good dev UX, great observability.
Cons: Learning curve around asset concepts.
Best for: ML teams looking for data lineage and strong development ergonomics.
🟩 Airflow
Pros: Battle-tested, huge ecosystem, flexible.
Cons: DSL is clunky for ML; weak typing; hard to trace ML artifacts.
Best for: Traditional ETL and teams with existing Airflow setups.
🟥 Kubeflow
Pros: Cloud-native, scalable, integrates well with K8s and TensorFlow ecosystem.
Cons: Complex setup, Kubernetes-only, poor local dev UX.
Best for: Teams deploying ML at scale on Kubernetes.
🟨 Prefect
Pros: Simple, Pythonic, cloud-native agents, strong developer experience.
Cons: Not ML-specific; less focus on artifact tracking.
Best for: Lightweight workflows, dataops, hybrid cloud/local orchestration.
🟪 Metaflow
Pros: Very ML-friendly, notebook integration, supports branching, versioning, retries.
Cons: Less customizable for general workflows.
Best for: ML teams needing reproducibility from notebook → prod.
Dagster#

An open-source data orchestrator for ML, analytics and (Extract, Transform, Load) ETL. It enables to build, run and monitor complex data pipelines. Dagster offers:
Declarative pipeline definitions (data dependencies and configuration).
Type-safe operations.
Native support for assets, schedules, and sensors.
Integration with popular data tools (e.g., dbt, Spark, MLFlow).
Core Concepts#
Concept |
Description |
|---|---|
|
A function that performs a unit of work |
|
A directed graph of |
|
A first-class, versioned data product |
|
A reusable composition of ops |
|
External dependency like S3, DB, API |
|
Triggers jobs by event/time |
Getting Started#
Install Dagster:
pip install dagster dagit
Initialize a new Dagster project:
dagster project scaffold --name dagster_tutorial cd dagster_tutorial
Run the Dagster development server:
dagster dev
Open the Dagit UI in your browser at
http://localhost:3000.Create a new file
ops.pyin the subdirectorydagster_tutorialand add the following code:from dagster import op @op def get_numbers(): return [1, 2, 3] @op def multiply(numbers): return [x * 10 for x in numbers]
Create a new file
jobs.pyin the subdirectorydagster_tutorialand add the following code:from dagster import job from .ops import get_numbers, multiply @job def process_job(): multiply(get_numbers())
In the
definitions.pyfile, import the job and add it to the repository:from dagster import Definitions, load_assets_from_modules from dagster_tutorial import assets # noqa: TID252 from dagster_tutorial.jobs import process_job all_assets = load_assets_from_modules([assets]) defs = Definitions( assets=all_assets, jobs=[process_job], )
Modify the
multiplyfunction inops.pyto get runtime config:from typing import List from dagster import op, Config class MultiplyConfig(Config): factor: int @op def multiply(config:MultiplyConfig,numbers:List[int]): return [x * config.factor for x in numbers]
In the launchpad, you can now run the
process_jobwith a configuration:ops: multiply: config: factor: 10
You can enable logging in your Dagster project by adding this anywhere you want to add logging:
from dagster import get_dagster_logger logger = get_dagster_logger() logger.info("This is an info message")
You can also use
assetsto define, persist and version your data products. For example, you can create a new fileassets.pyin the subdirectorydagster_tutorialand add the following code:from dagster import asset @asset def raw_data(): return [1, 2, 3, 4] @asset def squared_data(raw_data): return [x**2 for x in raw_data]
You can add a scheduler to run a job at a specific time. To do that, add the following code to the
definitions.pyfile:from dagster import ScheduleDefinition hourly_schedule = ScheduleDefinition( job=process_job, cron_schedule="0 * * * *", # Every hour ) defs = Definitions( assets=all_assets, jobs=[process_job], schedules=[hourly_schedule], )
Mlflow#

Open-source platform for managing the ML lifecycle, including experimentation, reproducibility, and deployment.
It provides a central repository for tracking experiments, packaging code into reproducible runs, and sharing and deploying models.
It has four main components:
Tracking: Log and query experiments.
Projects: Package code in a reusable and reproducible way.
Models: Manage and deploy models from various ML libraries.
Registry: Store and manage models in a central repository.
Dagster can be used to orchestrate ML workflows and integrate with MLflow for tracking experiments and managing models.
You can use Dagster to define and run ML pipelines, and use MLflow to log and track experiments, models, and artifacts.
You can use Dagster’s
@opdecorator to define MLflow operations, and use MLflow’s Python API to log and track experiments.You can use Dagster’s
@jobdecorator to define MLflow jobs, and use MLflow’s Python API to log and track experiments.You can use Dagster’s
@assetdecorator to define MLflow assets, and use MLflow’s Python API to log and track experiments.You can use Dagster’s
@scheduledecorator to define MLflow schedules, and use MLflow’s Python API to log and track experiments.You can use Dagster’s
@sensordecorator to define MLflow sensors, and use MLflow’s Python API to log and track experiments.
To test MLFlow access, you can run the following python code:
import os
from dotenv import load_dotenv
import mlflow
load_dotenv("../../.env")
MLFLOW_SERVER_URL = os.getenv("MLFLOW_SERVER_URL", "http://localhost:5000")
MLFLOW_TRACKING_USERNAME = os.getenv("MLFLOW_TRACKING_USERNAME")
MLFLOW_TRACKING_PASSWORD = os.getenv("MLFLOW_TRACKING_PASSWORD")
MY_PREFIX = "mohanad-experiment"
os.environ["MLFLOW_TRACKING_USERNAME"] = MLFLOW_TRACKING_USERNAME
os.environ["MLFLOW_TRACKING_PASSWORD"] = MLFLOW_TRACKING_PASSWORD
mlflow.set_tracking_uri(MLFLOW_SERVER_URL)
mlflow.set_experiment(f"/{MY_PREFIX}/classification")
with mlflow.start_run():
mlflow.log_metric("metric1", 1.0)
Dagster + Xarray + Dask to Train ERA5 Forecasting Model#
You can run the advanced Dagster project era5_forecast that integrates xarray, Dask with Dagster.
Instructions#
1- Create the dagster project:
dagster project scaffold -n era5_forecast
cd era5_forecast
2- Assuming the dependencies are installed, run the dagster server:
dagster dev
3- Open the Dagit UI in your browser at http://localhost:3000.
4- On jobs tab, select the era5_forecast job and click on the launchpad button to run it.
5- Open the Dask UI in your browser at http://localhost:8787/status.
ops.py
1import os
2from dagster import op, Out, get_dagster_logger
3import dask
4import numpy as np
5import xarray as xr
6from xgboost.dask import DaskDMatrix, train
7from .resources import DaskResource
8
9logger = get_dagster_logger()
10
11# ----------- CONFIG -----------
12ZARR_URL = "gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3"
13LAGS = list(range(1, 7)) # 6 days autoregressive lag features
14
15
16
17
18def create_lagged_features(da, lags):
19 lagged = [da.shift(time=lag).rename(f"lag_{lag}") for lag in lags]
20 return xr.merge(lagged + [da.rename("target")])
21
22
23# ----------- OPS -----------
24@op
25def load_and_preprocess_data():
26 ds = xr.open_zarr(
27 ZARR_URL,
28 chunks={"time": 1}, # type: ignore
29 storage_options={"token": "anon"},
30 )
31 temp = ds["2m_temperature"].sel(
32 time=slice("2018-01-01", "2018-01-31"),
33 # latitude=slice(40, 50),
34 # longitude=slice(0, 10),
35 )
36 daily = temp.resample(time="1D").mean()
37 daily = daily.persist()
38 return daily
39
40
41@op(out={"features": Out(), "labels": Out()})
42def generate_training_data(daily):
43 ds_lagged = create_lagged_features(daily, lags=LAGS)
44 X = xr.concat([ds_lagged[f"lag_{lag}"] for lag in LAGS], dim="feature")
45 X = (
46 X.stack(sample=("time", "latitude", "longitude"))
47 .transpose("sample", "feature")
48 .data
49 )
50 y = ds_lagged["target"].chunk({"time": -1, "latitude": 10, "longitude": 10})
51 y = y.stack(sample=("time", "latitude", "longitude")).data
52 return X, y
53
54
55# @op
56# def train_model(my_dask_resource: DaskResource, X, y):
57# client = my_dask_resource.make_dask_cluster()
58# logger.info("Dask dashboard link %s", client.dashboard_link)
59# sample_frac = 0.05 # or whatever % you'd like
60# logger.info(f"Sampling {sample_frac*100:.0f}% of data before rechunking")
61
62# n_samples = X.shape[0]
63# sample_size = int(n_samples * sample_frac)
64# random_indices = np.random.permutation(n_samples)[:sample_size]
65# X = X[random_indices]
66# y = y[random_indices]
67# logger.info("X.shape: %s", X.shape)
68# logger.info("y.shape: %s", y.shape)
69# chunk_size = 32 # or whatever fits your cluster
70# logger.info("Rechunking X and y...")
71# X = X.rechunk((chunk_size, -1)) # all columns together in each chunk
72# y = y.rechunk((chunk_size,)) # chunk by rows for labels
73# logger.info("Creating DaskDMatrix...")
74# dtrain = DaskDMatrix(client, X, y)
75# logger.info("Starting training...")
76# output = train(
77# client,
78# {"verbosity": 2, "tree_method": "hist", "objective": "reg:squarederror"},
79# dtrain,
80# num_boost_round=100,
81# evals=[(dtrain, "train")],
82# early_stopping_rounds=4,
83# )
84# booster = output["booster"]
85# os.makedirs("models", exist_ok=True)
86# booster.save_model("models/xgb_model.json")
87# client.close()
88
89@op
90def train_model(my_dask_resource: DaskResource, X, y):
91 client = my_dask_resource.make_dask_cluster()
92 logger.info("Dask dashboard link %s", client.dashboard_link)
93
94 sample_frac = 0.05
95 logger.info(f"Sampling {sample_frac * 100:.0f}% of data before rechunking")
96
97 n_samples = X.shape[0]
98 sample_size = int(n_samples * sample_frac)
99 random_indices = np.random.permutation(n_samples)[:sample_size]
100
101 # Turn indices into a Dask array for indexing
102 random_indices = dask.array.from_array(random_indices, chunks=(sample_size,))
103
104 # Sample X and y using Dask indexing
105 X = X[random_indices]
106 y = y[random_indices]
107
108 logger.info("X.shape: %s", X.shape)
109 logger.info("y.shape: %s", y.shape)
110
111 logger.info("Splitting sampled data into train/val sets using Dask...")
112 val_frac = 0.2
113 val_size = int(sample_size * val_frac)
114 train_size = sample_size - val_size
115
116
117 X_train = X[:train_size]
118 y_train = y[:train_size]
119 X_val = X[train_size:]
120 y_val = y[train_size:]
121
122 chunk_size = 32
123 X_train = X_train.rechunk((chunk_size, -1))
124 y_train = y_train.rechunk((chunk_size,))
125 X_val = X_val.rechunk((chunk_size, -1))
126 y_val = y_val.rechunk((chunk_size,))
127
128 logger.info("Creating DaskDMatrix for training and validation...")
129 dtrain = DaskDMatrix(client, X_train, y_train)
130 dval = DaskDMatrix(client, X_val, y_val)
131
132 logger.info("Starting training with validation...")
133 output = train(
134 client,
135 {
136 "verbosity": 2,
137 "tree_method": "hist",
138 "objective": "reg:squarederror",
139 },
140 dtrain,
141 num_boost_round=100,
142 evals=[(dtrain, "train"), (dval, "validation")],
143 early_stopping_rounds=4,
144 )
145 booster = output["booster"]
146 best_iteration = booster.best_iteration
147 logger.info(f"Training stopped at iteration: {best_iteration}")
148 os.makedirs("models", exist_ok=True)
149 booster.save_model("models/xgb_model.json")
150 logger.info("Model saved to models/xgb_model.json")
151
152 client.close()
resources.py
1from dagster import ConfigurableResource
2from dask.distributed import Client, LocalCluster
3
4
5class DaskResource(ConfigurableResource):
6 n_workers: int
7
8 def make_dask_cluster(self) -> Client:
9 client = Client(LocalCluster(n_workers=self.n_workers))
10 return client
jobs.py
1from dagster import job
2from .ops import generate_training_data, load_and_preprocess_data, train_model
3
4
5@job
6def training_pipeline():
7 daily = load_and_preprocess_data()
8 X, y = generate_training_data(daily)
9 train_model(X, y) # pylint:disable=E1120
definitions.py
1from dagster import Definitions, load_assets_from_modules
2
3from era5_forecast import assets # noqa: TID252
4from era5_forecast.jobs import training_pipeline
5from era5_forecast.resources import DaskResource # noqa: TID252
6
7all_assets = load_assets_from_modules([assets])
8
9defs = Definitions(
10 assets=all_assets,
11 jobs=[training_pipeline],
12 resources={"my_dask_resource": DaskResource(n_workers=4)},
13)
Dagster + Stackstac + Dask + MLflow to Train Sentinel-2 Land Cover Classification Model#
You can run the advanced Dagster project esa_worldcover_classification that integrates xarray, Dask with Dagster and MLflow.
Note
The ESA WorldCover dataset is a global land cover map produced by the European Space Agency, offering 10-meter resolution classification based on Sentinel-1 and Sentinel-2 satellite imagery. Released in 2021, it provides detailed and consistent information on land cover types such as forests, croplands, urban areas, and water bodies. Designed to support environmental monitoring, climate change studies, and sustainable land management, WorldCover is freely accessible and regularly updated, making it a valuable resource for researchers, policymakers, and Earth observation applications worldwide. The ESA WorldCover dataset includes 11 land cover classes, based on the UN FAO Land Cover Classification System (LCCS). Here’s the list of classes:
Class ID |
Land Cover Class |
|---|---|
10 |
Tree cover |
20 |
Shrubland |
30 |
Grassland |
40 |
Cropland |
50 |
Built-up |
60 |
Bare / sparse vegetation |
70 |
Snow and ice |
80 |
Permanent water bodies |
90 |
Herbaceous wetland |
95 |
Mangroves |
100 |
Moss and lichen |
Instructions#
1- Run the mlflow server:
mlflow server \
--backend-store-uri sqlite:///mlflow.db \
--default-artifact-root ./mlruns \
--host 0.0.0.0 \
--port 8000
2- Create the dagster project:
dagster project scaffold -n eas_worldcover_classification
cd eas_worldcover_classification
3- Assuming the dependencies are installed, run the dagster server:
dagster dev
4- Open the Dagit UI in your browser at http://localhost:3000.
5- On jobs tab, select the esa_worldcover_classification job and click on the launchpad button to run it. Use the following configuration:
Note
Change the mlflow experiment name with your name, e.g. /mohanad_s2_classification. Make sure to have the mlflow server URL, username and password set in the .env file.
ops:
fetch_s2_stack:
config:
bbox: [21.0, 38.0, 21.5, 38.5]
time_range: "2020-01-01/2020-01-31"
fetch_worldcover_stack:
config:
bbox: [21.0, 38.0, 21.5, 38.5]
time_range: "2020-01-01/2020-01-31"
save_to_zarr:
config:
zarr_cache_dir: "cache"
train_unet:
config:
patch_size: 64
stride: 32
batch_size: 16
model: "unet"
epochs: 1
learning_rate: 0.001
loss: "cross_entropy"
num_workers: 4
in_channels: 5
out_classes: 4
mlflow_tracking_uri: null
mlflow_experiment_name: "s2_classification"
model_path: "models/unet_model.pth"
zarr_cache_dir: "cache"
device: "mps"
6- Open the MLflow UI in your browser at http://localhost:8000.
ops.py
1import gc
2import os
3from typing import Optional
4from dagster import op, Out, Config, get_dagster_logger
5from pystac_client import Client
6import planetary_computer
7import requests
8import stackstac
9import xarray as xr
10from rasterio.enums import Resampling
11from .utils import clean_attrs, clean_coords, map_labels, transform_bbox
12from .configurations import (
13 AWS_STAC_API,
14 CLASS_MAPPING,
15 PLANETARY_COMPUTER_STAC_API,
16 PLANETARY_COMPUTER_TOKEN_URL,
17)
18from .train import train_model
19
20
21logger = get_dagster_logger()
22
23
24class StackConfig(Config):
25 bbox: list[float]
26 time_range: str
27
28
29class ZarrConfig(Config):
30 zarr_cache_dir: str
31
32
33class TrainUNetConfig(Config):
34 patch_size: int
35 stride: int
36 num_workers: int
37 batch_size: int
38 model: str
39 epochs: int
40 learning_rate: float
41 loss: str
42 in_channels: int
43 out_classes: int
44 mlflow_tracking_uri: Optional[str] = None
45 mlflow_experiment_name: str
46 model_path: str
47 zarr_cache_dir: str
48 device: str
49
50
51@op(out={"pixels_output": Out(), "proj_output": Out(), "valid_mask": Out()})
52def fetch_s2_stack(config: StackConfig) -> tuple[xr.DataArray, str, xr.DataArray]:
53 catalog = Client.open(AWS_STAC_API)
54 search = catalog.search(
55 collections=["sentinel-2-l2a"],
56 bbox=config.bbox,
57 datetime=config.time_range,
58 query={
59 "eo:cloud_cover": {"lt": 10},
60 "s2:nodata_pixel_percentage": {"lt": 10},
61 },
62 max_items=5000,
63 )
64 all_items = search.item_collection()
65 items = []
66 granules = []
67 for item in all_items:
68 if item.properties["s2:granule_id"] not in granules:
69 items.append(item)
70 granules.append(item.properties["s2:granule_id"])
71
72 logger.info("Found %d Sentinel-2-L2A items", len(items))
73 proj = items[0].properties["proj:code"]
74 bbox_utm = transform_bbox(config.bbox, items[0].properties["proj:code"])
75 assets = ["blue", "green", "red", "nir", "swir16", "scl"]
76 ds = stackstac.stack(
77 items,
78 assets=assets,
79 bounds=bbox_utm, # pyright: ignore
80 resolution=20,
81 epsg=int(proj.split(":")[-1]),
82 dtype="float64", # pyright: ignore
83 rescale=False,
84 snap_bounds=True,
85 resampling=Resampling.nearest,
86 chunksize=(1, 1, 512, 512),
87 )
88
89 cloud_values = [0, 1, 2, 3, 8, 9, 10] # cloud, shadows, cirrus, etc.
90 scl_mask = ds.sel(band="scl")
91 bands = ds.sel(band=["blue", "green", "red", "nir", "swir16"])
92 valid_mask = ~scl_mask.isin(cloud_values)
93 bands_masked = bands.where(valid_mask)
94 median_ds = bands_masked.groupby("time.month").median("time", skipna=True)
95 valid_mask_median = ~xr.ufuncs.isnan(median_ds.isel(band=0))
96 median_ds = median_ds.fillna(0)
97 logger.info("Median dataset shape: %s", median_ds.shape)
98 return (median_ds, proj, valid_mask_median)
99
100
101@op
102def fetch_worldcover_stack(
103 config: StackConfig, proj: str, valid_mask: xr.DataArray
104) -> xr.DataArray:
105 logger.info("Projection: %s", proj)
106 catalog = Client.open(PLANETARY_COMPUTER_STAC_API)
107 response = requests.get(
108 f"{PLANETARY_COMPUTER_TOKEN_URL}/esa-worldcover", timeout=10
109 )
110
111 if response.status_code == 200:
112 response = response.json()
113 token = response["token"]
114 _ = {"Authorization": f"Bearer {token}"}
115 else:
116 print(f"Failed to get token. Status code: {response.status_code}")
117 exit()
118 search = search = catalog.search(
119 collections=["esa-worldcover"],
120 bbox=config.bbox,
121 query={
122 "start_datetime": {"eq": "2020-01-01T00:00:00Z"},
123 "end_datetime": {"eq": "2020-12-31T23:59:59Z"},
124 },
125 )
126 all_items = search.item_collection()
127 items = [planetary_computer.sign_item(item) for item in all_items]
128 logger.info("Found %d ESA World Cover items", len(items))
129 bbox_utm = transform_bbox(config.bbox, proj)
130 ds = stackstac.stack(
131 items,
132 assets=["map"],
133 bounds=bbox_utm, # pyright: ignore
134 resolution=20,
135 epsg=int(proj.split(":")[-1]),
136 dtype="float64", # pyright: ignore
137 rescale=False,
138 snap_bounds=True,
139 resampling=Resampling.nearest,
140 chunksize=(1, 1, 512, 512),
141 )
142 ds = ds.sel(time=ds.notnull().any(dim=["x", "y", "band"]))
143 ds = ds.where(valid_mask)
144 ds = ds.fillna(0)
145 logger.info("WorldCover dataset shape: %s", ds.shape)
146 return ds
147
148
149@op
150def save_to_zarr(
151 config: ZarrConfig, features: xr.DataArray, labels: xr.DataArray
152) -> str:
153 features = features.squeeze()
154 labels = labels.squeeze()
155 features = features.drop_vars(
156 [dim for dim in ["time", "month"] if dim in features.dims]
157 )
158 labels = labels.drop_vars([dim for dim in ["time", "month"] if dim in labels.dims])
159 features, labels = xr.align(features, labels, join="inner")
160 labels = map_labels(labels, CLASS_MAPPING)
161 img_zarr_path = os.path.join(config.zarr_cache_dir, "img.zarr")
162 mask_zarr_path = os.path.join(config.zarr_cache_dir, "mask.zarr")
163 if not os.path.exists(img_zarr_path) and not os.path.exists(mask_zarr_path):
164 features = features.chunk({"x": 512, "y": 512})
165 labels = labels.chunk({"x": 512, "y": 512})
166 features.data = features.data.compute_chunk_sizes()
167 labels.data = labels.data.compute_chunk_sizes()
168 features = clean_attrs(features)
169 features = clean_coords(features)
170 features.to_zarr(img_zarr_path)
171
172 labels = clean_attrs(labels)
173 labels = clean_coords(labels)
174 labels.to_zarr(mask_zarr_path)
175 del features
176 del labels
177 gc.collect()
178 return config.zarr_cache_dir
179
180
181@op
182def train_unet(config: TrainUNetConfig, zarr_dir: str): # pylint:disable=W0613
183 train_model(config)
model.py
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4from transformers.models.segformer import (
5 SegformerForSemanticSegmentation,
6 SegformerConfig,
7)
8
9
10class UNetBaseLine(nn.Module):
11 def __init__(self, in_channels=4, out_classes=11, dropout=0.2):
12 super(UNetBaseLine, self).__init__()
13
14 def conv_block(in_ch, out_ch):
15 return nn.Sequential(
16 nn.Conv2d(in_ch, out_ch, 3, padding=1),
17 nn.BatchNorm2d(out_ch),
18 nn.ReLU(inplace=True),
19 nn.Conv2d(out_ch, out_ch, 3, padding=1),
20 nn.BatchNorm2d(out_ch),
21 nn.ReLU(inplace=True),
22 nn.Dropout2d(dropout),
23 )
24
25 self.enc1 = conv_block(in_channels, 64)
26 self.enc2 = conv_block(64, 128)
27 self.enc3 = conv_block(128, 256)
28
29 self.bottleneck = conv_block(256, 512)
30
31 self.pool = nn.MaxPool2d(2)
32
33 self.dec2 = conv_block(512 + 128, 128)
34 self.dec1 = conv_block(128 + 64, 64)
35
36 self.final = nn.Conv2d(64, out_classes, kernel_size=1)
37
38 def forward(self, x):
39 e1 = self.enc1(x)
40 e2 = self.enc2(self.pool(e1))
41 e3 = self.enc3(self.pool(e2))
42 b = self.bottleneck(self.pool(e3))
43
44 d2 = self._upsample_concat(b, e2)
45 d2 = self.dec2(d2)
46
47 d1 = self._upsample_concat(d2, e1)
48 d1 = self.dec1(d1)
49
50 return self.final(d1) # (B, 11, H, W)
51
52 def _upsample_concat(self, x, skip):
53 x = F.interpolate(x, size=skip.shape[2:], mode="bilinear", align_corners=False)
54 return torch.cat([x, skip], dim=1)
55
56
57class UNet(nn.Module):
58 def __init__(self, in_channels=4, out_classes=11, dropout=0.1, init_weights=True):
59 super(UNet, self).__init__()
60
61 def conv_block(in_ch, out_ch):
62 return nn.Sequential(
63 nn.Conv2d(in_ch, out_ch, 3, padding=1),
64 nn.BatchNorm2d(out_ch),
65 nn.ReLU(inplace=True),
66 nn.Conv2d(out_ch, out_ch, 3, padding=1),
67 nn.BatchNorm2d(out_ch),
68 nn.ReLU(inplace=True),
69 nn.Dropout2d(dropout),
70 )
71
72 # Encoder
73 self.enc1 = conv_block(in_channels, 64)
74 self.enc2 = conv_block(64, 128)
75 self.enc3 = conv_block(128, 256)
76
77 self.pool = nn.MaxPool2d(2)
78
79 # Bottleneck
80 self.bottleneck = conv_block(256, 512)
81
82 # Decoder
83 self.dec3 = conv_block(512 + 256, 256)
84 self.dec2 = conv_block(256 + 128, 128)
85 self.dec1 = conv_block(128 + 64, 64)
86 self.dec0 = conv_block(64, 32)
87
88 # Final classifier
89 self.final = nn.Conv2d(32, out_classes, kernel_size=1)
90
91 if init_weights:
92 self._init_weights()
93
94 def _init_weights(self):
95 for m in self.modules():
96 if isinstance(m, nn.Conv2d):
97 nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
98 if m.bias is not None:
99 nn.init.zeros_(m.bias)
100
101 def forward(self, x):
102 e1 = self.enc1(x) # (B, 64, H, W)
103 e2 = self.enc2(self.pool(e1)) # (B, 128, H/2, W/2)
104 e3 = self.enc3(self.pool(e2)) # (B, 256, H/4, W/4)
105 b = self.bottleneck(self.pool(e3)) # (B, 512, H/8, W/8)
106
107 d3 = self._upsample_concat(b, e3) # (B, 512+256, H/4, W/4)
108 d3 = self.dec3(d3)
109
110 d2 = self._upsample_concat(d3, e2) # (B, 256+128, H/2, W/2)
111 d2 = self.dec2(d2)
112
113 d1 = self._upsample_concat(d2, e1) # (B, 128+64, H, W)
114 d1 = self.dec1(d1)
115
116 d0 = self.dec0(d1) # (B, 32, H, W)
117
118 return self.final(d0) # logits: (B, out_classes, H, W)
119
120 def _upsample_concat(self, x, skip):
121 x = F.interpolate(x, size=skip.shape[2:], mode="bilinear", align_corners=False)
122 return torch.cat([x, skip], dim=1)
123
124
125class SegformerB0FourBand(nn.Module):
126 def __init__(self, in_channels=4, out_classes=11):
127 super().__init__()
128
129 # Step 1: Load pretrained model to extract 3-channel conv weights
130 pretrained_model = SegformerForSemanticSegmentation.from_pretrained(
131 "nvidia/segformer-b0-finetuned-ade-512-512"
132 )
133 pretrained_conv = pretrained_model.segformer.encoder.patch_embeddings[0].proj
134 pretrained_weights = pretrained_conv.weight # type: ignore # shape: [32, 3, 7, 7]
135
136 # Step 2: Create modified config
137 config = SegformerConfig.from_pretrained(
138 "nvidia/segformer-b0-finetuned-ade-512-512"
139 )
140 config.num_channels = in_channels
141 config.num_labels = out_classes
142
143 # Step 3: Create your model
144 self.model = SegformerForSemanticSegmentation(config)
145
146 # Step 4: Replace first conv with 4-band version and copy weights
147 old_conv = self.model.segformer.encoder.patch_embeddings[0].proj
148 new_conv = nn.Conv2d(
149 4,
150 old_conv.out_channels, # type: ignore
151 kernel_size=old_conv.kernel_size, # type: ignore
152 stride=old_conv.stride, # type: ignore
153 padding=old_conv.padding, # type: ignore
154 bias=old_conv.bias is not None, # type: ignore
155 )
156 with torch.no_grad():
157 new_conv.weight[:, :3] = pretrained_weights # type: ignore # use original RGB weights
158 new_conv.weight[:, 3] = pretrained_weights[:, 0] # type: ignore # init 4th like Red
159 if old_conv.bias is not None: # type: ignore
160 new_conv.bias.copy_(pretrained_conv.bias) # type: ignore
161 self.model.segformer.encoder.patch_embeddings[0].proj = new_conv
162
163 # Step 5: Load other pretrained weights (except the mismatched ones)
164 state_dict = pretrained_model.state_dict()
165 for key in [
166 "segformer.encoder.patch_embeddings.0.proj.weight",
167 "decode_head.classifier.weight",
168 "decode_head.classifier.bias",
169 ]:
170 state_dict.pop(key, None)
171 self.model.load_state_dict(state_dict, strict=False)
172
173 def forward(self, x):
174 logits = self.model(pixel_values=x).logits
175 logits = F.interpolate(
176 logits, size=x.shape[-2:], mode="bilinear", align_corners=False
177 )
178 logits = logits.contiguous()
179 return logits
dataset.py
1import os
2import torch
3from torch.utils.data import Dataset
4import xarray as xr
5from .utils import compute_valid_patch_indices
6
7
8class XarrayPatchDataset(Dataset):
9 def __init__(self, patch_size: int, stride: int, zarr_cache_dir: str,indices=None):
10 """
11 Dataset for extracting patches from xarray images, using precomputed valid patch indices.
12
13 Args:
14 patch_size (int): Size of square patches.
15 stride (int): Stride of sliding window.
16 zarr_cache_dir (str): Directory to store the cached data in Zarr format.
17 """
18 self.patch_size = patch_size
19 self.stride = stride
20 self.zarr_cache_dir = zarr_cache_dir
21 self.img_zarr_path = os.path.join(zarr_cache_dir, "img.zarr")
22 self.mask_zarr_path = os.path.join(zarr_cache_dir, "mask.zarr")
23 self.img = xr.open_zarr(self.img_zarr_path)
24 self.img = next(iter(self.img.data_vars.values()))
25 self.mask = xr.open_zarr(self.mask_zarr_path)
26 self.mask = next(iter(self.mask.data_vars.values()))
27 self.mask.data.compute_chunk_sizes()
28 if indices is None:
29 full_indices = compute_valid_patch_indices(self.mask, patch_size, stride, threshold=0.1)
30 self.valid_indices = (
31 indices if indices is not None else full_indices
32 )
33
34 def __len__(self):
35 """Return the number of valid patches."""
36 return len(self.valid_indices)
37
38 def __getitem__(self, idx):
39 """Get the patch at the given index."""
40 # Get the top-left coordinates of the patch
41 i, j = self.valid_indices[idx]
42
43 # Use Dask arrays directly for patch extraction
44 img_patch = self.img[:, i : i + self.patch_size, j : j + self.patch_size]
45 mask_patch = self.mask[i : i + self.patch_size, j : j + self.patch_size]
46
47 # Convert to Dask arrays (lazily loaded)
48 img_patch = img_patch.data # Dask array
49 mask_patch = mask_patch.data # Dask array
50
51 # Compute and convert to PyTorch tensors
52 x = torch.tensor(img_patch.compute(), dtype=torch.float32)
53 y = torch.tensor(mask_patch.compute(), dtype=torch.long)
54
55 # Scale and apply mask
56 x = x / 10000.0
57 x = x.clip(min=0.0, max=1.0)
58
59 return x, y
train.py
1import os
2from tqdm import tqdm
3import torch
4from torch.utils.data import DataLoader
5import mlflow
6from dagster import get_dagster_logger
7import xarray as xr
8from sklearn.model_selection import train_test_split
9from torchmetrics import JaccardIndex
10from torchmetrics.segmentation import DiceScore, MeanIoU
11from dotenv import load_dotenv
12from .loss import FocalLoss
13from .model import UNet, UNetBaseLine, SegformerB0FourBand
14from .dataset import XarrayPatchDataset
15from .utils import compute_valid_patch_indices
16
17load_dotenv("../../.env")
18
19MLFLOW_SERVER_URL = os.getenv("MLFLOW_SERVER_URL")
20MLFLOW_TRACKING_USERNAME = os.getenv("MLFLOW_TRACKING_USERNAME")
21MLFLOW_TRACKING_PASSWORD = os.getenv("MLFLOW_TRACKING_PASSWORD")
22S3_ACCESS_KEY=os.getenv("S3_ACCESS_KEY")
23S3_SECRET_ACCESS_KEY=os.getenv("S3_SECRET_ACCESS_KEY")
24S3_END_POINT=os.getenv("S3_END_POINT")
25logger = get_dagster_logger()
26
27
28def train_model(config):
29 os.makedirs(os.path.dirname(config.model_path), exist_ok=True)
30 if not config.mlflow_tracking_uri:
31 mlflow_tracking_uri = MLFLOW_SERVER_URL
32 else:
33 mlflow_tracking_uri = config.mlflow_tracking_uri
34 if config.model == "unet":
35 model = UNet(
36 in_channels=config.in_channels,
37 out_classes=config.out_classes,
38 dropout=0.2,
39 )
40 elif config.model == "segformer":
41 model = SegformerB0FourBand(
42 in_channels=config.in_channels,
43 out_classes=config.out_classes,
44 )
45 else:
46 model = UNetBaseLine(
47 in_channels=config.in_channels, out_classes=config.out_classes
48 )
49 mask = xr.open_zarr(os.path.join(config.zarr_cache_dir, "mask.zarr"))
50 mask = next(iter(mask.data_vars.values()))
51 full_indices = compute_valid_patch_indices(
52 mask,
53 config.patch_size,
54 config.stride,
55 threshold=0.1
56 )
57 # Shuffle + split
58 train_indices, val_indices = train_test_split(
59 full_indices, test_size=0.2, random_state=42
60 )
61 train_dataset = XarrayPatchDataset(
62 config.patch_size,
63 config.stride,
64 config.zarr_cache_dir,
65 indices=train_indices
66 )
67 valid_dataset = XarrayPatchDataset(
68 config.patch_size,
69 config.stride,
70 config.zarr_cache_dir,
71 indices=val_indices
72 )
73 if config.loss == "focal":
74 criterion = FocalLoss(alpha=None, gamma=2.0, ignore_index=0)
75 else:
76 criterion = torch.nn.CrossEntropyLoss(ignore_index=0)
77 optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
78 jaccard = JaccardIndex(
79 task="multiclass",
80 num_classes=config.out_classes,
81 average="weighted",
82 zero_division=0.0,
83 ignore_index=0,
84 ).to(config.device)
85 dice = DiceScore(
86 num_classes=config.out_classes, input_format="index", zero_division=0.0
87 ).to(config.device)
88 iou = MeanIoU(num_classes=config.out_classes, input_format="index").to(
89 config.device
90 )
91 model.to(config.device)
92 model.train()
93 train_loader = DataLoader(
94 train_dataset,
95 batch_size=config.batch_size,
96 shuffle=True,
97 num_workers=config.num_workers,
98 prefetch_factor=2,
99 )
100 val_loader = DataLoader(
101 valid_dataset,
102 batch_size=config.batch_size,
103 shuffle=False,
104 num_workers=config.num_workers,
105 prefetch_factor=2,
106 )
107 n_train = len(train_loader)
108 n_val = len(val_loader)
109 best_val_loss = float("inf")
110 best_model_path = os.path.join(os.path.dirname(config.model_path), f"best_model.pth")
111 os.environ["MLFLOW_TRACKING_USERNAME"] = MLFLOW_TRACKING_USERNAME
112 os.environ["MLFLOW_TRACKING_PASSWORD"] = MLFLOW_TRACKING_PASSWORD
113 os.environ["AWS_ACCESS_KEY_ID"] = S3_ACCESS_KEY
114 os.environ["AWS_SECRET_ACCESS_KEY"] = S3_SECRET_ACCESS_KEY
115 os.environ["MLFLOW_S3_ENDPOINT_URL"] = S3_END_POINT
116 mlflow.set_tracking_uri(mlflow_tracking_uri)
117 mlflow.set_experiment(config.mlflow_experiment_name)
118 with mlflow.start_run():
119 mlflow.log_params(
120 {
121 "model": config.model,
122 "num_epochs": config.epochs,
123 "learning_rate": config.learning_rate,
124 "loss": config.loss,
125 "batch_size": config.batch_size,
126 "patch_size": config.patch_size,
127 "stride": config.stride,
128 "num_workers": config.num_workers,
129 "in_channels": config.in_channels,
130 "out_classes": config.out_classes,
131 "zarr_cache_dir": config.zarr_cache_dir,
132 "model_path": config.model_path,
133 "best_model_path": best_model_path,
134 }
135 )
136 for epoch in range(config.epochs): # pylint:disable=W0612
137 progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}", leave=False)
138 total_train_loss = 0
139 total_train_dice = 0
140 total_train_jaccard = 0
141 total_train_iou = 0
142 total_val_loss = 0
143 total_val_dice = 0
144 total_val_jaccard = 0
145 total_val_iou = 0
146 model.train()
147 for _, (X_batch, y_batch) in enumerate(progress_bar):
148 X_batch = X_batch.to(config.device)
149 y_batch = y_batch.to(config.device)
150 out = model(X_batch)
151 loss = criterion(out, y_batch)
152 preds = torch.argmax(out, dim=1)
153 dice_coef = dice(preds, y_batch)
154 jaccard_index = jaccard(preds, y_batch)
155 iou_index = iou(preds, y_batch)
156 optimizer.zero_grad()
157 # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
158 loss.backward()
159 optimizer.step()
160 total_train_loss += loss.item()
161 total_train_dice += dice_coef.item()
162 total_train_jaccard += jaccard_index.item()
163 total_train_iou += iou_index.item()
164
165 mlflow.log_metric("loss", total_train_loss / n_train, step=epoch)
166 mlflow.log_metric("dice", total_train_dice / n_train, step=epoch)
167 mlflow.log_metric("jaccard", total_train_jaccard / n_train, step=epoch)
168 mlflow.log_metric("iou", total_train_iou / n_train, step=epoch)
169 model.eval()
170 with torch.no_grad():
171 for X_batch, y_batch in val_loader:
172 X_batch = X_batch.to(config.device)
173 y_batch = y_batch.to(config.device)
174 out = model(X_batch)
175 loss = criterion(out, y_batch)
176 preds = torch.argmax(out, dim=1)
177 total_val_loss += loss.item()
178 total_val_dice += dice(preds, y_batch).item()
179 total_val_jaccard += jaccard(preds, y_batch).item()
180 total_val_iou += iou(preds, y_batch).item()
181 avg_val_loss = total_val_loss / n_val
182 mlflow.log_metric("val_loss", avg_val_loss, step=epoch)
183 mlflow.log_metric("val_dice", total_val_dice / n_val, step=epoch)
184 mlflow.log_metric("val_jaccard", total_val_jaccard / n_val, step=epoch)
185 mlflow.log_metric("val_iou", total_val_iou / n_val, step=epoch)
186 if avg_val_loss < best_val_loss:
187 best_val_loss = avg_val_loss
188 torch.save(model.state_dict(),best_model_path)
189 logger.info(f"Saved new best model at epoch {epoch+1} with val_loss={avg_val_loss:.4f}")
190 torch.save(model.state_dict(), config.model_path)
191 input_example = X_batch[0].unsqueeze(0).cpu().numpy()
192 mlflow.pytorch.log_model(model, artifact_path="last_model",input_example=input_example)
193 model.load_state_dict(torch.load(best_model_path))
194 mlflow.pytorch.log_model(model, artifact_path="best_model",input_example=input_example)
utils.py
1import logging
2import xarray as xr
3import numpy as np
4from pyproj import Transformer
5import dask.array as da
6from numpy.lib.stride_tricks import sliding_window_view
7
8
9def transform_bbox(bbox_wgs84, dest_crs):
10 transformer = Transformer.from_crs("epsg:4326", dest_crs, always_xy=True)
11 minx, miny = transformer.transform(bbox_wgs84[0], bbox_wgs84[1])
12 maxx, maxy = transformer.transform(bbox_wgs84[2], bbox_wgs84[3])
13 bbox_utm = [minx, miny, maxx, maxy]
14 return bbox_utm
15
16
17def map_labels(labels: xr.DataArray, mapping: dict[int, int]) -> xr.DataArray:
18 data = labels.values
19 mapped = np.full_like(data, fill_value=0, dtype=np.int32)
20 for original, new in mapping.items():
21 mapped[data == original] = new
22 return xr.DataArray(mapped, coords=labels.coords, dims=labels.dims)
23
24
25def clean_attrs(da_or_ds):
26 # Only keep simple, JSON-serializable entries in attrs
27 da_or_ds.attrs = {
28 k: v
29 for k, v in da_or_ds.attrs.items()
30 if isinstance(v, (str, int, float, list, dict, bool, type(None)))
31 }
32 return da_or_ds
33
34
35def clean_coords( # pylint:disable=W0102
36 ds: xr.DataArray, keep: list[str] = []
37) -> xr.DataArray:
38 """
39 Remove non-dimension coordinates and attributes from an xarray Dataset,
40 except those explicitly listed in `keep`.
41
42 Parameters:
43 ----------
44 ds : xr.Dataset
45 The dataset to clean.
46 keep : list of str, optional
47 List of coordinate names to keep, even if they are non-dimension coords.
48
49 Returns:
50 -------
51 xr.Dataset
52 Cleaned dataset with only dimension coordinates and selected metadata.
53 """
54 # Get all dimension coordinates
55 dim_coords = list(ds.dims)
56
57 # Identify coordinates to drop (non-dim, not explicitly kept)
58 drop_coords = [c for c in ds.coords if c not in dim_coords and c not in keep]
59
60 # Drop extra coordinates
61 ds = ds.drop_vars(drop_coords)
62
63 return ds
64
65
66def compute_valid_patch_indices(
67 mask: xr.DataArray, patch_size: int, stride: int, threshold: float = 0.1
68):
69 """
70 Compute top-left coordinates of valid patches based on non-NaN ratio in each patch.
71
72 Args:
73 mask (xr.DataArray): 2D xarray DataArray (e.g. land cover labels or cloud mask).
74 patch_size (int): Size of square patches.
75 stride (int): Stride of sliding window.
76 threshold (float): Minimum fraction of valid (non-NaN) pixels in a patch.
77
78 Returns:
79 List of (i, j) tuples (top-left corners of valid patches).
80 """
81 # Convert to dask array and binarize: 1 = valid, 0 = NaN
82 binary_mask = (~da.isnan(mask.data)).astype("uint8") # type: ignore
83
84 # Convert to NumPy (only metadata) to get shape info
85 _, _ = binary_mask.shape
86 logging.info("Shape of binary mask: %s", binary_mask.shape)
87 # Sliding windows over 2D: shape becomes (num_patches_y, num_patches_x, patch_size, patch_size)
88 sw = sliding_window_view(binary_mask, (patch_size, patch_size))[::stride, ::stride]
89
90 # Compute mean valid fraction per patch (still lazy)
91 patch_valid_fraction = sw.mean(axis=(-1, -2))
92
93 # Evaluate only the valid locations
94 patch_valid_fraction = patch_valid_fraction.compute()
95
96 # Get coordinates where patch validity exceeds threshold
97 valid_coords = np.argwhere(patch_valid_fraction >= threshold)
98 # Map patch indices back to full image pixel coords
99 indices = [(i * stride, j * stride) for i, j in valid_coords]
100
101 return indices
jobs.py
1from dagster import job
2from .ops import fetch_s2_stack, fetch_worldcover_stack, save_to_zarr, train_unet
3
4
5@job
6def s2_worldcover_landcover_classification_pipeline():
7 features, proj, mask = fetch_s2_stack() # pylint:disable=E1120
8 labels = fetch_worldcover_stack(proj, mask) # pylint:disable=E1120
9 zarr_dir = save_to_zarr(features, labels) # pylint:disable=E1120
10 train_unet(zarr_dir) # pylint:disable=E1120
definitions.py
1from dagster import Definitions, load_assets_from_modules
2
3from esa_worldcover_classification import assets # noqa: TID252
4from esa_worldcover_classification.jobs import (
5 s2_worldcover_landcover_classification_pipeline,
6)
7
8all_assets = load_assets_from_modules([assets])
9
10defs = Definitions(
11 assets=all_assets,
12 jobs=[s2_worldcover_landcover_classification_pipeline], # pylint:disable=E1120
13)