Skip to content

Commit 5966aab

Browse files
committed
Add tensor/batch factories and copies.
Signed-off-by: Michał Zientkiewicz <mzient@gmail.com>
1 parent 82274b4 commit 5966aab

File tree

4 files changed

+49
-9
lines changed

4 files changed

+49
-9
lines changed

dali/python/nvidia/dali/experimental/dali2/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818

1919
from ._eval_mode import EvalMode
2020
from ._type import * # noqa: F403
21-
from ._tensor import Tensor
22-
from ._batch import Batch
21+
from ._tensor import Tensor, tensor
22+
from ._batch import Batch, batch
2323
from ._device import Device
2424

2525
from . import fn

dali/python/nvidia/dali/experimental/dali2/_batch.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from ._device import Device
2121
from . import _eval_mode
2222
from . import _invocation
23+
import copy
2324

2425

2526
class BatchedSlice:
@@ -74,10 +75,12 @@ def __init__(
7475
device: Optional[Device] = None,
7576
layout: Optional[str] = None,
7677
invocation_result: Optional[_invocation.InvocationResult] = None,
78+
copy: bool = False,
7779
):
7880
assert isinstance(layout, str) or layout is None
7981
self._wraps_external_data = False
8082
self._tensors = None
83+
copied = False
8184
if tensors is not None:
8285
self._tensors = []
8386
if len(tensors) == 0:
@@ -99,6 +102,8 @@ def __init__(
99102
self._tensors.append(sample)
100103
if sample._wraps_external_data:
101104
self._wraps_external_data = True
105+
else:
106+
copied = True
102107

103108
if dtype is not None:
104109
if not isinstance(dtype, DType):
@@ -112,6 +117,9 @@ def __init__(
112117
if self._tensors and self._tensors[0]._shape:
113118
self._ndim = len(self._tensors[0]._shape)
114119

120+
if copy and self._backend is not None and not copied:
121+
self.assign(self.to_device(self.device, force_copy=True).evaluate())
122+
115123
if _eval_mode.EvalMode.current().value >= _eval_mode.EvalMode.eager.value:
116124
self.evaluate()
117125

@@ -175,8 +183,8 @@ def tensors(self):
175183
t._backend = self._backend[i]
176184
return self._tensors
177185

178-
def to_device(self, device: Device) -> "Batch":
179-
if self.device == device:
186+
def to_device(self, device: Device, force_copy: bool = False) -> "Batch":
187+
if self.device == device and not force_copy:
180188
return self
181189
else:
182190
with device:
@@ -393,3 +401,15 @@ def __xor__(self, other):
393401

394402
def __rxor__(self, other):
395403
return _arithm_op("bitxor", other, self)
404+
405+
406+
def batch(
407+
tensors: Union[List[Any], Batch],
408+
dtype: Optional[DType] = None,
409+
device: Optional[Device] = None,
410+
layout: Optional[str] = None,
411+
):
412+
if isinstance(tensors, Batch):
413+
batch = tensors.to_device(device, force_copy=True).evaluate()
414+
else:
415+
return Batch(tensors, dtype=dtype, device=device, layout=layout, copy=True)

dali/python/nvidia/dali/experimental/dali2/_op_builder.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,7 @@ def call(*inputs, batch_size=None, device=None, **raw_kwargs):
431431

432432
# If device is not specified, infer it from the inputs and call_args
433433
if device is None:
434+
434435
def _infer_device():
435436
for inp in inputs:
436437
if inp is None:
@@ -445,6 +446,7 @@ def _infer_device():
445446
if dev is not None and dev.device_type == "gpu":
446447
return dev
447448
return _device.Device("cpu")
449+
448450
device = _infer_device()
449451
elif not isinstance(device, _device.Device):
450452
device = _device.Device(device)

dali/python/nvidia/dali/experimental/dali2/_tensor.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def __init__(
4040
batch: Optional[Any] = None,
4141
index_in_batch: Optional[int] = None,
4242
invocation_result: Optional[_invocation.InvocationResult] = None,
43+
copy: bool = False,
4344
):
4445
if layout is None:
4546
layout = ""
@@ -52,6 +53,7 @@ def __init__(
5253
self._index_in_batch = index_in_batch
5354
self._invocation_result = None
5455
self._wraps_external_data = False
56+
copied = False
5557

5658
from . import fn
5759

@@ -77,13 +79,15 @@ def __init__(
7779
self.assign(data)
7880
self._wraps_external_data = data._wraps_external_data
7981
else:
80-
self.assign(data.to_device(device).evaluate())
82+
dev = data.to_device(device).evaluate()
83+
if dev is not self:
84+
copied = True
85+
self.assign(dev)
86+
self._wraps_external_data = not copied
8187
else:
8288
self.assign(fn.cast(data, dtype, device=device).evaluate())
83-
return
8489
elif isinstance(data, TensorSlice):
8590
self._slice = data
86-
return
8791
elif hasattr(data, "__dlpack__"):
8892
self._backend = TensorCPU(data, layout)
8993
self._wraps_external_data = True
@@ -99,10 +103,12 @@ def __init__(
99103
layout,
100104
False,
101105
)
106+
copied = True
102107
self._wraps_external_data = False
103108
self._dtype = dtype
104109
else:
105110
self._backend = TensorCPU(np.array(data), layout, False)
111+
copied = True
106112
self._wraps_external_data = False
107113

108114
if device is not None:
@@ -131,6 +137,9 @@ def __init__(
131137
if _eval_mode.EvalMode.current().value >= _eval_mode.EvalMode.eager.value:
132138
self.evaluate()
133139

140+
if copy and self._backend is not None and not copied:
141+
self.assign(self.to_device(self.device, force_copy=True).evaluate())
142+
134143
def _is_external(self) -> bool:
135144
return self._wraps_external_data
136145

@@ -150,8 +159,8 @@ def device(self) -> Device:
150159
else:
151160
raise RuntimeError("Device not set")
152161

153-
def to_device(self, device: Device) -> "Tensor":
154-
if self.device == device:
162+
def to_device(self, device: Device, force_copy: bool = False) -> "Tensor":
163+
if self.device == device and not force_copy:
155164
return self
156165
else:
157166
with device:
@@ -548,3 +557,12 @@ def evaluate(self):
548557
from . import fn
549558

550559
return fn.tensor_subscript(self._tensor, **args).evaluate()
560+
561+
562+
def tensor(
563+
data: Any,
564+
dtype: Optional[Any] = None,
565+
device: Optional[Device] = None,
566+
layout: Optional[str] = None,
567+
):
568+
return Tensor(data, dtype=dtype, device=device, layout=layout, copy=True)

0 commit comments

Comments
 (0)