Skip to content

Commit 7db84e4

Browse files
shoyerXarray-Beam authors
authored andcommitted
Allow for passing pipelines to xbeam.Dataset constructors.
Associating a beam.Pipeline with an xbeam.Dataset means that a pipeline doesn't need to be applied later (e.g., to the result of `to_zarr`). This is both a little cleaner, and also potentially a significant optimization, because it means that Beam understands that it can reuse a ptransform rather than recomputing it. This includes a new `_LazyPCollection` class to ensure that our optimizations for Transforms applied directly after xbeam.DatasetToChunks still works. PiperOrigin-RevId: 824281688
1 parent afd3f80 commit 7db84e4

File tree

4 files changed

+239
-88
lines changed

4 files changed

+239
-88
lines changed

docs/high-level.ipynb

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -176,9 +176,9 @@
176176
},
177177
"cell_type": "code",
178178
"source": [
179-
"with beam.Pipeline() as p:\n",
180-
" p | (\n",
181-
" xbeam.Dataset.from_zarr('example_data.zarr')\n",
179+
"with beam.Pipeline() as pipeline:\n",
180+
" (\n",
181+
" xbeam.Dataset.from_zarr('example_data.zarr', pipeline=pipeline)\n",
182182
" .rechunk({'time': -1, ...: '100 MB'})\n",
183183
" .map_blocks(lambda ds: ds.groupby('time.month').mean())\n",
184184
" .rechunk('10 MB') # ensure a reasonable min chunk-size for Zarr\n",
@@ -206,9 +206,9 @@
206206
},
207207
"cell_type": "code",
208208
"source": [
209-
"with beam.Pipeline() as p:\n",
210-
" p | (\n",
211-
" xbeam.Dataset.from_zarr('example_data.zarr')\n",
209+
"with beam.Pipeline() as pipeline:\n",
210+
" (\n",
211+
" xbeam.Dataset.from_zarr('example_data.zarr', pipeline=pipeline)\n",
212212
" .rechunk({'time': '30MB', 'latitude': -1, 'longitude': -1})\n",
213213
" .map_blocks(lambda ds: ds.coarsen(latitude=2, longitude=2).mean())\n",
214214
" .to_zarr('example_regrid.zarr')\n",
@@ -254,7 +254,7 @@
254254
},
255255
"cell_type": "markdown",
256256
"source": [
257-
"You can avoid these errors by explicitly supplying a template, either from {py:attr}`Dataset.template \u003cxarray_beam.Dataset.template\u003e` or produced by {py:func}`~xarray_beam.make_template`:"
257+
"You can avoid these errors by explicitly supplying a template, either from {py:attr}`Dataset.template <xarray_beam.Dataset.template>` or produced by {py:func}`~xarray_beam.make_template`:"
258258
]
259259
},
260260
{
@@ -270,7 +270,18 @@
270270
")"
271271
],
272272
"outputs": [],
273-
"execution_count": 6
273+
"execution_count": 5
274+
},
275+
{
276+
"metadata": {
277+
"id": "7l8Cw8xTURea"
278+
},
279+
"cell_type": "markdown",
280+
"source": [
281+
"```{tip}\n",
282+
"Notice that supplying `pipeline` to {py:func}`~xarray_beam.Dataset.from_zarr` is _optional_. You'll need to eventually apply a Beam pipeline to a `PTransform` produced by Xarray-Beam to compute it, but it can be convenient to omit when building pipelines interactively.\n",
283+
"```"
284+
]
274285
},
275286
{
276287
"metadata": {
@@ -301,7 +312,7 @@
301312
"all_times = pd.date_range('2025-01-01', freq='1D', periods=365)\n",
302313
"source_dataset = xarray.open_zarr('example_data.zarr', chunks=None)\n",
303314
"\n",
304-
"def load_chunk(time: pd.Timestamp) -\u003e tuple[xbeam.Key, xarray.Dataset]:\n",
315+
"def load_chunk(time: pd.Timestamp) -> tuple[xbeam.Key, xarray.Dataset]:\n",
305316
" key = xbeam.Key({'time': (time - all_times[0]).days})\n",
306317
" dataset = source_dataset.sel(time=[time])\n",
307318
" return key, dataset\n",

xarray_beam/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,4 +55,4 @@
5555
DatasetToZarr as DatasetToZarr,
5656
)
5757

58-
__version__ = '0.11.2' # automatically synchronized to pyproject.toml
58+
__version__ = '0.11.3' # automatically synchronized to pyproject.toml

xarray_beam/_src/dataset.py

Lines changed: 88 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,13 @@
1717
1818
import xarray_beam as xbeam
1919
20-
transform = (
21-
xbeam.Dataset.from_zarr(input_path)
22-
.rechunk({'time': -1, 'latitude': 10, 'longitude': 10})
23-
.map_blocks(lambda x: x.median('time'))
24-
.to_zarr(output_path)
25-
)
2620
with beam.Pipeline() as p:
27-
p | transform
21+
(
22+
xbeam.Dataset.from_zarr(input_path, pipeline=p)
23+
.rechunk({'time': -1, 'latitude': 10, 'longitude': 10})
24+
.map_blocks(lambda x: x.median('time'))
25+
.to_zarr(output_path)
26+
)
2827
"""
2928
from __future__ import annotations
3029

@@ -371,21 +370,21 @@ def method(self: Dataset, *args, **kwargs) -> Dataset:
371370
chunks = {k: v for k, v in self.chunks.items() if k in template.dims}
372371

373372
label = _get_label(method_name)
374-
if isinstance(self.ptransform, core.DatasetToChunks):
373+
374+
pipeline, ptransform = _split_lazy_pcollection(self._ptransform)
375+
if isinstance(ptransform, core.DatasetToChunks):
375376
# Some transformations (e.g., indexing) can be applied much less
376377
# expensively to xarray.Dataset objects rather than via Xarray-Beam. Try
377378
# to preserve this option for downstream transformations if possible.
378-
dataset = func(self.ptransform.dataset)
379+
dataset = func(ptransform.dataset)
379380
ptransform = core.DatasetToChunks(dataset, chunks, self.split_vars)
380-
ptransform.label = _concat_labels(self.ptransform.label, label)
381+
ptransform.label = _concat_labels(ptransform.label, label)
382+
if pipeline is not None:
383+
ptransform = _LazyPCollection(pipeline, ptransform)
381384
else:
382385
ptransform = self.ptransform | label >> beam.MapTuple(
383386
functools.partial(
384-
_apply_to_each_chunk,
385-
func,
386-
method_name,
387-
self.chunks,
388-
chunks
387+
_apply_to_each_chunk, func, method_name, self.chunks, chunks
389388
)
390389
)
391390
return Dataset(template, chunks, self.split_vars, ptransform)
@@ -405,6 +404,43 @@ def apply(self, name: str) -> str:
405404
_get_label = _CountNamer().apply
406405

407406

407+
@dataclasses.dataclass(frozen=True)
408+
class _LazyPCollection:
409+
"""Pipeline and PTransform not yet been combined into a PCollection."""
410+
# Beam does not provide a public API for manipulating Pipeline objects, so
411+
# instead of applying pipelines eagerly, we store them in this wrapper. This
412+
# allows for performance optimizations specialized to Xarray-Beam PTransforms,
413+
# (in particular for DatasetToChunks) even in the case where a pipeline is
414+
# supplied.
415+
pipeline: beam.Pipeline
416+
ptransform: beam.PTransform
417+
418+
# Cache the evaluated PCollection, so we apply the transform to the pipeline
419+
# at most once. Otherwise, reapplying the same transform results in reused
420+
# labels in the same pipeline, which is an error in Beam.
421+
@functools.cached_property
422+
def evaluated(self) -> beam.PCollection:
423+
return self.pipeline | self.ptransform
424+
425+
426+
def _split_lazy_pcollection(
427+
value: beam.PTransform | beam.PCollection | _LazyPCollection,
428+
) -> tuple[beam.Pipeline | None, beam.PTransform | beam.PCollection]:
429+
if isinstance(value, _LazyPCollection):
430+
return value.pipeline, value.ptransform
431+
else:
432+
return None, value
433+
434+
435+
def _as_eager_pcollection_or_ptransform(
436+
value: beam.PTransform | beam.PCollection | _LazyPCollection,
437+
) -> beam.PTransform | beam.PCollection:
438+
if isinstance(value, _LazyPCollection):
439+
return value.evaluated
440+
else:
441+
return value
442+
443+
408444
@core.export
409445
@dataclasses.dataclass
410446
class Dataset:
@@ -415,7 +451,7 @@ def __init__(
415451
template: xarray.Dataset,
416452
chunks: Mapping[str, int],
417453
split_vars: bool,
418-
ptransform: beam.PTransform,
454+
ptransform: beam.PTransform | beam.PCollection | _LazyPCollection,
419455
):
420456
"""Low level interface for creating a new Dataset, without validation.
421457
@@ -429,7 +465,7 @@ def __init__(
429465
use :py:func:`xarray_beam.normalize_chunks`.
430466
split_vars: whether variables are split between separate elements in the
431467
ptransform, or all stored in the same element.
432-
ptransform: Beam PTransform of ``(xbeam.Key, xarray.Dataset)`` tuples with
468+
ptransform: Beam collection of ``(xbeam.Key, xarray.Dataset)`` tuples with
433469
this dataset's data.
434470
"""
435471
self._template = template
@@ -453,9 +489,9 @@ def split_vars(self) -> bool:
453489
return self._split_vars
454490

455491
@property
456-
def ptransform(self) -> beam.PTransform:
492+
def ptransform(self) -> beam.PTransform | beam.PCollection:
457493
"""Beam PTransform of (xbeam.Key, xarray.Dataset) with this dataset's data."""
458-
return self._ptransform
494+
return _as_eager_pcollection_or_ptransform(self._ptransform)
459495

460496
@property
461497
def sizes(self) -> Mapping[str, int]:
@@ -507,7 +543,7 @@ def __repr__(self):
507543
plural = 's' if chunk_count != 1 else ''
508544
return (
509545
'<xarray_beam.Dataset>\n'
510-
f'PTransform: {self.ptransform}\n'
546+
f'PTransform: {self._ptransform}\n'
511547
f'Chunks: {chunk_size} ({chunks_str})\n'
512548
f'Template: {total_size} ({chunk_count} chunk{plural})\n'
513549
+ textwrap.indent('\n'.join(base.split('\n')[1:]), ' ' * 4)
@@ -516,7 +552,7 @@ def __repr__(self):
516552
@classmethod
517553
def from_ptransform(
518554
cls,
519-
ptransform: beam.PTransform,
555+
ptransform: beam.PTransform | beam.PCollection,
520556
*,
521557
template: xarray.Dataset,
522558
chunks: Mapping[str | types.EllipsisType, int],
@@ -535,9 +571,9 @@ def from_ptransform(
535571
outputs are valid.
536572
537573
Args:
538-
ptransform: A Beam PTransform that yields ``(Key, xarray.Dataset)`` pairs.
539-
You only need to set ``offsets`` on these keys, ``vars`` will be
540-
automatically set based on the dataset if ``split_vars`` is True.
574+
ptransform: A Beam collection of ``(Key, xarray.Dataset)`` pairs. You only
575+
need to set ``offsets`` on these keys, ``vars`` will be automatically
576+
set based on the dataset if ``split_vars`` is True.
541577
template: An ``xarray.Dataset`` object representing the schema
542578
(coordinates, dimensions, data variables, and attributes) of the full
543579
dataset, as produced by :py:func:`xarray_beam.make_template`, with data
@@ -577,6 +613,7 @@ def from_xarray(
577613
*,
578614
split_vars: bool = False,
579615
previous_chunks: Mapping[str, int] | None = None,
616+
pipeline: beam.Pipeline | None = None,
580617
) -> Dataset:
581618
"""Create an xarray_beam.Dataset from an xarray.Dataset.
582619
@@ -588,13 +625,17 @@ def from_xarray(
588625
ptransform, or all stored in the same element.
589626
previous_chunks: chunks hint used for parsing string values in ``chunks``
590627
with ``normalize_chunks()``.
628+
pipeline: Beam pipeline to use for this dataset. If not provided, you will
629+
need apply a pipeline later to compute this dataset.
591630
"""
592631
template = zarr.make_template(source)
593632
if previous_chunks is None:
594633
previous_chunks = source.sizes
595634
chunks = normalize_chunks(chunks, template, split_vars, previous_chunks)
596635
ptransform = core.DatasetToChunks(source, chunks, split_vars)
597636
ptransform.label = _get_label('from_xarray')
637+
if pipeline is not None:
638+
ptransform = _LazyPCollection(pipeline, ptransform)
598639
return cls(template, dict(chunks), split_vars, ptransform)
599640

600641
@classmethod
@@ -604,6 +645,7 @@ def from_zarr(
604645
*,
605646
chunks: UnnormalizedChunks | None = None,
606647
split_vars: bool = False,
648+
pipeline: beam.Pipeline | None = None,
607649
) -> Dataset:
608650
"""Create an xarray_beam.Dataset from a Zarr store.
609651
@@ -614,6 +656,8 @@ def from_zarr(
614656
provided, the chunk sizes will be inferred from the Zarr file.
615657
split_vars: whether variables are split between separate elements in the
616658
ptransform, or all stored in the same element.
659+
pipeline: Beam pipeline to use for this dataset. If not provided, you will
660+
need apply a pipeline later to compute this dataset.
617661
618662
Returns:
619663
New Dataset created from the Zarr store.
@@ -622,9 +666,14 @@ def from_zarr(
622666
if chunks is None:
623667
chunks = previous_chunks
624668
result = cls.from_xarray(
625-
source, chunks, split_vars=split_vars, previous_chunks=previous_chunks
669+
source,
670+
chunks,
671+
split_vars=split_vars,
672+
previous_chunks=previous_chunks,
626673
)
627674
result.ptransform.label = _get_label('from_zarr')
675+
if pipeline is not None:
676+
result._ptransform = _LazyPCollection(pipeline, result.ptransform)
628677
return result
629678

630679
def _check_shards_or_chunks(
@@ -650,7 +699,7 @@ def to_zarr(
650699
zarr_shards: UnnormalizedChunks | None = None,
651700
zarr_format: int | None = None,
652701
stage_locally: bool | None = None,
653-
) -> beam.PTransform:
702+
) -> beam.PTransform | beam.PCollection:
654703
"""Write this dataset to a Zarr file.
655704
656705
The extensive options for controlling chunking and sharding are intended for
@@ -688,7 +737,7 @@ def to_zarr(
688737
path.
689738
690739
Returns:
691-
Beam PTransform that writes the dataset to a Zarr file.
740+
Beam transform that writes the dataset to a Zarr file.
692741
"""
693742
if zarr_shards is not None:
694743
zarr_shards = normalize_chunks(
@@ -889,15 +938,18 @@ def rechunk(
889938
)
890939
label = _get_label('rechunk')
891940

892-
if isinstance(self.ptransform, core.DatasetToChunks) and all(
941+
pipeline, ptransform = _split_lazy_pcollection(self._ptransform)
942+
if isinstance(ptransform, core.DatasetToChunks) and all(
893943
chunks[k] % self.chunks[k] == 0 for k in chunks
894944
):
895945
# Rechunking can be performed by re-reading the source dataset with new
896946
# chunks, rather than using a separate rechunking transform.
897947
ptransform = core.DatasetToChunks(
898-
self.ptransform.dataset, chunks, split_vars
948+
ptransform.dataset, chunks, split_vars
899949
)
900-
ptransform.label = _concat_labels(self.ptransform.label, label)
950+
ptransform.label = _concat_labels(ptransform.label, label)
951+
if pipeline is not None:
952+
ptransform = _LazyPCollection(pipeline, ptransform)
901953
return type(self)(self.template, chunks, split_vars, ptransform)
902954

903955
# Need to do a full rechunking.
@@ -982,23 +1034,25 @@ def mean(
9821034

9831035
def head(self, **indexers_kwargs: int) -> Dataset:
9841036
"""Return a Dataset with the first N elements of each dimension."""
985-
if not isinstance(self.ptransform, core.DatasetToChunks):
1037+
_, ptransform = _split_lazy_pcollection(self._ptransform)
1038+
if not isinstance(ptransform, core.DatasetToChunks):
9861039
raise ValueError(
9871040
'head() is only supported on untransformed datasets, with '
9881041
'ptransform=DatasetToChunks. This dataset has '
989-
f'ptransform={self.ptransform}'
1042+
f'ptransform={ptransform}'
9901043
)
9911044
return self._head(**indexers_kwargs)
9921045

9931046
_tail = _whole_dataset_method('tail')
9941047

9951048
def tail(self, **indexers_kwargs: int) -> Dataset:
9961049
"""Return a Dataset with the last N elements of each dimension."""
997-
if not isinstance(self.ptransform, core.DatasetToChunks):
1050+
_, ptransform = _split_lazy_pcollection(self._ptransform)
1051+
if not isinstance(ptransform, core.DatasetToChunks):
9981052
raise ValueError(
9991053
'tail() is only supported on untransformed datasets, with '
10001054
'ptransform=DatasetToChunks. This dataset has '
1001-
f'ptransform={self.ptransform}'
1055+
f'ptransform={ptransform}'
10021056
)
10031057
return self._tail(**indexers_kwargs)
10041058

0 commit comments

Comments
 (0)