diff --git a/dali/operators/generic/resize/tensor_resize_cpu.cc b/dali/operators/generic/resize/tensor_resize_cpu.cc index d231fb86506..59624b2a4bf 100644 --- a/dali/operators/generic/resize/tensor_resize_cpu.cc +++ b/dali/operators/generic/resize/tensor_resize_cpu.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -36,7 +36,7 @@ class TensorResizeCPU : public TensorResize { } // namespace tensor_resize -DALI_SCHEMA(experimental__TensorResize) +DALI_SCHEMA(TensorResize) .DocStr(R"code(Resize tensors.)code") .NumInput(1) .NumOutput(1) @@ -45,8 +45,26 @@ DALI_SCHEMA(experimental__TensorResize) .AddParent("ResamplingFilterAttr") .AddParent("TensorResizeAttr"); +// Deprecated alias +DALI_SCHEMA(experimental__TensorResize) + .AddParent("TensorResize") + .DocStr("Legacy alias for :meth:`tensor_resize`.") + .NumInput(1) + .NumOutput(1) + .MakeDocHidden() + .SupportVolumetric() + .AllowSequences() + .Deprecate( + "2.0", + "TensorResize", + "This operator was moved out from the experimental phase, " + "and is now a regular DALI operator. This is just a deprecated " + "alias kept for backward compatibility."); +// Kept for backwards compatibility DALI_REGISTER_OPERATOR(experimental__TensorResize, tensor_resize::TensorResizeCPU, CPU); +DALI_REGISTER_OPERATOR(TensorResize, tensor_resize::TensorResizeCPU, CPU); + } // namespace dali diff --git a/dali/operators/generic/resize/tensor_resize_gpu.cc b/dali/operators/generic/resize/tensor_resize_gpu.cc index d6331dce974..4cdf0ff39c9 100644 --- a/dali/operators/generic/resize/tensor_resize_gpu.cc +++ b/dali/operators/generic/resize/tensor_resize_gpu.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -37,6 +37,9 @@ class TensorResizeGPU : public TensorResize { } // namespace tensor_resize +// Kept for backwards compatibility DALI_REGISTER_OPERATOR(experimental__TensorResize, tensor_resize::TensorResizeGPU, GPU); +DALI_REGISTER_OPERATOR(TensorResize, tensor_resize::TensorResizeGPU, GPU); + } // namespace dali diff --git a/dali/operators/image/color/debayer.cc b/dali/operators/image/color/debayer.cc index 59308698ee0..49ece097b23 100644 --- a/dali/operators/image/color/debayer.cc +++ b/dali/operators/image/color/debayer.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -16,7 +16,7 @@ namespace dali { -DALI_SCHEMA(experimental__Debayer) +DALI_SCHEMA(Debayer) .DocStr(R"code(Performs image demosaicing/debayering. Converts single-channel image to RGB using specified color filter array. @@ -114,4 +114,20 @@ Different algorithms are supported on the GPU and CPU. .InputLayout(0, {"HW", "HWC", "FHW", "FHWC"}) .AllowSequences(); +// Deprecated alias +DALI_SCHEMA(experimental__Debayer) + .AddParent("Debayer") + .DocStr("Legacy alias for :meth:`debayer`.") + .NumInput(1) + .NumOutput(1) + .MakeDocHidden() + .InputLayout(0, {"HW", "HWC", "FHW", "FHWC"}) + .AllowSequences() + .Deprecate( + "2.0", + "Debayer", + "This operator was moved out from the experimental phase, " + "and is now a regular DALI operator. This is just a deprecated " + "alias kept for backward compatibility."); + } // namespace dali diff --git a/dali/operators/image/color/debayer_cpu.cc b/dali/operators/image/color/debayer_cpu.cc index aa7dbcc1815..55510e6f6bb 100644 --- a/dali/operators/image/color/debayer_cpu.cc +++ b/dali/operators/image/color/debayer_cpu.cc @@ -134,6 +134,9 @@ class DebayerCPU : public Debayer { } }; +// Kept for backwards compatibility DALI_REGISTER_OPERATOR(experimental__Debayer, DebayerCPU, CPU); +DALI_REGISTER_OPERATOR(Debayer, DebayerCPU, CPU); + } // namespace dali diff --git a/dali/operators/image/color/debayer_gpu.cc b/dali/operators/image/color/debayer_gpu.cc index 0e9436daa21..87fb0c23842 100644 --- a/dali/operators/image/color/debayer_gpu.cc +++ b/dali/operators/image/color/debayer_gpu.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -107,6 +107,9 @@ void DebayerGPU::RunImpl(Workspace &ws) { impl_->RunImpl(ws); } +// Kept for backwards compatibility DALI_REGISTER_OPERATOR(experimental__Debayer, DebayerGPU, GPU); +DALI_REGISTER_OPERATOR(Debayer, DebayerGPU, GPU); + } // namespace dali diff --git a/dali/operators/image/color/equalize.cc b/dali/operators/image/color/equalize.cc index 6ac935a12e6..6094cc974bf 100644 --- a/dali/operators/image/color/equalize.cc +++ b/dali/operators/image/color/equalize.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2023-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -22,7 +22,7 @@ namespace dali { -DALI_SCHEMA(experimental__Equalize) +DALI_SCHEMA(Equalize) .DocStr(R"code(Performs grayscale/per-channel histogram equalization. The supported inputs are images and videos of uint8_t type.)code") @@ -31,6 +31,22 @@ The supported inputs are images and videos of uint8_t type.)code") .InputLayout(0, {"HW", "HWC", "CHW", "FHW", "FHWC", "FCHW"}) .AllowSequences(); +// Deprecated alias +DALI_SCHEMA(experimental__Equalize) + .AddParent("Equalize") + .DocStr("Legacy alias for :meth:`equalize`.") + .NumInput(1) + .NumOutput(1) + .MakeDocHidden() + .InputLayout(0, {"HW", "HWC", "CHW", "FHW", "FHWC", "FCHW"}) + .AllowSequences() + .Deprecate( + "2.0", + "Equalize", + "This operator was moved out from the experimental phase, " + "and is now a regular DALI operator. This is just a deprecated " + "alias kept for backward compatibility."); + namespace equalize { class EqualizeCPU : public Equalize { @@ -100,6 +116,9 @@ class EqualizeCPU : public Equalize { } // namespace equalize +// Kept for backwards compatibility DALI_REGISTER_OPERATOR(experimental__Equalize, equalize::EqualizeCPU, CPU); +DALI_REGISTER_OPERATOR(Equalize, equalize::EqualizeCPU, CPU); + } // namespace dali diff --git a/dali/operators/image/color/equalize.cu b/dali/operators/image/color/equalize.cu index d7a7b40b6e2..796fd8d981d 100644 --- a/dali/operators/image/color/equalize.cu +++ b/dali/operators/image/color/equalize.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2023-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -73,6 +73,9 @@ class EqualizeGPU : public Equalize { } // namespace equalize +// Kept for backwards compatibility DALI_REGISTER_OPERATOR(experimental__Equalize, equalize::EqualizeGPU, GPU); +DALI_REGISTER_OPERATOR(Equalize, equalize::EqualizeGPU, GPU); + } // namespace dali diff --git a/dali/operators/image/convolution/filter.cc b/dali/operators/image/convolution/filter.cc index cf9bf013304..bac65c1a9c7 100644 --- a/dali/operators/image/convolution/filter.cc +++ b/dali/operators/image/convolution/filter.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2023-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -25,7 +25,7 @@ namespace dali { -DALI_SCHEMA(experimental__Filter) +DALI_SCHEMA(Filter) .DocStr(R"code(Convolves the image with the provided filter. .. note:: @@ -118,6 +118,22 @@ If not set, the input type is used. the values will be clamped to the output type range. )code"); +// Deprecated alias +DALI_SCHEMA(experimental__Filter) + .AddParent("Filter") + .DocStr("Legacy alias for :meth:`filter`.") + .NumInput(2, 3) + .NumOutput(1) + .InputDevice(1, 3, InputDevice::MatchBackendOrCPU) + .AllowSequences() + .MakeDocHidden() + .Deprecate( + "2.0", + "Filter", + "This operator was moved out from the experimental phase, " + "and is now a regular DALI operator. This is just a deprecated " + "alias kept for backward compatibility."); + namespace filter { namespace ocv { @@ -352,6 +368,9 @@ std::unique_ptr> Filter::GetFilterImpl( return filter::get_filter_cpu_op_impl(spec, input_desc); } +// Kept for backwards compatibility DALI_REGISTER_OPERATOR(experimental__Filter, Filter, CPU); +DALI_REGISTER_OPERATOR(Filter, Filter, CPU); + } // namespace dali diff --git a/dali/operators/image/convolution/filter.cu b/dali/operators/image/convolution/filter.cu index 4dc207566c5..1e4e1975ebf 100644 --- a/dali/operators/image/convolution/filter.cu +++ b/dali/operators/image/convolution/filter.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2023-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -46,6 +46,9 @@ std::unique_ptr> Filter::GetFilterImpl( return filter::get_filter_gpu_op_impl(spec, input_desc); } +// Kept for backwards compatibility DALI_REGISTER_OPERATOR(experimental__Filter, Filter, GPU); +DALI_REGISTER_OPERATOR(Filter, Filter, GPU); + } // namespace dali diff --git a/dali/operators/video/decoder/video_decoder_cpu.cc b/dali/operators/video/decoder/video_decoder_cpu.cc index 0babfaba51f..3b245377c41 100644 --- a/dali/operators/video/decoder/video_decoder_cpu.cc +++ b/dali/operators/video/decoder/video_decoder_cpu.cc @@ -17,7 +17,7 @@ namespace dali { -DALI_SCHEMA(experimental__decoders__Video) +DALI_SCHEMA(decoders__Video) .DocStr( R"code(Decodes videos from in-memory streams. @@ -195,6 +195,19 @@ Building an index is particularly useful when decoding a small number of frames apart or starting playback from a frame deep into the video.)code", true); +DALI_SCHEMA(experimental__decoders__Video) + .AddParent("decoders__Video") + .DocStr("Legacy alias for :meth:`decoders.video`.") + .NumInput(1) + .NumOutput(1) + .MakeDocHidden() + .Deprecate( + "2.0", + "decoders__Video", + "This operator was moved out from the experimental phase, " + "and is now a regular DALI operator. This is just a deprecated " + "alias kept for backward compatibility."); + class VideoDecoderCpu : public VideoDecoderBase { public: explicit VideoDecoderCpu(const OpSpec &spec) : @@ -202,5 +215,6 @@ class VideoDecoderCpu : public VideoDecoderBase { }; DALI_REGISTER_OPERATOR(experimental__decoders__Video, VideoDecoderCpu, CPU); +DALI_REGISTER_OPERATOR(decoders__Video, VideoDecoderCpu, CPU); } // namespace dali diff --git a/dali/operators/video/decoder/video_decoder_mixed.cc b/dali/operators/video/decoder/video_decoder_mixed.cc index 72eb2b954de..69a3f105b7b 100644 --- a/dali/operators/video/decoder/video_decoder_mixed.cc +++ b/dali/operators/video/decoder/video_decoder_mixed.cc @@ -23,5 +23,6 @@ class VideoDecoderMixed : public VideoDecoderBase= 1: + out0 = [ + out0_lambda(address_as_void_pointer(ptr), shape) + for ptr, shape in zip(out_arr[0], out_shapes_np[0]) + ] + if num_outs >= 2: + out1 = [ + out1_lambda(address_as_void_pointer(ptr), shape) + for ptr, shape in zip(out_arr[1], out_shapes_np[1]) + ] + if num_outs >= 3: + out2 = [ + out2_lambda(address_as_void_pointer(ptr), shape) + for ptr, shape in zip(out_arr[2], out_shapes_np[2]) + ] + if num_outs >= 4: + out3 = [ + out3_lambda(address_as_void_pointer(ptr), shape) + for ptr, shape in zip(out_arr[3], out_shapes_np[3]) + ] + if num_outs >= 5: + out4 = [ + out4_lambda(address_as_void_pointer(ptr), shape) + for ptr, shape in zip(out_arr[4], out_shapes_np[4]) + ] + if num_outs >= 6: + out5 = [ + out5_lambda(address_as_void_pointer(ptr), shape) + for ptr, shape in zip(out_arr[5], out_shapes_np[5]) + ] + + in0 = in1 = in2 = in3 = in4 = in5 = None + in_shapes_np = _get_shape_view(in_shapes_ptr, in_ndims_ptr, num_ins, num_samples) + in_arr = carray( + address_as_void_pointer(in_ptr), (num_ins, num_samples), dtype=np.int64 + ) + if num_ins >= 1: + in0 = [ + in0_lambda(address_as_void_pointer(ptr), shape) + for ptr, shape in zip(in_arr[0], in_shapes_np[0]) + ] + if num_ins >= 2: + in1 = [ + in1_lambda(address_as_void_pointer(ptr), shape) + for ptr, shape in zip(in_arr[1], in_shapes_np[1]) + ] + if num_ins >= 3: + in2 = [ + in2_lambda(address_as_void_pointer(ptr), shape) + for ptr, shape in zip(in_arr[2], in_shapes_np[2]) + ] + if num_ins >= 4: + in3 = [ + in3_lambda(address_as_void_pointer(ptr), shape) + for ptr, shape in zip(in_arr[3], in_shapes_np[3]) + ] + if num_ins >= 5: + in4 = [ + in4_lambda(address_as_void_pointer(ptr), shape) + for ptr, shape in zip(in_arr[4], in_shapes_np[4]) + ] + if num_ins >= 6: + in5 = [ + in5_lambda(address_as_void_pointer(ptr), shape) + for ptr, shape in zip(in_arr[5], in_shapes_np[5]) + ] + + run_fn_lambda( + run_fn, out0, out1, out2, out3, out4, out5, in0, in1, in2, in3, in4, in5 + ) + + else: + + @cfunc(self._run_fn_sig(batch_processing=False), nopython=True) + def run_cfunc( + out_ptr, + out_shapes_ptr, + out_ndims_ptr, + num_outs, + in_ptr, + in_shapes_ptr, + in_ndims_ptr, + num_ins, + ): + out0 = out1 = out2 = out3 = out4 = out5 = None + out_shapes_np = _get_shape_view(out_shapes_ptr, out_ndims_ptr, num_outs, 1) + out_arr = carray(address_as_void_pointer(out_ptr), num_outs, dtype=np.int64) + if num_outs >= 1: + out0 = out0_lambda(address_as_void_pointer(out_arr[0]), out_shapes_np[0][0]) + if num_outs >= 2: + out1 = out1_lambda(address_as_void_pointer(out_arr[1]), out_shapes_np[1][0]) + if num_outs >= 3: + out2 = out2_lambda(address_as_void_pointer(out_arr[2]), out_shapes_np[2][0]) + if num_outs >= 4: + out3 = out3_lambda(address_as_void_pointer(out_arr[3]), out_shapes_np[3][0]) + if num_outs >= 5: + out4 = out4_lambda(address_as_void_pointer(out_arr[4]), out_shapes_np[4][0]) + if num_outs >= 6: + out5 = out5_lambda(address_as_void_pointer(out_arr[5]), out_shapes_np[5][0]) + + in0 = in1 = in2 = in3 = in4 = in5 = None + in_shapes_np = _get_shape_view(in_shapes_ptr, in_ndims_ptr, num_ins, 1) + in_arr = carray(address_as_void_pointer(in_ptr), num_ins, dtype=np.int64) + if num_ins >= 1: + in0 = in0_lambda(address_as_void_pointer(in_arr[0]), in_shapes_np[0][0]) + if num_ins >= 2: + in1 = in1_lambda(address_as_void_pointer(in_arr[1]), in_shapes_np[1][0]) + if num_ins >= 3: + in2 = in2_lambda(address_as_void_pointer(in_arr[2]), in_shapes_np[2][0]) + if num_ins >= 4: + in3 = in3_lambda(address_as_void_pointer(in_arr[3]), in_shapes_np[3][0]) + if num_ins >= 5: + in4 = in4_lambda(address_as_void_pointer(in_arr[4]), in_shapes_np[4][0]) + if num_ins >= 6: + in5 = in5_lambda(address_as_void_pointer(in_arr[5]), in_shapes_np[5][0]) + + run_fn_lambda( + run_fn, out0, out1, out2, out3, out4, out5, in0, in1, in2, in3, in4, in5 + ) + + return run_cfunc.address + + def __call__(self, *inputs, **kwargs): + pipeline = Pipeline.current() + inputs = ops._preprocess_inputs(inputs, self.__class__.__name__, self._device, None) + if pipeline is None: + Pipeline._raise_pipeline_required(self.__class__.__name__) + for inp in inputs: + if not isinstance(inp, _DataNode): + raise TypeError( + ( + "Expected inputs of type `DataNode`. Received input of type '{}'. " + + "Python Operators do not support Multiple Input Sets." + ).format(type(inp).__name__) + ) + + kwargs.update( + { + "run_fn": self.run_fn, + "out_types": self.out_types, + "in_types": self.in_types, + "outs_ndim": self.outs_ndim, + "ins_ndim": self.ins_ndim, + "batch_processing": self.batch_processing, + } + ) + if self.setup_fn is not None: + kwargs.update({"setup_fn": self.setup_fn}) + if self.device == "gpu": + kwargs.update( + { + "blocks": self.blocks, + "threads_per_block": self.threads_per_block, + } + ) + + return super().__call__(*inputs, **kwargs) + + def __init__( + self, + run_fn, + out_types, + in_types, + outs_ndim, + ins_ndim, + setup_fn=None, + device="cpu", + batch_processing=False, + blocks=None, + threads_per_block=None, + **kwargs, + ): + if device == "gpu": + NumbaFunction._check_minimal_numba_version() + NumbaFunction._check_cuda_compatibility() + + # TODO(klecki): Normalize the types into lists first, than apply the checks + assert len(in_types) == len(ins_ndim), ( + "Number of input types " "and input dimensions should match." + ) + assert len(out_types) == len(outs_ndim), ( + "Number of output types " "and output dimensions should match." + ) + + if "float16" in dir(numba_types): + for t in [*in_types, *out_types]: + if t == dali_types.FLOAT16: + raise RuntimeError( + "Numba does not support float16 for " + "current Python version. " + "Python 3.7 or newer is required" + ) + + if device == "gpu": + assert batch_processing is False, "Currently batch processing for GPU is not supported." + assert len(blocks) == 3, ( + "`blocks` array should contain 3 numbers, while received: " f"{len(blocks)}" + ) + for i, block_dim in enumerate(blocks): + assert block_dim > 0, ( + "All dimensions should be positive. Value specified in " + f"`blocks` at index {i} is nonpositive: {block_dim}" + ) + + assert len(threads_per_block) == 3, ( + "`threads_per_block` array should contain 3 " + f"numbers, while received: {len(threads_per_block)}" + ) + for i, threads in enumerate(threads_per_block): + assert threads > 0, ( + "All dimensions should be positive. " + "Value specified in `threads_per_block` at index " + f"{i} is nonpositive: {threads}" + ) + + if not isinstance(outs_ndim, list): + outs_ndim = [outs_ndim] + if not isinstance(ins_ndim, list): + ins_ndim = [ins_ndim] + if not isinstance(out_types, list): + out_types = [out_types] + if not isinstance(in_types, list): + in_types = [in_types] + + super().__init__(device=device, **kwargs) + + if device == "gpu": + self.run_fn = self._get_run_fn_gpu(run_fn, out_types + in_types, outs_ndim + ins_ndim) + else: + self.run_fn = self._get_run_fn_cpu( + run_fn, out_types, in_types, outs_ndim, ins_ndim, batch_processing + ) + self.setup_fn = self._get_setup_fn_cpu(setup_fn) + self.out_types = out_types + self.in_types = in_types + self.outs_ndim = outs_ndim + self.ins_ndim = ins_ndim + self.num_outputs = len(out_types) + self.batch_processing = batch_processing + self._preserve = True + self.blocks = blocks + self.threads_per_block = threads_per_block + + @staticmethod + def _check_minimal_numba_version(throw: bool = True): + current_version = Version(nb.__version__) + toolkit_version = nb_cuda.runtime.get_version() + if toolkit_version[0] not in minimal_numba_version: + if throw: + raise RuntimeError(f"Unsupported CUDA toolkit version: {toolkit_version}") + else: + return False + min_ver = minimal_numba_version[toolkit_version[0]] + if current_version < min_ver: + if throw: + raise RuntimeError( + f"Insufficient Numba version. Numba GPU operator " + f"requires Numba {str(min_ver)} or higher. " + f"Detected version: {str(Version(nb.__version__))}." + ) + else: + return False + return True + + @staticmethod + def _check_cuda_compatibility(throw: bool = True): + toolkit_version = nb_cuda.runtime.get_version() + driver_version = nb_cuda.driver.driver.get_version() + + # numba_cuda should handle the compatibility between toolkit and driver versions + # otherwise check if the driver and runtime matches, or if the last working numba version + # matches the driver for CUDA 12 + try: + # try importing cuda.core as it can be used later to check the compatibility + # it is okay to fail as it may not be installed, the check later can handle this + import cuda.core + except ImportError: + pass + + # numba_cuda similarly to numba provides numba.cuda module so we need + # to check is package is present to learn who provides it + numba_cuda_missing = not importlib.util.find_spec("numba_cuda") + cuda_core_too_old = ( + importlib.util.find_spec("core") + and importlib.util.find_spec("cuda.core") + and Version(cuda.core.__version__) <= Version("0.3.1") + and nb_cuda.driver.driver.get_version()[0] > 12 + ) + toolkit_newer_than_driver = toolkit_version > driver_version + numba_too_old_for_driver = ( + Version(nb.__version__) <= Version("0.61.2") + and nb_cuda.driver.driver.get_version()[0] > 12 + ) + + if numba_cuda_missing or cuda_core_too_old: + if toolkit_newer_than_driver or numba_too_old_for_driver: + if throw: + raise RuntimeError( + f"Environment is not compatible with Numba GPU operator. " + f"Driver version is {driver_version} and CUDA Toolkit " + f"version is {toolkit_version}. " + "Driver cannot be older than the CUDA Toolkit" + ) + else: + return False + return True + + +# Register the main operator +ops._wrap_op(NumbaFunction, "fn", "nvidia.dali.plugin.numba") + +# Kept for backwards compatibility - deprecated alias +ops._wrap_op(NumbaFunction, "fn.experimental", "nvidia.dali.plugin.numba") + +# Import submodules from . import fn # noqa F401 diff --git a/dali/python/nvidia/dali/plugin/numba/__init__.pyi b/dali/python/nvidia/dali/plugin/numba/__init__.pyi new file mode 100644 index 00000000000..74681af44e4 --- /dev/null +++ b/dali/python/nvidia/dali/plugin/numba/__init__.pyi @@ -0,0 +1,70 @@ +# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Optional, Union, List, Sequence, Callable + +from nvidia.dali.data_node import DataNode + +from nvidia.dali.types import DALIDataType +from nvidia.dali._typing import TensorLikeIn + +from . import fn as fn + +class NumbaFunction: + """Invokes a njit compiled Numba function. + + The run function should be a Python function that can be compiled in Numba ``nopython`` mode.""" + + def __init__( + self, + run_fn: Optional[Callable[..., None]] = None, + out_types: Optional[List[DALIDataType]] = None, + in_types: Optional[List[DALIDataType]] = None, + outs_ndim: Optional[List[int]] = None, + ins_ndim: Optional[List[int]] = None, + setup_fn: Optional[ + Callable[[Sequence[Sequence[Any]], Sequence[Sequence[Any]], None]] + ] = None, + device: str = "cpu", + batch_processing: bool = False, + blocks: Optional[Sequence[int]] = None, + threads_per_block: Optional[Sequence[int]] = None, + bytes_per_sample_hint: Union[Sequence[int], int, None] = [0], + seed: Optional[int] = -1, + ) -> None: ... + def __call__( + self, + __input_0: Union[DataNode, TensorLikeIn], + __input_1: Union[DataNode, TensorLikeIn, None] = None, + __input_2: Union[DataNode, TensorLikeIn, None] = None, + __input_3: Union[DataNode, TensorLikeIn, None] = None, + __input_4: Union[DataNode, TensorLikeIn, None] = None, + __input_5: Union[DataNode, TensorLikeIn, None] = None, + /, + *, + run_fn: Optional[Callable[..., None]] = None, + out_types: Optional[List[DALIDataType]] = None, + in_types: Optional[List[DALIDataType]] = None, + outs_ndim: Optional[List[int]] = None, + ins_ndim: Optional[List[int]] = None, + setup_fn: Optional[ + Callable[[Sequence[Sequence[Any]], Sequence[Sequence[Any]], None]] + ] = None, + device: str = "cpu", + batch_processing: bool = False, + blocks: Optional[Sequence[int]] = None, + threads_per_block: Optional[Sequence[int]] = None, + bytes_per_sample_hint: Union[Sequence[int], int, None] = [0], + seed: Optional[int] = -1, + ) -> Union[DataNode, Sequence[DataNode]]: ... diff --git a/dali/python/nvidia/dali/plugin/numba/experimental/__init__.py b/dali/python/nvidia/dali/plugin/numba/experimental/__init__.py index e1785b9bd55..b9e4473a64f 100644 --- a/dali/python/nvidia/dali/plugin/numba/experimental/__init__.py +++ b/dali/python/nvidia/dali/plugin/numba/experimental/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2021-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,584 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from packaging.version import Version - -from nvidia.dali.pipeline import Pipeline -from nvidia.dali.data_node import DataNode as _DataNode -from nvidia.dali import ops -from nvidia.dali import types as dali_types -from numba import types as numba_types -from numba import njit, cfunc, carray -from numba import cuda as nb_cuda -import numpy as np -import numba as nb -import importlib - - -_to_numpy = { - dali_types.BOOL: "bool_", - dali_types.UINT8: "uint8", - dali_types.UINT16: "uint16", - dali_types.UINT32: "uint32", - dali_types.UINT64: "uint64", - dali_types.INT8: "int8", - dali_types.INT16: "int16", - dali_types.INT32: "int32", - dali_types.INT64: "int64", - dali_types.FLOAT16: "float16", - dali_types.FLOAT: "float32", - dali_types.FLOAT64: "float64", -} - -_to_numba = { - dali_types.BOOL: numba_types.boolean, - dali_types.UINT8: numba_types.uint8, - dali_types.UINT16: numba_types.uint16, - dali_types.UINT32: numba_types.uint32, - dali_types.UINT64: numba_types.uint64, - dali_types.INT8: numba_types.int8, - dali_types.INT16: numba_types.int16, - dali_types.INT32: numba_types.int32, - dali_types.INT64: numba_types.int64, - dali_types.FLOAT16: numba_types.float16, - dali_types.FLOAT: numba_types.float32, - dali_types.FLOAT64: numba_types.float64, -} - - -# Minimal version of Numba that is required for Numba GPU operator to work -minimal_numba_version = { - 11: Version("0.55.2"), - 12: Version("0.57.0"), -} - - -@nb.extending.intrinsic -def address_as_void_pointer(typingctx, src): - from numba.core import types, cgutils - - sig = types.voidptr(src) - - def codegen(cgctx, builder, sig, args): - return builder.inttoptr(args[0], cgutils.voidptr_t) - - return sig, codegen - - -@njit -def _get_shape_view(shapes_ptr, ndims_ptr, num_dims, num_samples): - ndims = carray(address_as_void_pointer(ndims_ptr), num_dims, dtype=np.int32) - samples = carray(address_as_void_pointer(shapes_ptr), (num_dims, num_samples), dtype=np.int64) - ret = [] - for sample, size in zip(samples, ndims): - d = [] - for shape_ptr in sample: - d.append(carray(address_as_void_pointer(shape_ptr), size, dtype=np.int64)) - ret.append(d) - return ret - - -class NumbaFunction( - ops.python_op_factory("NumbaFunctionBase", "NumbaFunction", "NumbaFuncImpl", generated=False) -): - _impl_module = "nvidia.dali.plugin.numba" - ops.register_cpu_op("NumbaFunction") - ops.register_gpu_op("NumbaFunction") - - @property - def spec(self): - return self._spec - - @property - def schema(self): - return self._schema - - @property - def device(self): - return self._device - - @property - def preserve(self): - return self._preserve - - def _setup_fn_sig(self): - return numba_types.void( - numba_types.uint64, - numba_types.uint64, - numba_types.int32, - numba_types.uint64, - numba_types.uint64, - numba_types.int32, - numba_types.int32, - ) - - def _run_fn_sig(self, batch_processing=False): - sig_types = [] - sig_types.append(numba_types.uint64) - sig_types.append(numba_types.uint64) - sig_types.append(numba_types.uint64) - sig_types.append(numba_types.int32) - - sig_types.append(numba_types.uint64) - sig_types.append(numba_types.uint64) - sig_types.append(numba_types.uint64) - sig_types.append(numba_types.int32) - - if batch_processing: - sig_types.append(numba_types.int32) - return numba_types.void(*sig_types) - - def _get_carray_eval_lambda(self, dtype, ndim): - eval_string = "lambda ptr, shape: carray(ptr, (" - for i in range(ndim): - eval_string += "shape[{}]".format(i) - eval_string += ", " if i + 1 != ndim else "), " - eval_string += "dtype=np.{})".format(_to_numpy[dtype]) - return njit(eval(eval_string)) # nosec B307 - - def _get_carrays_eval_lambda(self, types, ndim): - ret = [self._get_carray_eval_lambda(dtype, ndim) for dtype, ndim in zip(types, ndim)] - ret += [njit(eval(("lambda x, y: None"))) for i in range(6 - len(types))] # nosec B307 - return tuple(ret) - - def _get_run_fn_lambda(self, num_outs, num_ins): - eval_string = ( - "lambda run_fn, out0, out1, out2, out3, out4, out5, " - "in0, in1, in2, in3, in4, in5 : " - "run_fn(" - ) - for i in range(num_outs): - eval_string += "out{}".format(i) - eval_string += ", " if i + 1 != num_outs else ", " - for i in range(num_ins): - eval_string += "in{}".format(i) - eval_string += ", " if i + 1 != num_ins else ")" - return njit(eval(eval_string)) # nosec B307 - - def _get_setup_fn_cpu(self, setup_fn): - setup_fn_address = None - if setup_fn is not None: - setup_fn = njit(setup_fn) - - @cfunc(self._setup_fn_sig(), nopython=True) - def setup_cfunc( - out_shapes_ptr, - out_ndims_ptr, - num_outs, - in_shapes_ptr, - in_ndims_ptr, - num_ins, - num_samples, - ): - out_shapes_np = _get_shape_view( - out_shapes_ptr, out_ndims_ptr, num_outs, num_samples - ) - in_shapes_np = _get_shape_view(in_shapes_ptr, in_ndims_ptr, num_outs, num_samples) - setup_fn(out_shapes_np, in_shapes_np) - - setup_fn_address = setup_cfunc.address - - return setup_fn_address - - def _get_run_fn_gpu(self, run_fn, types, dims): - nvvm_options = {"fastmath": False, "opt": 3} - - cuda_arguments = [] - for dali_type, ndim in zip(types, dims): - cuda_arguments.append(numba_types.Array(_to_numba[dali_type], ndim, "C")) - - if Version(nb.__version__) < Version("0.57.0"): - cres = nb_cuda.compiler.compile_cuda(run_fn, numba_types.void, cuda_arguments) - else: - pipeline = Pipeline.current() - device_id = pipeline.device_id - old_device = nb_cuda.api.get_current_device().id - cc = nb_cuda.api.select_device(device_id).compute_capability - nb_cuda.api.select_device(old_device) - cres = nb_cuda.compiler.compile_cuda( - run_fn, - numba_types.void, - cuda_arguments, - nvvm_options=nvvm_options, - fastmath=False, - cc=cc, - ) - - tgt_ctx = cres.target_context - code = run_fn.__code__ - filename = code.co_filename - linenum = code.co_firstlineno - return_value = 0 - if Version(nb.__version__) < Version("0.57.0"): - nvvm_options["debug"] = False - nvvm_options["lineinfo"] = False - lib, _ = tgt_ctx.prepare_cuda_kernel( - cres.library, cres.fndesc, True, nvvm_options, filename, linenum - ) - return_value = lib.get_cufunc().handle.value - else: - if hasattr(tgt_ctx, "prepare_cuda_kernel"): - lib, _ = tgt_ctx.prepare_cuda_kernel( - cres.library, cres.fndesc, False, True, nvvm_options, filename, linenum - ) - return_value = lib.get_cufunc().handle.value - else: - from numba.cuda.compiler import kernel_fixup - - lib = cres.library - kernel = lib.get_function(cres.fndesc.llvm_func_name) - kernel_fixup(kernel, debug=False) - lib._entry_name = cres.fndesc.llvm_func_name - return_value = int(lib.get_cufunc().handle) - - return return_value - - def _get_run_fn_cpu(self, run_fn, out_types, in_types, outs_ndim, ins_ndim, batch_processing): - ( - out0_lambda, - out1_lambda, - out2_lambda, - out3_lambda, - out4_lambda, - out5_lambda, - ) = self._get_carrays_eval_lambda(out_types, outs_ndim) - ( - in0_lambda, - in1_lambda, - in2_lambda, - in3_lambda, - in4_lambda, - in5_lambda, - ) = self._get_carrays_eval_lambda(in_types, ins_ndim) - run_fn = njit(run_fn) - run_fn_lambda = self._get_run_fn_lambda(len(out_types), len(in_types)) - if batch_processing: - - @cfunc(self._run_fn_sig(batch_processing=True), nopython=True) - def run_cfunc( - out_ptr, - out_shapes_ptr, - out_ndims_ptr, - num_outs, - in_ptr, - in_shapes_ptr, - in_ndims_ptr, - num_ins, - num_samples, - ): - out0 = out1 = out2 = out3 = out4 = out5 = None - out_shapes_np = _get_shape_view( - out_shapes_ptr, out_ndims_ptr, num_outs, num_samples - ) - out_arr = carray( - address_as_void_pointer(out_ptr), (num_outs, num_samples), dtype=np.int64 - ) - if num_outs >= 1: - out0 = [ - out0_lambda(address_as_void_pointer(ptr), shape) - for ptr, shape in zip(out_arr[0], out_shapes_np[0]) - ] - if num_outs >= 2: - out1 = [ - out1_lambda(address_as_void_pointer(ptr), shape) - for ptr, shape in zip(out_arr[1], out_shapes_np[1]) - ] - if num_outs >= 3: - out2 = [ - out2_lambda(address_as_void_pointer(ptr), shape) - for ptr, shape in zip(out_arr[2], out_shapes_np[2]) - ] - if num_outs >= 4: - out3 = [ - out3_lambda(address_as_void_pointer(ptr), shape) - for ptr, shape in zip(out_arr[3], out_shapes_np[3]) - ] - if num_outs >= 5: - out4 = [ - out4_lambda(address_as_void_pointer(ptr), shape) - for ptr, shape in zip(out_arr[4], out_shapes_np[4]) - ] - if num_outs >= 6: - out5 = [ - out5_lambda(address_as_void_pointer(ptr), shape) - for ptr, shape in zip(out_arr[5], out_shapes_np[5]) - ] - - in0 = in1 = in2 = in3 = in4 = in5 = None - in_shapes_np = _get_shape_view(in_shapes_ptr, in_ndims_ptr, num_ins, num_samples) - in_arr = carray( - address_as_void_pointer(in_ptr), (num_ins, num_samples), dtype=np.int64 - ) - if num_ins >= 1: - in0 = [ - in0_lambda(address_as_void_pointer(ptr), shape) - for ptr, shape in zip(in_arr[0], in_shapes_np[0]) - ] - if num_ins >= 2: - in1 = [ - in1_lambda(address_as_void_pointer(ptr), shape) - for ptr, shape in zip(in_arr[1], in_shapes_np[1]) - ] - if num_ins >= 3: - in2 = [ - in2_lambda(address_as_void_pointer(ptr), shape) - for ptr, shape in zip(in_arr[2], in_shapes_np[2]) - ] - if num_ins >= 4: - in3 = [ - in3_lambda(address_as_void_pointer(ptr), shape) - for ptr, shape in zip(in_arr[3], in_shapes_np[3]) - ] - if num_ins >= 5: - in4 = [ - in4_lambda(address_as_void_pointer(ptr), shape) - for ptr, shape in zip(in_arr[4], in_shapes_np[4]) - ] - if num_ins >= 6: - in5 = [ - in5_lambda(address_as_void_pointer(ptr), shape) - for ptr, shape in zip(in_arr[5], in_shapes_np[5]) - ] - - run_fn_lambda( - run_fn, out0, out1, out2, out3, out4, out5, in0, in1, in2, in3, in4, in5 - ) - - else: - - @cfunc(self._run_fn_sig(batch_processing=False), nopython=True) - def run_cfunc( - out_ptr, - out_shapes_ptr, - out_ndims_ptr, - num_outs, - in_ptr, - in_shapes_ptr, - in_ndims_ptr, - num_ins, - ): - out0 = out1 = out2 = out3 = out4 = out5 = None - out_shapes_np = _get_shape_view(out_shapes_ptr, out_ndims_ptr, num_outs, 1) - out_arr = carray(address_as_void_pointer(out_ptr), num_outs, dtype=np.int64) - if num_outs >= 1: - out0 = out0_lambda(address_as_void_pointer(out_arr[0]), out_shapes_np[0][0]) - if num_outs >= 2: - out1 = out1_lambda(address_as_void_pointer(out_arr[1]), out_shapes_np[1][0]) - if num_outs >= 3: - out2 = out2_lambda(address_as_void_pointer(out_arr[2]), out_shapes_np[2][0]) - if num_outs >= 4: - out3 = out3_lambda(address_as_void_pointer(out_arr[3]), out_shapes_np[3][0]) - if num_outs >= 5: - out4 = out4_lambda(address_as_void_pointer(out_arr[4]), out_shapes_np[4][0]) - if num_outs >= 6: - out5 = out5_lambda(address_as_void_pointer(out_arr[5]), out_shapes_np[5][0]) - - in0 = in1 = in2 = in3 = in4 = in5 = None - in_shapes_np = _get_shape_view(in_shapes_ptr, in_ndims_ptr, num_ins, 1) - in_arr = carray(address_as_void_pointer(in_ptr), num_ins, dtype=np.int64) - if num_ins >= 1: - in0 = in0_lambda(address_as_void_pointer(in_arr[0]), in_shapes_np[0][0]) - if num_ins >= 2: - in1 = in1_lambda(address_as_void_pointer(in_arr[1]), in_shapes_np[1][0]) - if num_ins >= 3: - in2 = in2_lambda(address_as_void_pointer(in_arr[2]), in_shapes_np[2][0]) - if num_ins >= 4: - in3 = in3_lambda(address_as_void_pointer(in_arr[3]), in_shapes_np[3][0]) - if num_ins >= 5: - in4 = in4_lambda(address_as_void_pointer(in_arr[4]), in_shapes_np[4][0]) - if num_ins >= 6: - in5 = in5_lambda(address_as_void_pointer(in_arr[5]), in_shapes_np[5][0]) - - run_fn_lambda( - run_fn, out0, out1, out2, out3, out4, out5, in0, in1, in2, in3, in4, in5 - ) - - return run_cfunc.address - - def __call__(self, *inputs, **kwargs): - pipeline = Pipeline.current() - inputs = ops._preprocess_inputs(inputs, self.__class__.__name__, self._device, None) - if pipeline is None: - Pipeline._raise_pipeline_required(self.__class__.__name__) - for inp in inputs: - if not isinstance(inp, _DataNode): - raise TypeError( - ( - "Expected inputs of type `DataNode`. Received input of type '{}'. " - + "Python Operators do not support Multiple Input Sets." - ).format(type(inp).__name__) - ) - - kwargs.update( - { - "run_fn": self.run_fn, - "out_types": self.out_types, - "in_types": self.in_types, - "outs_ndim": self.outs_ndim, - "ins_ndim": self.ins_ndim, - "batch_processing": self.batch_processing, - } - ) - if self.setup_fn is not None: - kwargs.update({"setup_fn": self.setup_fn}) - if self.device == "gpu": - kwargs.update( - { - "blocks": self.blocks, - "threads_per_block": self.threads_per_block, - } - ) - - return super().__call__(*inputs, **kwargs) - - def __init__( - self, - run_fn, - out_types, - in_types, - outs_ndim, - ins_ndim, - setup_fn=None, - device="cpu", - batch_processing=False, - blocks=None, - threads_per_block=None, - **kwargs, - ): - if device == "gpu": - NumbaFunction._check_minimal_numba_version() - NumbaFunction._check_cuda_compatibility() - - # TODO(klecki): Normalize the types into lists first, than apply the checks - assert len(in_types) == len(ins_ndim), ( - "Number of input types " "and input dimensions should match." - ) - assert len(out_types) == len(outs_ndim), ( - "Number of output types " "and output dimensions should match." - ) - - if "float16" in dir(numba_types): - for t in [*in_types, *out_types]: - if t == dali_types.FLOAT16: - raise RuntimeError( - "Numba does not support float16 for " - "current Python version. " - "Python 3.7 or newer is required" - ) - - if device == "gpu": - assert batch_processing is False, "Currently batch processing for GPU is not supported." - assert len(blocks) == 3, ( - "`blocks` array should contain 3 numbers, while received: " f"{len(blocks)}" - ) - for i, block_dim in enumerate(blocks): - assert block_dim > 0, ( - "All dimensions should be positive. Value specified in " - f"`blocks` at index {i} is nonpositive: {block_dim}" - ) - - assert len(threads_per_block) == 3, ( - "`threads_per_block` array should contain 3 " - f"numbers, while received: {len(threads_per_block)}" - ) - for i, threads in enumerate(threads_per_block): - assert threads > 0, ( - "All dimensions should be positive. " - "Value specified in `threads_per_block` at index " - f"{i} is nonpositive: {threads}" - ) - - if not isinstance(outs_ndim, list): - outs_ndim = [outs_ndim] - if not isinstance(ins_ndim, list): - ins_ndim = [ins_ndim] - if not isinstance(out_types, list): - out_types = [out_types] - if not isinstance(in_types, list): - in_types = [in_types] - - super().__init__(device=device, **kwargs) - - if device == "gpu": - self.run_fn = self._get_run_fn_gpu(run_fn, out_types + in_types, outs_ndim + ins_ndim) - else: - self.run_fn = self._get_run_fn_cpu( - run_fn, out_types, in_types, outs_ndim, ins_ndim, batch_processing - ) - self.setup_fn = self._get_setup_fn_cpu(setup_fn) - self.out_types = out_types - self.in_types = in_types - self.outs_ndim = outs_ndim - self.ins_ndim = ins_ndim - self.num_outputs = len(out_types) - self.batch_processing = batch_processing - self._preserve = True - self.blocks = blocks - self.threads_per_block = threads_per_block - - @staticmethod - def _check_minimal_numba_version(throw: bool = True): - current_version = Version(nb.__version__) - toolkit_version = nb_cuda.runtime.get_version() - if toolkit_version[0] not in minimal_numba_version: - if throw: - raise RuntimeError(f"Unsupported CUDA toolkit version: {toolkit_version}") - else: - return False - min_ver = minimal_numba_version[toolkit_version[0]] - if current_version < min_ver: - if throw: - raise RuntimeError( - f"Insufficient Numba version. Numba GPU operator " - f"requires Numba {str(min_ver)} or higher. " - f"Detected version: {str(Version(nb.__version__))}." - ) - else: - return False - return True - - @staticmethod - def _check_cuda_compatibility(throw: bool = True): - toolkit_version = nb_cuda.runtime.get_version() - driver_version = nb_cuda.driver.driver.get_version() - - # numba_cuda should handle the compatibility between toolkit and driver versions - # otherwise check if the driver and runtime matches, or if the last working numba version - # matches the driver for CUDA 12 - try: - # try importing cuda.core as it can be used later to check the compatibility - # it is okay to fail as it may not be installed, the check later can handle this - import cuda.core - except ImportError: - pass - - # numba_cuda similarly to numba provides numba.cuda module so we need - # to check is package is present to learn who provides it - numba_cuda_missing = not importlib.util.find_spec("numba_cuda") - cuda_core_too_old = ( - importlib.util.find_spec("core") - and importlib.util.find_spec("cuda.core") - and Version(cuda.core.__version__) <= Version("0.3.1") - and nb_cuda.driver.driver.get_version()[0] > 12 - ) - toolkit_newer_than_driver = toolkit_version > driver_version - numba_too_old_for_driver = ( - Version(nb.__version__) <= Version("0.61.2") - and nb_cuda.driver.driver.get_version()[0] > 12 - ) - - if numba_cuda_missing or cuda_core_too_old: - if toolkit_newer_than_driver or numba_too_old_for_driver: - if throw: - raise RuntimeError( - f"Environment is not compatible with Numba GPU operator. " - f"Driver version is {driver_version} and CUDA Toolkit " - f"version is {toolkit_version}. " - "Driver cannot be older than the CUDA Toolkit" - ) - else: - return False - return True - - -ops._wrap_op(NumbaFunction, "fn.experimental", "nvidia.dali.plugin.numba") +# Backwards compatibility: NumbaFunction is now in the parent module +# This module is kept for backwards compatibility +from nvidia.dali.plugin.numba import NumbaFunction # noqa F401 diff --git a/dali/python/nvidia/dali/plugin/numba/experimental/__init__.pyi b/dali/python/nvidia/dali/plugin/numba/experimental/__init__.pyi index 869d349d3a8..1d0da2b7092 100644 --- a/dali/python/nvidia/dali/plugin/numba/experimental/__init__.pyi +++ b/dali/python/nvidia/dali/plugin/numba/experimental/__init__.pyi @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2023-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,57 +12,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Union, List, Sequence, Callable - -from nvidia.dali.data_node import DataNode - -from nvidia.dali.types import DALIDataType -from nvidia.dali._typing import TensorLikeIn - -class NumbaFunction: - """Invokes a njit compiled Numba function. - - The run function should be a Python function that can be compiled in Numba ``nopython`` mode.""" - - def __init__( - self, - run_fn: Optional[Callable[..., None]] = None, - out_types: Optional[List[DALIDataType]] = None, - in_types: Optional[List[DALIDataType]] = None, - outs_ndim: Optional[List[int]] = None, - ins_ndim: Optional[List[int]] = None, - setup_fn: Optional[ - Callable[[Sequence[Sequence[Any]], Sequence[Sequence[Any]], None]] - ] = None, - device: str = "cpu", - batch_processing: bool = False, - blocks: Optional[Sequence[int]] = None, - threads_per_block: Optional[Sequence[int]] = None, - bytes_per_sample_hint: Union[Sequence[int], int, None] = [0], - seed: Optional[int] = -1, - ) -> None: ... - def __call__( - self, - __input_0: Union[DataNode, TensorLikeIn], - __input_1: Union[DataNode, TensorLikeIn, None] = None, - __input_2: Union[DataNode, TensorLikeIn, None] = None, - __input_3: Union[DataNode, TensorLikeIn, None] = None, - __input_4: Union[DataNode, TensorLikeIn, None] = None, - __input_5: Union[DataNode, TensorLikeIn, None] = None, - /, - *, - run_fn: Optional[Callable[..., None]] = None, - out_types: Optional[List[DALIDataType]] = None, - in_types: Optional[List[DALIDataType]] = None, - outs_ndim: Optional[List[int]] = None, - ins_ndim: Optional[List[int]] = None, - setup_fn: Optional[ - Callable[[Sequence[Sequence[Any]], Sequence[Sequence[Any]], None]] - ] = None, - device: str = "cpu", - batch_processing: bool = False, - blocks: Optional[Sequence[int]] = None, - threads_per_block: Optional[Sequence[int]] = None, - bytes_per_sample_hint: Union[Sequence[int], int, None] = [0], - seed: Optional[int] = -1, - ) -> Union[DataNode, Sequence[DataNode]]: ... +# Backwards compatibility: NumbaFunction is now in the parent module +from nvidia.dali.plugin.numba import NumbaFunction as NumbaFunction diff --git a/dali/python/nvidia/dali/plugin/numba/fn/__init__.py b/dali/python/nvidia/dali/plugin/numba/fn/__init__.py index c11f17cdcd5..02db62baa27 100644 --- a/dali/python/nvidia/dali/plugin/numba/fn/__init__.py +++ b/dali/python/nvidia/dali/plugin/numba/fn/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2023-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,5 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from . import experimental # noqa F401 diff --git a/dali/python/nvidia/dali/plugin/numba/fn/__init__.pyi b/dali/python/nvidia/dali/plugin/numba/fn/__init__.pyi new file mode 100644 index 00000000000..0ccccf912ed --- /dev/null +++ b/dali/python/nvidia/dali/plugin/numba/fn/__init__.pyi @@ -0,0 +1,51 @@ +# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Optional, Union, List, Sequence, Callable + +from nvidia.dali.data_node import DataNode + +from nvidia.dali.types import DALIDataType +from nvidia.dali._typing import TensorLikeIn + +from . import experimental as experimental + +def numba_function( + __input_0: Union[DataNode, TensorLikeIn], + __input_1: Union[DataNode, TensorLikeIn, None] = None, + __input_2: Union[DataNode, TensorLikeIn, None] = None, + __input_3: Union[DataNode, TensorLikeIn, None] = None, + __input_4: Union[DataNode, TensorLikeIn, None] = None, + __input_5: Union[DataNode, TensorLikeIn, None] = None, + /, + *, + run_fn: Callable[..., None], + out_types: List[DALIDataType], + in_types: List[DALIDataType], + outs_ndim: List[int], + ins_ndim: List[int], + setup_fn: Optional[Callable[[Sequence[Sequence[Any]], Sequence[Sequence[Any]], None]]] = None, + batch_processing: bool = False, + blocks: Optional[Sequence[int]] = None, + threads_per_block: Optional[Sequence[int]] = None, + bytes_per_sample_hint: Union[Sequence[int], int, None] = [0], + preserve: Optional[bool] = False, + seed: Optional[int] = -1, + device: Optional[str] = None, + name: Optional[str] = None, +) -> Union[DataNode, Sequence[DataNode]]: + """Invokes a njit compiled Numba function. + + The run function should be a Python function that can be compiled in Numba ``nopython`` mode.""" + ... diff --git a/dali/python/nvidia/dali/plugin/numba/fn/experimental/__init__.py b/dali/python/nvidia/dali/plugin/numba/fn/experimental/__init__.py index dbfe137c14d..3145d704e7c 100644 --- a/dali/python/nvidia/dali/plugin/numba/fn/experimental/__init__.py +++ b/dali/python/nvidia/dali/plugin/numba/fn/experimental/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2023-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,3 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +# Backwards compatibility: numba_function is now in the parent fn module +from nvidia.dali.plugin.numba.fn import numba_function as numba_function + +__all__ = ["numba_function"] diff --git a/dali/python/nvidia/dali/plugin/numba/fn/experimental/__init__.pyi b/dali/python/nvidia/dali/plugin/numba/fn/experimental/__init__.pyi index bb4192ba81f..3145d704e7c 100644 --- a/dali/python/nvidia/dali/plugin/numba/fn/experimental/__init__.pyi +++ b/dali/python/nvidia/dali/plugin/numba/fn/experimental/__init__.pyi @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2023-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,38 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Union, List, Sequence, Callable +# Backwards compatibility: numba_function is now in the parent fn module +from nvidia.dali.plugin.numba.fn import numba_function as numba_function -from nvidia.dali.data_node import DataNode - -from nvidia.dali.types import DALIDataType -from nvidia.dali._typing import TensorLikeIn - -def numba_function( - __input_0: Union[DataNode, TensorLikeIn], - __input_1: Union[DataNode, TensorLikeIn, None] = None, - __input_2: Union[DataNode, TensorLikeIn, None] = None, - __input_3: Union[DataNode, TensorLikeIn, None] = None, - __input_4: Union[DataNode, TensorLikeIn, None] = None, - __input_5: Union[DataNode, TensorLikeIn, None] = None, - /, - *, - run_fn: Callable[..., None], - out_types: List[DALIDataType], - in_types: List[DALIDataType], - outs_ndim: List[int], - ins_ndim: List[int], - setup_fn: Optional[Callable[[Sequence[Sequence[Any]], Sequence[Sequence[Any]], None]]] = None, - batch_processing: bool = False, - blocks: Optional[Sequence[int]] = None, - threads_per_block: Optional[Sequence[int]] = None, - bytes_per_sample_hint: Union[Sequence[int], int, None] = [0], - preserve: Optional[bool] = False, - seed: Optional[int] = -1, - device: Optional[str] = None, - name: Optional[str] = None, -) -> Union[DataNode, Sequence[DataNode]]: - """Invokes a njit compiled Numba function. - - The run function should be a Python function that can be compiled in Numba ``nopython`` mode.""" - ... +__all__ = ["numba_function"] diff --git a/dali/test/python/checkpointing/test_dali_checkpointing.py b/dali/test/python/checkpointing/test_dali_checkpointing.py index 0e46f307e35..5d092a16d30 100644 --- a/dali/test/python/checkpointing/test_dali_checkpointing.py +++ b/dali/test/python/checkpointing/test_dali_checkpointing.py @@ -1240,7 +1240,9 @@ def pipe(arg): "hidden.*", "experimental.hidden.*", "clahe", + "decoders.video", "experimental.decoders.video", + "experimental.decoders.hidden.video", "experimental.inputs.video", "plugin.video.decoder", ] diff --git a/dali/test/python/checkpointing/test_dali_stateless_operators.py b/dali/test/python/checkpointing/test_dali_stateless_operators.py index 80e84d59fbf..d27b8b3c85e 100644 --- a/dali/test/python/checkpointing/test_dali_stateless_operators.py +++ b/dali/test/python/checkpointing/test_dali_stateless_operators.py @@ -214,9 +214,9 @@ def test_resize_stateless(device): @params("cpu", "gpu") -@stateless_signed_off("experimental.tensor_resize") +@stateless_signed_off("experimental.tensor_resize", "tensor_resize") def test_tensor_resize_stateless(device): - check_single_input(fn.experimental.tensor_resize, device, axes=[0, 1], sizes=[40, 40]) + check_single_input(fn.tensor_resize, device, axes=[0, 1], sizes=[40, 40]) @params("cpu", "gpu") @@ -335,9 +335,9 @@ def test_reductions_variance_stateless(device): @params("cpu", "gpu") -@stateless_signed_off("experimental.equalize") +@stateless_signed_off("experimental.equalize", "equalize") def test_equalize_stateless(device): - check_single_input(fn.experimental.equalize, device) + check_single_input(fn.equalize, device) @stateless_signed_off("transforms.crop") @@ -469,10 +469,10 @@ def test_sphere_stateless(device): @params("cpu", "gpu") -@stateless_signed_off("experimental.filter") +@stateless_signed_off("experimental.filter", "filter") def test_filter_stateless(device): check_single_input( - lambda x, **kwargs: fn.experimental.filter(x, np.full((3, 3), 1 / 9), **kwargs), + lambda x, **kwargs: fn.filter(x, np.full((3, 3), 1 / 9), **kwargs), device, ) @@ -495,14 +495,14 @@ def pipeline_factory(): @params("cpu", "gpu") -@stateless_signed_off("experimental.debayer") +@stateless_signed_off("experimental.debayer", "debayer") def test_debayer_stateless(device): @pipeline_def(enable_checkpointing=True) def pipeline_factory(): data = fn.external_source(source=RandomBatch((40, 40)), layout="HW", batch=True) if device == "gpu": data = data.gpu() - return fn.experimental.debayer(data, blue_position=[0, 0]) + return fn.debayer(data, blue_position=[0, 0]) check_is_pipeline_stateless(pipeline_factory) @@ -772,7 +772,7 @@ def wrapper(x, **kwargs): @attr("numba") -@stateless_signed_off("experimental.numba_function") +@stateless_signed_off("experimental.numba_function", "numba_function") def test_numba_function_stateless(): import nvidia.dali.plugin.numba as dali_numba @@ -786,7 +786,7 @@ def numba_pipe(): forty_two = fn.external_source( source=lambda x: np.full((2,), 42, dtype=np.uint8), batch=False ) - out = dali_numba.fn.experimental.numba_function( + out = dali_numba.fn.numba_function( forty_two, run_fn=double_sample, out_types=[types.DALIDataType.UINT8], diff --git a/dali/test/python/decoder/test_video.py b/dali/test/python/decoder/test_video.py index 7bde338c9bf..4fbda2068d9 100644 --- a/dali/test/python/decoder/test_video.py +++ b/dali/test/python/decoder/test_video.py @@ -215,7 +215,7 @@ def ref_iter(epochs=1, device="cpu"): yield np.array(output[0]) -@params(("mixed", fn.experimental)) +@params(("mixed", fn.experimental), ("mixed", fn)) def test_video_decoder(device, module): skip_if_m60() batch_size = 3 diff --git a/dali/test/python/operator_1/test_debayer.py b/dali/test/python/operator_1/test_debayer.py index cd7ac08c2dd..6f880f50e73 100644 --- a/dali/test/python/operator_1/test_debayer.py +++ b/dali/test/python/operator_1/test_debayer.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -126,9 +126,15 @@ def get_test_data(cls, dtype): return cls.bayered_imgs, cls.npp_baseline return cls.bayered_imgs16t, cls.npp_baseline16t - @params(*enumerate(itertools.product([1, 64], bayer_patterns, ("gpu", "cpu")))) + @params( + *enumerate( + itertools.product( + [1, 64], bayer_patterns, ("gpu", "cpu"), (fn.experimental.debayer, fn.debayer) + ) + ) + ) def test_debayer_fixed_pattern(self, i, args): - (batch_size, pattern, device) = args + (batch_size, pattern, device, debayer_op) = args num_iterations = 3 test_hwc_single_channel_input = i % 2 == 1 bayered_imgs, npp_baseline = self.get_test_data(np.uint8) @@ -147,9 +153,7 @@ def debayer_pipeline(): bayer_imgs, idxs = fn.external_source(source=source, batch=False, num_outputs=2) if device == "gpu": bayer_imgs = bayer_imgs.gpu() - debayered_imgs = fn.experimental.debayer( - bayer_imgs, blue_position=blue_position(pattern) - ) + debayered_imgs = debayer_op(bayer_imgs, blue_position=blue_position(pattern)) return debayered_imgs, idxs pipe = debayer_pipeline(batch_size=batch_size, device_id=0, num_threads=4) @@ -176,8 +180,10 @@ def debayer_pipeline(): else: assert compare_image_equality(baseline, img_debayered) - @cartesian_params((1, 11, 184), (np.uint8, np.uint16), ("gpu", "cpu")) - def test_debayer_per_sample_pattern(self, batch_size, dtype, device): + @cartesian_params( + (1, 11, 184), (np.uint8, np.uint16), ("gpu", "cpu"), (fn.experimental.debayer, fn.debayer) + ) + def test_debayer_per_sample_pattern(self, batch_size, dtype, device, debayer_op): num_iterations = 3 num_patterns = len(bayer_patterns) rng = np.random.default_rng(seed=42 + batch_size) @@ -200,7 +206,7 @@ def debayer_pipeline(): ) if device == "gpu": bayer_imgs = bayer_imgs.gpu() - debayered_imgs = fn.experimental.debayer(bayer_imgs, blue_position=blue_poses) + debayered_imgs = debayer_op(bayer_imgs, blue_position=blue_poses) return debayered_imgs, blue_poses, idxs pipe = debayer_pipeline(batch_size=batch_size, device_id=0, num_threads=4) @@ -229,9 +235,11 @@ def debayer_pipeline(): assert compare_image_equality(img_debayered, baseline) @cartesian_params( - ("bilinear_ocv", "edgeaware_ocv", "vng_ocv", "gray_ocv"), (np.uint8, np.uint16) + ("bilinear_ocv", "edgeaware_ocv", "vng_ocv", "gray_ocv"), + (np.uint8, np.uint16), + (fn.experimental.debayer, fn.debayer), ) - def test_cpu_algorithms(self, algorithm: str, dtype: np.dtype): + def test_cpu_algorithms(self, algorithm: str, dtype: np.dtype, debayer_op): if algorithm == "vng_ocv" and dtype == np.uint16: # VNG algorithm is not supported for uint16 return @@ -252,7 +260,7 @@ def source(sample_info): @pipeline_def def debayer_pipeline(): bayer_imgs, idxs = fn.external_source(source=source, batch=False, num_outputs=2) - debayered_imgs = fn.experimental.debayer( + debayered_imgs = debayer_op( bayer_imgs, blue_position=blue_position(pattern), algorithm=algorithm ) return debayered_imgs, idxs @@ -315,9 +323,7 @@ def debayer_pipeline(): ) if device == "gpu": bayered_vid = bayered_vid.gpu() - debayered_vid = fn.experimental.debayer( - bayered_vid, blue_position=fn.per_frame(blue_positions) - ) + debayered_vid = fn.debayer(bayered_vid, blue_position=fn.per_frame(blue_positions)) return debayered_vid, idxs pipe = debayer_pipeline(batch_size=batch_size, device_id=0, num_threads=4) @@ -353,7 +359,7 @@ def _test_shape_pipeline(shape, dtype): @pipeline_def def pipeline(): bayer_imgs = fn.external_source(source_full_array(shape, dtype), batch=False) - return fn.experimental.debayer(bayer_imgs, blue_position=[0, 0], algorithm="bilinear_ocv") + return fn.debayer(bayer_imgs, blue_position=[0, 0], algorithm="bilinear_ocv") pipe = pipeline(batch_size=8, num_threads=4, device_id=0) pipe.run() @@ -391,7 +397,7 @@ def test_no_blue_position_specified(): @pipeline_def def pipeline(): bayer_imgs = fn.external_source(source_full_array((20, 20), np.uint8), batch=False) - return fn.experimental.debayer(bayer_imgs) + return fn.debayer(bayer_imgs) pipe = pipeline(batch_size=8, num_threads=4, device_id=0) pipe.run() @@ -404,7 +410,7 @@ def test_blue_position_outside_of_2x2_tile(blue_position_): @pipeline_def def pipeline(): bayer_imgs = fn.external_source(source_full_array((20, 20), np.uint8), batch=False) - return fn.experimental.debayer(bayer_imgs, blue_position=blue_position_) + return fn.debayer(bayer_imgs, blue_position=blue_position_) pipe = pipeline(batch_size=8, num_threads=4, device_id=0) pipe.run() @@ -419,9 +425,7 @@ def test_gpu_algorithm_unsupported(algorithm): @pipeline_def def pipeline(): bayer_imgs = fn.external_source(source_full_array((20, 20), np.uint8), batch=False) - return fn.experimental.debayer( - bayer_imgs.gpu(), blue_position=[0, 0], algorithm=algorithm - ) + return fn.debayer(bayer_imgs.gpu(), blue_position=[0, 0], algorithm=algorithm) pipe = pipeline(batch_size=8, num_threads=4, device_id=0) pipe.run() @@ -434,7 +438,7 @@ def test_cpu_algorithm_unsupported(algorithm): @pipeline_def def pipeline(): bayer_imgs = fn.external_source(source_full_array((20, 20), np.uint8), batch=False) - return fn.experimental.debayer(bayer_imgs, blue_position=[0, 0], algorithm=algorithm) + return fn.debayer(bayer_imgs, blue_position=[0, 0], algorithm=algorithm) pipe = pipeline(batch_size=8, num_threads=4, device_id=0) pipe.run() diff --git a/dali/test/python/operator_2/test_equalize.py b/dali/test/python/operator_2/test_equalize.py index 8c9524bd73f..56f2e726684 100644 --- a/dali/test/python/operator_2/test_equalize.py +++ b/dali/test/python/operator_2/test_equalize.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2023-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -40,7 +40,7 @@ def equalize_cv_baseline(img, layout): @pipeline_def -def images_pipeline(layout, dev): +def images_pipeline(layout, dev, equalize_op): images, _ = fn.readers.file( name="Reader", file_root=images_dir, prefetch_queue_depth=2, random_shuffle=True, seed=42 ) @@ -53,7 +53,7 @@ def images_pipeline(layout, dev): images = fn.decoders.image(images, device=decoder, output_type=types.RGB) if layout == "CHW": images = fn.transpose(images, perm=[2, 0, 1]) - equalized = fn.experimental.equalize(images) + equalized = equalize_op(images) return equalized, images @@ -62,15 +62,21 @@ def images_pipeline(layout, dev): itertools.product( ("cpu", "gpu"), (("HWC", 1), ("HWC", 32), ("CHW", 1), ("CHW", 7), ("HW", 253), ("HW", 128)), + (fn.experimental.equalize, fn.equalize), ) ) ) -def test_image_pipeline(dev, layout_batch_size): +def test_image_pipeline(dev, layout_batch_size, equalize_op): layout, batch_size = layout_batch_size num_iters = 2 pipe = images_pipeline( - num_threads=4, device_id=0, batch_size=batch_size, layout=layout, dev=dev + num_threads=4, + device_id=0, + batch_size=batch_size, + layout=layout, + dev=dev, + equalize_op=equalize_op, ) for _ in range(num_iters): @@ -107,7 +113,7 @@ def pipeline(): input = fn.external_source(input_sample, batch=False) if dev == "gpu": input = input.gpu() - return fn.experimental.equalize(input), input + return fn.equalize(input), input pipe = pipeline() diff --git a/dali/test/python/operator_2/test_filter.py b/dali/test/python/operator_2/test_filter.py index 09365290d23..82f2c11236c 100644 --- a/dali/test/python/operator_2/test_filter.py +++ b/dali/test/python/operator_2/test_filter.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2023-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -66,7 +66,7 @@ def source(sample_info): @pipeline_def -def images_pipeline(dev, shapes, border, in_dtype, mode): +def images_pipeline(dev, shapes, border, in_dtype, mode, filter_op=fn.filter): images, _ = fn.readers.file( name="Reader", file_root=images_dir, prefetch_queue_depth=2, random_shuffle=True, seed=42 ) @@ -81,18 +81,18 @@ def images_pipeline(dev, shapes, border, in_dtype, mode): fill_val_limit = 1 if not np.issubdtype(in_dtype, np.integer) else np.iinfo(in_dtype).max fill_values = fn.random.uniform(range=[0, fill_val_limit], dtype=np_type_to_dali(in_dtype)) if border == "constant": - convolved = fn.experimental.filter( + convolved = filter_op( images, filters, fill_values, anchor=anchors, border=border, mode=mode ) else: - convolved = fn.experimental.filter( - images, filters, anchor=anchors, border=border, mode=mode - ) + convolved = filter_op(images, filters, anchor=anchors, border=border, mode=mode) return convolved, images, filters, anchors, fill_values @pipeline_def -def sample_pipeline(sample_shapes, sample_layout, filter_shapes, border, in_dtype, mode, dev): +def sample_pipeline( + sample_shapes, sample_layout, filter_shapes, border, in_dtype, mode, dev, filter_op=fn.filter +): samples = fn.external_source( source=create_sample_source(sample_shapes, in_dtype), batch=False, layout=sample_layout ) @@ -107,13 +107,11 @@ def sample_pipeline(sample_shapes, sample_layout, filter_shapes, border, in_dtyp fill_values = fn.cast_like(fill_values, samples) in_samples = samples.gpu() if dev == "gpu" else samples if border == "constant": - convolved = fn.experimental.filter( + convolved = filter_op( in_samples, filters, fill_values, anchor=anchors, border=border, mode=mode ) else: - convolved = fn.experimental.filter( - in_samples, filters, anchor=anchors, border=border, mode=mode - ) + convolved = filter_op(in_samples, filters, anchor=anchors, border=border, mode=mode) return convolved, samples, filters, anchors, fill_values @@ -201,6 +199,7 @@ def test_image_pipeline(dev, dtype, batch_size, border, mode): 8, "101", "same", + fn.filter, ), ( np.int8, @@ -210,6 +209,7 @@ def test_image_pipeline(dev, dtype, batch_size, border, mode): 4, "101", "same", + fn.experimental.filter, ), ( np.uint8, @@ -219,6 +219,7 @@ def test_image_pipeline(dev, dtype, batch_size, border, mode): 8, "1001", "same", + fn.filter, ), ( np.float16, @@ -228,6 +229,7 @@ def test_image_pipeline(dev, dtype, batch_size, border, mode): 8, "wrap", "same", + fn.experimental.filter, ), ( np.uint8, @@ -237,6 +239,7 @@ def test_image_pipeline(dev, dtype, batch_size, border, mode): 8, "wrap", "valid", + fn.filter, ), ( np.uint16, @@ -246,6 +249,7 @@ def test_image_pipeline(dev, dtype, batch_size, border, mode): 8, "clamp", "same", + fn.experimental.filter, ), ( np.int8, @@ -255,6 +259,7 @@ def test_image_pipeline(dev, dtype, batch_size, border, mode): 8, "wrap", "same", + fn.filter, ), ( np.float32, @@ -264,6 +269,7 @@ def test_image_pipeline(dev, dtype, batch_size, border, mode): 4, "101", "same", + fn.experimental.filter, ), ], ) @@ -280,6 +286,7 @@ def test_image_pipeline(dev, dtype, batch_size, border, mode): 4, "101", "same", + fn.filter, ), ( "gpu", @@ -290,6 +297,7 @@ def test_image_pipeline(dev, dtype, batch_size, border, mode): 2, "1001", "same", + fn.experimental.filter, ), ( "gpu", @@ -300,6 +308,7 @@ def test_image_pipeline(dev, dtype, batch_size, border, mode): 4, "constant", "same", + fn.filter, ), ( "gpu", @@ -310,6 +319,7 @@ def test_image_pipeline(dev, dtype, batch_size, border, mode): 3, "101", "same", + fn.experimental.filter, ), ( "gpu", @@ -320,6 +330,7 @@ def test_image_pipeline(dev, dtype, batch_size, border, mode): 3, "101", "same", + fn.filter, ), ) @@ -328,7 +339,7 @@ def test_image_pipeline(dev, dtype, batch_size, border, mode): @attr("slow") @params(*(sample_2d_cases + sample_3d_cases)) def slow_test_samples( - dev, dtype, sample_layout, sample_shapes, filter_shapes, batch_size, border, mode + dev, dtype, sample_layout, sample_shapes, filter_shapes, batch_size, border, mode, filter_op ): num_iters = 2 @@ -343,6 +354,7 @@ def slow_test_samples( in_dtype=dtype, mode=mode, dev=dev, + filter_op=filter_op, ) if dtype == np.float32: atol = 1e-5 diff --git a/dali/test/python/test_dali_cpu_only.py b/dali/test/python/test_dali_cpu_only.py index 787a2e59bc4..ee632e2d72f 100644 --- a/dali/test/python/test_dali_cpu_only.py +++ b/dali/test/python/test_dali_cpu_only.py @@ -261,7 +261,7 @@ def test_resize_cpu(): def test_tensor_resize_cpu(): - check_single_input(fn.experimental.tensor_resize, sizes=[50, 50], axes=[0, 1]) + check_single_input(fn.tensor_resize, sizes=[50, 50], axes=[0, 1]) def test_per_frame_cpu(): @@ -1451,7 +1451,7 @@ def test_io_file_read_cpu(): def test_debayer(): check_single_input( - fn.experimental.debayer, + fn.debayer, get_data=lambda: np.full((256, 256), 128, dtype=np.uint8), batch=False, input_layout="HW", @@ -1480,6 +1480,7 @@ def test_warp_perspective(): "decoders.image_random_crop", "decoders.numpy", "experimental.debayer", + "debayer", "experimental.decoders.image", "experimental.decoders.image_crop", "experimental.decoders.image_slice", @@ -1551,6 +1552,7 @@ def test_warp_perspective(): "cast_like", "resize", "experimental.tensor_resize", + "tensor_resize", "gaussian_blur", "laplacian", "crop_mirror_normalize", @@ -1644,7 +1646,9 @@ def test_warp_perspective(): "dl_tensor_python_function", "experimental.warp_perspective", "audio_resample", + "experimental.decoders.hidden.video", "experimental.decoders.video", + "decoders.video", "zeros", "zeros_like", "ones", @@ -1667,7 +1671,9 @@ def test_warp_perspective(): "optical_flow", # not supported for CPU "experimental.audio_resample", # Alias of audio_resample (already tested) "experimental.equalize", # not supported for CPU + "equalize", # not supported for CPU "experimental.filter", # not supported for CPU + "filter", # not supported for CPU "decoders.inflate", # not supported for CPU "experimental.inflate", # not supported for CPU "experimental.remap", # operator is GPU-only diff --git a/dali/test/python/test_dali_variable_batch_size.py b/dali/test/python/test_dali_variable_batch_size.py index 835b4f4fc21..f8cfbb087ee 100644 --- a/dali/test/python/test_dali_variable_batch_size.py +++ b/dali/test/python/test_dali_variable_batch_size.py @@ -22,6 +22,7 @@ import re from functools import partial from nose_utils import SkipTest, attr, nottest +from nose2.tools import params from nvidia.dali.pipeline import Pipeline, pipeline_def from nvidia.dali.pipeline.experimental import pipeline_def as experimental_pipeline_def from nvidia.dali.types import DALIDataType @@ -347,6 +348,7 @@ def numba_setup_out_shape(out_shape, in_shape): (fn.coord_transform, {"M": 0.5}), (fn.crop, {"crop": (5, 5)}), (fn.experimental.equalize, {"devices": ["gpu"]}), + (fn.equalize, {"devices": ["gpu"]}), ( fn.erase, { @@ -375,6 +377,7 @@ def numba_setup_out_shape(out_shape, in_shape): (fn.resize, {"resize_x": 50, "resize_y": 50}), (fn.resize_crop_mirror, {"crop": [5, 5], "resize_shorter": 10, "devices": ["cpu"]}), (fn.experimental.tensor_resize, {"sizes": [50, 50], "axes": [0, 1]}), + (fn.tensor_resize, {"sizes": [50, 50], "axes": [0, 1]}), (fn.rotate, {"angle": 25}), (fn.transpose, {"perm": [2, 0, 1]}), (fn.warp_affine, {"matrix": (0.1, 0.9, 10, 0.8, -0.2, -20)}), @@ -405,25 +408,27 @@ def numba_setup_out_shape(out_shape, in_shape): numba_compatible_devices.append("cpu") if len(numba_compatible_devices) > 0 and not os.environ.get("DALI_ENABLE_SANITIZERS", None): - from nvidia.dali.plugin.numba.fn.experimental import numba_function - - ops_image_custom_args.append( - ( - numba_function, - { - "batch_processing": False, - "devices": numba_compatible_devices, - "in_types": [types.UINT8], - "ins_ndim": [3], - "out_types": [types.UINT8], - "outs_ndim": [3], - "blocks": [32, 32, 1], - "threads_per_block": [32, 16, 1], - "run_fn": numba_set_all_values_to_255_batch, - "setup_fn": numba_setup_out_shape, - }, - ) + from nvidia.dali.plugin.numba.fn.experimental import ( + numba_function as experimental_numba_function, ) + from nvidia.dali.plugin.numba.fn import numba_function + + numba_function_args = { + "batch_processing": False, + "devices": numba_compatible_devices, + "in_types": [types.UINT8], + "ins_ndim": [3], + "out_types": [types.UINT8], + "outs_ndim": [3], + "blocks": [32, 32, 1], + "threads_per_block": [32, 16, 1], + "run_fn": numba_set_all_values_to_255_batch, + "setup_fn": numba_setup_out_shape, + } + + # Test both experimental and regular numba_function + ops_image_custom_args.append((experimental_numba_function, numba_function_args)) + ops_image_custom_args.append((numba_function, numba_function_args)) def test_ops_image_custom_args(): @@ -1427,11 +1432,12 @@ def pipeline(): pipe.run() -def test_video_decoder(): +@params(fn.experimental.decoders.video, fn.decoders.video) +def test_video_decoder(decoder): def video_decoder_pipe(max_batch_size, input_data, device): pipe = Pipeline(batch_size=max_batch_size, num_threads=4, device_id=0) encoded = fn.external_source(source=input_data, cycle=False, device="cpu") - decoded = fn.experimental.decoders.video(encoded, device=device) + decoded = decoder(encoded, device=device) pipe.set_outputs(decoded) return pipe @@ -1480,7 +1486,8 @@ def sample_gen(): check_pipeline(batches, inflate_pipline, devices=["gpu"]) -def test_debayer(): +@params(fn.experimental.debayer, fn.debayer) +def test_debayer(debayer_op): from debayer_test_utils import rgb2bayer, bayer_patterns, blue_position def debayer_pipline(max_batch_size, inputs, device): @@ -1494,7 +1501,7 @@ def piepline(): positions = fn.external_source(source=blue_positions) if device == "gpu": bayered = bayered.gpu() - return fn.experimental.debayer(bayered, blue_position=positions) + return debayer_op(bayered, blue_position=positions) return piepline(batch_size=max_batch_size, num_threads=4, device_id=0) @@ -1522,7 +1529,8 @@ def sample_gen(): check_pipeline(batches, debayer_pipline, devices=["gpu", "cpu"]) -def test_filter(): +@params(fn.experimental.filter, fn.filter) +def test_filter(filter_op): def filter_pipeline(max_batch_size, inputs, device): batches = [list(zip(*batch)) for batch in inputs] sample_batches = [list(inp_batch) for inp_batch, _, _ in batches] @@ -1534,7 +1542,7 @@ def pipeline(): samples = fn.external_source(source=sample_batches, layout="HWC") filters = fn.external_source(source=filter_batches) fill_values = fn.external_source(source=fill_value_bacthes) - return fn.experimental.filter(samples.gpu(), filters, fill_values, border="constant") + return filter_op(samples.gpu(), filters, fill_values, border="constant") return pipeline(batch_size=max_batch_size, num_threads=4, device_id=0) @@ -1803,20 +1811,24 @@ def get_data(batch_size): "decoders.image_random_crop", "decoders.image_slice", "decoders.numpy", + "decoders.video", "dl_tensor_python_function", "dump_image", - "experimental.equalize", + "equalize", "element_extract", "erase", "expand_dims", + "debayer", "experimental.debayer", "experimental.decoders.image", "experimental.decoders.image_crop", "experimental.decoders.image_slice", "experimental.decoders.image_random_crop", "experimental.decoders.video", + "experimental.decoders.hidden.video", "experimental.dilate", "experimental.erode", + "experimental.equalize", "experimental.filter", "decoders.inflate", "experimental.inflate", @@ -1824,9 +1836,11 @@ def get_data(batch_size): "experimental.peek_image_shape", "experimental.remap", "experimental.resize", + "experimental.tensor_resize", "experimental.warp_perspective", "external_source", "fast_resize_crop_mirror", + "filter", "flip", "gaussian_blur", "get_property", @@ -1879,6 +1893,7 @@ def get_data(batch_size): "normal_distribution", "normalize", "numba.fn.experimental.numba_function", + "numba.fn.numba_function", "one_hot", "optical_flow", "pad", @@ -1909,7 +1924,7 @@ def get_data(batch_size): "reshape", "resize", "resize_crop_mirror", - "experimental.tensor_resize", + "tensor_resize", "roi_random_crop", "rotate", "saturation", diff --git a/dali/test/python/test_utils.py b/dali/test/python/test_utils.py index 5819ccc20a4..2b455deda0b 100644 --- a/dali/test/python/test_utils.py +++ b/dali/test/python/test_utils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2019-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -955,12 +955,15 @@ def check_numba_compatibility_cpu(if_skip=True): return True -def check_numba_compatibility_gpu(if_skip=True): - import nvidia.dali.plugin.numba.experimental as ex +def check_numba_compatibility_gpu(if_skip=True, use_experimental: bool = False): + if use_experimental: + from nvidia.dali.plugin.numba.experimental import NumbaFunction + else: + from nvidia.dali.plugin.numba import NumbaFunction - if not ex.NumbaFunction._check_minimal_numba_version( + if not NumbaFunction._check_minimal_numba_version( False - ) or not ex.NumbaFunction._check_cuda_compatibility(False): + ) or not NumbaFunction._check_cuda_compatibility(False): if if_skip: raise SkipTest() else: diff --git a/dali/test/python/type_annotations/test_typing_pipelines.py b/dali/test/python/type_annotations/test_typing_pipelines.py index 859caee0635..7b386c73ae9 100644 --- a/dali/test/python/type_annotations/test_typing_pipelines.py +++ b/dali/test/python/type_annotations/test_typing_pipelines.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2023-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -25,6 +25,7 @@ from test_utils import get_dali_extra_path, check_numba_compatibility_cpu from nose_utils import attr # type: ignore +from nose2.tools import params # type: ignore _test_root = Path(get_dali_extra_path()) @@ -328,7 +329,8 @@ def torch_pipe(): @attr("numba") -def test_numba_plugin(): +@params(True, False) +def test_numba_plugin(use_experimental): import nvidia.dali.plugin.numba as dali_numba check_numba_compatibility_cpu() @@ -336,12 +338,18 @@ def test_numba_plugin(): def double_sample(out_sample, in_sample): out_sample[:] = 2 * in_sample[:] + numba_function = ( + dali_numba.fn.experimental.numba_function + if use_experimental + else dali_numba.fn.numba_function + ) + @pipeline_def(batch_size=2, device_id=0, num_threads=4) def numba_pipe(): forty_two = fn.external_source( source=lambda x: np.full((2,), 42, dtype=np.uint8), batch=False ) - out = dali_numba.fn.experimental.numba_function( + out = numba_function( forty_two, run_fn=double_sample, out_types=[types.DALIDataType.UINT8], @@ -350,7 +358,7 @@ def numba_pipe(): ins_ndim=[1], batch_processing=False, ) - out_from_const = dali_numba.fn.experimental.numba_function( + out_from_const = numba_function( [42], run_fn=double_sample, out_types=[types.DALIDataType.INT32], diff --git a/qa/TL0_python-self-test-core/test_body.sh b/qa/TL0_python-self-test-core/test_body.sh index d220bc5156c..dd71576860c 100644 --- a/qa/TL0_python-self-test-core/test_body.sh +++ b/qa/TL0_python-self-test-core/test_body.sh @@ -21,7 +21,6 @@ test_py_with_framework() { test_pipeline_segmentation.py \ test_triton_autoserialize.py \ test_functional_api.py \ - test_dali_variable_batch_size.py \ test_external_source_impl_utils.py); do if [ -z "$DALI_ENABLE_SANITIZERS" ]; then ${python_invoke_test} --attr "!slow,!pytorch,!mxnet,!cupy" ${test_script} @@ -30,6 +29,12 @@ test_py_with_framework() { fi done + if [ -z "$DALI_ENABLE_SANITIZERS" ]; then + ${python_new_invoke_test} -A "!slow,!pytorch,!mxnet,!cupy" test_dali_variable_batch_size + else + ${python_new_invoke_test} -A "!slow,!pytorch,!mxnet,!cupy,!numba" test_dali_variable_batch_size + fi + ${python_new_invoke_test} -A '!slow,!pytorch,!mxnet,!cupy' test_backend_impl if [ -z "$DALI_ENABLE_SANITIZERS" ]; then