1717
1818 import xarray_beam as xbeam
1919
20- transform = (
21- xbeam.Dataset.from_zarr(input_path)
22- .rechunk({'time': -1, 'latitude': 10, 'longitude': 10})
23- .map_blocks(lambda x: x.median('time'))
24- .to_zarr(output_path)
25- )
2620 with beam.Pipeline() as p:
27- p | transform
21+ (
22+ xbeam.Dataset.from_zarr(input_path, pipeline=p)
23+ .rechunk({'time': -1, 'latitude': 10, 'longitude': 10})
24+ .map_blocks(lambda x: x.median('time'))
25+ .to_zarr(output_path)
26+ )
2827"""
2928from __future__ import annotations
3029
@@ -371,21 +370,21 @@ def method(self: Dataset, *args, **kwargs) -> Dataset:
371370 chunks = {k : v for k , v in self .chunks .items () if k in template .dims }
372371
373372 label = _get_label (method_name )
374- if isinstance (self .ptransform , core .DatasetToChunks ):
373+
374+ pipeline , ptransform = _split_lazy_pcollection (self ._ptransform )
375+ if isinstance (ptransform , core .DatasetToChunks ):
375376 # Some transformations (e.g., indexing) can be applied much less
376377 # expensively to xarray.Dataset objects rather than via Xarray-Beam. Try
377378 # to preserve this option for downstream transformations if possible.
378- dataset = func (self . ptransform .dataset )
379+ dataset = func (ptransform .dataset )
379380 ptransform = core .DatasetToChunks (dataset , chunks , self .split_vars )
380- ptransform .label = _concat_labels (self .ptransform .label , label )
381+ ptransform .label = _concat_labels (ptransform .label , label )
382+ if pipeline is not None :
383+ ptransform = _LazyPCollection (pipeline , ptransform )
381384 else :
382385 ptransform = self .ptransform | label >> beam .MapTuple (
383386 functools .partial (
384- _apply_to_each_chunk ,
385- func ,
386- method_name ,
387- self .chunks ,
388- chunks
387+ _apply_to_each_chunk , func , method_name , self .chunks , chunks
389388 )
390389 )
391390 return Dataset (template , chunks , self .split_vars , ptransform )
@@ -405,6 +404,43 @@ def apply(self, name: str) -> str:
405404_get_label = _CountNamer ().apply
406405
407406
407+ @dataclasses .dataclass (frozen = True )
408+ class _LazyPCollection :
409+ """Pipeline and PTransform not yet been combined into a PCollection."""
410+ # Beam does not provide a public API for manipulating Pipeline objects, so
411+ # instead of applying pipelines eagerly, we store them in this wrapper. This
412+ # allows for performance optimizations specialized to Xarray-Beam PTransforms,
413+ # (in particular for DatasetToChunks) even in the case where a pipeline is
414+ # supplied.
415+ pipeline : beam .Pipeline
416+ ptransform : beam .PTransform
417+
418+ # Cache the evaluated PCollection, so we apply the transform to the pipeline
419+ # at most once. Otherwise, reapplying the same transform results in reused
420+ # labels in the same pipeline, which is an error in Beam.
421+ @functools .cached_property
422+ def evaluated (self ) -> beam .PCollection :
423+ return self .pipeline | self .ptransform
424+
425+
426+ def _split_lazy_pcollection (
427+ value : beam .PTransform | beam .PCollection | _LazyPCollection ,
428+ ) -> tuple [beam .Pipeline | None , beam .PTransform | beam .PCollection ]:
429+ if isinstance (value , _LazyPCollection ):
430+ return value .pipeline , value .ptransform
431+ else :
432+ return None , value
433+
434+
435+ def _as_eager_pcollection_or_ptransform (
436+ value : beam .PTransform | beam .PCollection | _LazyPCollection ,
437+ ) -> beam .PTransform | beam .PCollection :
438+ if isinstance (value , _LazyPCollection ):
439+ return value .evaluated
440+ else :
441+ return value
442+
443+
408444@core .export
409445@dataclasses .dataclass
410446class Dataset :
@@ -415,7 +451,7 @@ def __init__(
415451 template : xarray .Dataset ,
416452 chunks : Mapping [str , int ],
417453 split_vars : bool ,
418- ptransform : beam .PTransform ,
454+ ptransform : beam .PTransform | beam . PCollection | _LazyPCollection ,
419455 ):
420456 """Low level interface for creating a new Dataset, without validation.
421457
@@ -429,7 +465,7 @@ def __init__(
429465 use :py:func:`xarray_beam.normalize_chunks`.
430466 split_vars: whether variables are split between separate elements in the
431467 ptransform, or all stored in the same element.
432- ptransform: Beam PTransform of ``(xbeam.Key, xarray.Dataset)`` tuples with
468+ ptransform: Beam collection of ``(xbeam.Key, xarray.Dataset)`` tuples with
433469 this dataset's data.
434470 """
435471 self ._template = template
@@ -453,9 +489,9 @@ def split_vars(self) -> bool:
453489 return self ._split_vars
454490
455491 @property
456- def ptransform (self ) -> beam .PTransform :
492+ def ptransform (self ) -> beam .PTransform | beam . PCollection :
457493 """Beam PTransform of (xbeam.Key, xarray.Dataset) with this dataset's data."""
458- return self ._ptransform
494+ return _as_eager_pcollection_or_ptransform ( self ._ptransform )
459495
460496 @property
461497 def sizes (self ) -> Mapping [str , int ]:
@@ -507,7 +543,7 @@ def __repr__(self):
507543 plural = 's' if chunk_count != 1 else ''
508544 return (
509545 '<xarray_beam.Dataset>\n '
510- f'PTransform: { self .ptransform } \n '
546+ f'PTransform: { self ._ptransform } \n '
511547 f'Chunks: { chunk_size } ({ chunks_str } )\n '
512548 f'Template: { total_size } ({ chunk_count } chunk{ plural } )\n '
513549 + textwrap .indent ('\n ' .join (base .split ('\n ' )[1 :]), ' ' * 4 )
@@ -516,7 +552,7 @@ def __repr__(self):
516552 @classmethod
517553 def from_ptransform (
518554 cls ,
519- ptransform : beam .PTransform ,
555+ ptransform : beam .PTransform | beam . PCollection ,
520556 * ,
521557 template : xarray .Dataset ,
522558 chunks : Mapping [str | types .EllipsisType , int ],
@@ -535,9 +571,9 @@ def from_ptransform(
535571 outputs are valid.
536572
537573 Args:
538- ptransform: A Beam PTransform that yields ``(Key, xarray.Dataset)`` pairs.
539- You only need to set ``offsets`` on these keys, ``vars`` will be
540- automatically set based on the dataset if ``split_vars`` is True.
574+ ptransform: A Beam collection of ``(Key, xarray.Dataset)`` pairs. You only
575+ need to set ``offsets`` on these keys, ``vars`` will be automatically
576+ set based on the dataset if ``split_vars`` is True.
541577 template: An ``xarray.Dataset`` object representing the schema
542578 (coordinates, dimensions, data variables, and attributes) of the full
543579 dataset, as produced by :py:func:`xarray_beam.make_template`, with data
@@ -577,6 +613,7 @@ def from_xarray(
577613 * ,
578614 split_vars : bool = False ,
579615 previous_chunks : Mapping [str , int ] | None = None ,
616+ pipeline : beam .Pipeline | None = None ,
580617 ) -> Dataset :
581618 """Create an xarray_beam.Dataset from an xarray.Dataset.
582619
@@ -588,13 +625,17 @@ def from_xarray(
588625 ptransform, or all stored in the same element.
589626 previous_chunks: chunks hint used for parsing string values in ``chunks``
590627 with ``normalize_chunks()``.
628+ pipeline: Beam pipeline to use for this dataset. If not provided, you will
629+ need apply a pipeline later to compute this dataset.
591630 """
592631 template = zarr .make_template (source )
593632 if previous_chunks is None :
594633 previous_chunks = source .sizes
595634 chunks = normalize_chunks (chunks , template , split_vars , previous_chunks )
596635 ptransform = core .DatasetToChunks (source , chunks , split_vars )
597636 ptransform .label = _get_label ('from_xarray' )
637+ if pipeline is not None :
638+ ptransform = _LazyPCollection (pipeline , ptransform )
598639 return cls (template , dict (chunks ), split_vars , ptransform )
599640
600641 @classmethod
@@ -604,6 +645,7 @@ def from_zarr(
604645 * ,
605646 chunks : UnnormalizedChunks | None = None ,
606647 split_vars : bool = False ,
648+ pipeline : beam .Pipeline | None = None ,
607649 ) -> Dataset :
608650 """Create an xarray_beam.Dataset from a Zarr store.
609651
@@ -614,6 +656,8 @@ def from_zarr(
614656 provided, the chunk sizes will be inferred from the Zarr file.
615657 split_vars: whether variables are split between separate elements in the
616658 ptransform, or all stored in the same element.
659+ pipeline: Beam pipeline to use for this dataset. If not provided, you will
660+ need apply a pipeline later to compute this dataset.
617661
618662 Returns:
619663 New Dataset created from the Zarr store.
@@ -622,9 +666,14 @@ def from_zarr(
622666 if chunks is None :
623667 chunks = previous_chunks
624668 result = cls .from_xarray (
625- source , chunks , split_vars = split_vars , previous_chunks = previous_chunks
669+ source ,
670+ chunks ,
671+ split_vars = split_vars ,
672+ previous_chunks = previous_chunks ,
626673 )
627674 result .ptransform .label = _get_label ('from_zarr' )
675+ if pipeline is not None :
676+ result ._ptransform = _LazyPCollection (pipeline , result .ptransform )
628677 return result
629678
630679 def _check_shards_or_chunks (
@@ -650,7 +699,7 @@ def to_zarr(
650699 zarr_shards : UnnormalizedChunks | None = None ,
651700 zarr_format : int | None = None ,
652701 stage_locally : bool | None = None ,
653- ) -> beam .PTransform :
702+ ) -> beam .PTransform | beam . PCollection :
654703 """Write this dataset to a Zarr file.
655704
656705 The extensive options for controlling chunking and sharding are intended for
@@ -688,7 +737,7 @@ def to_zarr(
688737 path.
689738
690739 Returns:
691- Beam PTransform that writes the dataset to a Zarr file.
740+ Beam transform that writes the dataset to a Zarr file.
692741 """
693742 if zarr_shards is not None :
694743 zarr_shards = normalize_chunks (
@@ -889,15 +938,18 @@ def rechunk(
889938 )
890939 label = _get_label ('rechunk' )
891940
892- if isinstance (self .ptransform , core .DatasetToChunks ) and all (
941+ pipeline , ptransform = _split_lazy_pcollection (self ._ptransform )
942+ if isinstance (ptransform , core .DatasetToChunks ) and all (
893943 chunks [k ] % self .chunks [k ] == 0 for k in chunks
894944 ):
895945 # Rechunking can be performed by re-reading the source dataset with new
896946 # chunks, rather than using a separate rechunking transform.
897947 ptransform = core .DatasetToChunks (
898- self . ptransform .dataset , chunks , split_vars
948+ ptransform .dataset , chunks , split_vars
899949 )
900- ptransform .label = _concat_labels (self .ptransform .label , label )
950+ ptransform .label = _concat_labels (ptransform .label , label )
951+ if pipeline is not None :
952+ ptransform = _LazyPCollection (pipeline , ptransform )
901953 return type (self )(self .template , chunks , split_vars , ptransform )
902954
903955 # Need to do a full rechunking.
@@ -982,23 +1034,25 @@ def mean(
9821034
9831035 def head (self , ** indexers_kwargs : int ) -> Dataset :
9841036 """Return a Dataset with the first N elements of each dimension."""
985- if not isinstance (self .ptransform , core .DatasetToChunks ):
1037+ _ , ptransform = _split_lazy_pcollection (self ._ptransform )
1038+ if not isinstance (ptransform , core .DatasetToChunks ):
9861039 raise ValueError (
9871040 'head() is only supported on untransformed datasets, with '
9881041 'ptransform=DatasetToChunks. This dataset has '
989- f'ptransform={ self . ptransform } '
1042+ f'ptransform={ ptransform } '
9901043 )
9911044 return self ._head (** indexers_kwargs )
9921045
9931046 _tail = _whole_dataset_method ('tail' )
9941047
9951048 def tail (self , ** indexers_kwargs : int ) -> Dataset :
9961049 """Return a Dataset with the last N elements of each dimension."""
997- if not isinstance (self .ptransform , core .DatasetToChunks ):
1050+ _ , ptransform = _split_lazy_pcollection (self ._ptransform )
1051+ if not isinstance (ptransform , core .DatasetToChunks ):
9981052 raise ValueError (
9991053 'tail() is only supported on untransformed datasets, with '
10001054 'ptransform=DatasetToChunks. This dataset has '
1001- f'ptransform={ self . ptransform } '
1055+ f'ptransform={ ptransform } '
10021056 )
10031057 return self ._tail (** indexers_kwargs )
10041058
0 commit comments