2323#if _CCCL_HAS_BACKEND_CUDA()
2424
2525# include < cuda/__fwd/execution_policy.h>
26+ # include < cuda/__memory_resource/device_memory_pool.h>
27+ # include < cuda/__memory_resource/get_memory_resource.h>
28+ # include < cuda/__memory_resource/resource.h>
2629# include < cuda/__stream/get_stream.h>
2730# include < cuda/__stream/stream_ref.h>
2831# include < cuda/std/__execution/policy.h>
2932# include < cuda/std/__type_traits/is_execution_policy.h>
33+ # include < cuda/std/__utility/forward.h>
3034
3135# include < cuda/std/__cccl/prologue.h>
3236
@@ -37,7 +41,7 @@ struct __policy_stream_holder
3741{
3842 ::cuda::stream_ref __stream_;
3943
40- _CCCL_API constexpr __policy_stream_holder (::cuda::stream_ref __stream) noexcept
44+ _CCCL_HOST_API constexpr __policy_stream_holder (::cuda::stream_ref __stream) noexcept
4145 : __stream_(__stream)
4246 {}
4347};
@@ -51,27 +55,64 @@ struct __policy_stream_holder<false>
5155 _CCCL_HOST_API constexpr __policy_stream_holder (::cuda::stream_ref) noexcept {}
5256};
5357
58+ template <bool _HasResource>
59+ struct __policy_memory_resource_holder
60+ {
61+ using __resource_t = ::cuda::mr::any_resource<::cuda::mr::device_accessible>;
62+
63+ __resource_t __resource_;
64+
65+ _CCCL_TEMPLATE (class _Resource )
66+ _CCCL_REQUIRES (::cuda::mr::resource_with<_Resource, ::cuda::mr::device_accessible>)
67+ _CCCL_HOST_API constexpr __policy_memory_resource_holder (_Resource&& __resource) noexcept
68+ : __resource_(::cuda::std::forward<_Resource>(__resource))
69+ {}
70+ };
71+
72+ template <>
73+ struct __policy_memory_resource_holder <false >
74+ {
75+ _CCCL_HIDE_FROM_ABI __policy_memory_resource_holder () = default;
76+
77+ // ! @brief Dummy constructor to simplify implementation of the cuda policy
78+ _CCCL_TEMPLATE (class _Resource )
79+ _CCCL_REQUIRES (::cuda::mr::resource_with<_Resource, ::cuda::mr::device_accessible>)
80+ _CCCL_HOST_API constexpr __policy_memory_resource_holder (_Resource&&) noexcept {}
81+ };
82+
5483template <uint32_t _Policy>
5584struct _CCCL_DECLSPEC_EMPTY_BASES __execution_policy_base<_Policy, __execution_backend::__cuda>
5685 : __execution_policy_base<_Policy, __execution_backend::__none>
5786 , protected __policy_stream_holder<__cuda_policy_with_stream<_Policy>>
87+ , protected __policy_memory_resource_holder<__cuda_policy_with_memory_resource<_Policy>>
5888{
5989private:
6090 template <uint32_t , __execution_backend>
6191 friend struct __execution_policy_base ;
6292
63- using __stream_holder = __policy_stream_holder<__cuda_policy_with_stream<_Policy>>;
93+ using __stream_holder = __policy_stream_holder<__cuda_policy_with_stream<_Policy>>;
94+ using __resource_holder = __policy_memory_resource_holder<__cuda_policy_with_memory_resource<_Policy>>;
6495
6596 template <uint32_t _OtherPolicy>
6697 _CCCL_HOST_API constexpr __execution_policy_base (
6798 const __execution_policy_base<_OtherPolicy, __execution_backend::__cuda>& __policy) noexcept
6899 : __stream_holder(__policy.query(::cuda::get_stream))
100+ , __resource_holder(__policy.query(::cuda::mr::get_memory_resource))
69101 {}
70102
71103 template <uint32_t _OtherPolicy>
72104 _CCCL_HOST_API constexpr __execution_policy_base (
73- const __execution_policy_base<_OtherPolicy, __execution_backend::__cuda>&, ::cuda::stream_ref __stream) noexcept
105+ const __execution_policy_base<_OtherPolicy, __execution_backend::__cuda>& __policy,
106+ ::cuda::stream_ref __stream) noexcept
74107 : __stream_holder(__stream)
108+ , __resource_holder(__policy.query(::cuda::mr::get_memory_resource))
109+ {}
110+
111+ template <uint32_t _OtherPolicy, class _Resource >
112+ _CCCL_HOST_API constexpr __execution_policy_base (
113+ const __execution_policy_base<_OtherPolicy, __execution_backend::__cuda>& __policy, _Resource&& __resource) noexcept
114+ : __stream_holder(__policy.query(::cuda::get_stream))
115+ , __resource_holder(::cuda::std::forward<_Resource>(__resource))
75116 {}
76117
77118public:
@@ -109,6 +150,40 @@ struct _CCCL_DECLSPEC_EMPTY_BASES __execution_policy_base<_Policy, __execution_b
109150 }
110151 }
111152
153+ // ! @brief Set the current memory resource
154+ _CCCL_TEMPLATE (class _Resource , bool _WithResource = __cuda_policy_with_memory_resource<_Policy>)
155+ _CCCL_REQUIRES (::cuda::mr::resource_with<_Resource, ::cuda::mr::device_accessible> _CCCL_AND _WithResource)
156+ [[nodiscard]] _CCCL_HOST_API __execution_policy_base& set_memory_resource (_Resource&& __resource) noexcept
157+ {
158+ this ->__resource_ = __resource;
159+ return *this ;
160+ }
161+
162+ // ! @brief Convert to a policy that holds a memory resource
163+ _CCCL_TEMPLATE (class _Resource , bool _WithResource = __cuda_policy_with_memory_resource<_Policy>)
164+ _CCCL_REQUIRES (::cuda::mr::resource_with<_Resource, ::cuda::mr::device_accessible> _CCCL_AND (!_WithResource))
165+ [[nodiscard]] _CCCL_HOST_API auto set_memory_resource(_Resource&& __resource) const noexcept
166+ {
167+ constexpr uint32_t __new_policy =
168+ __set_cuda_backend_option<_Policy, __cuda_backend_options::__with_memory_resource>;
169+ return __execution_policy_base<__new_policy>{*this , __resource};
170+ }
171+
172+ // ! @brief Return either a stored or a default memory resource
173+ // ! @note We cannot put that into the __policy_memory_resource_holder because we need a stream for the device
174+ [[nodiscard]] _CCCL_HOST_API auto query (const ::cuda::mr::get_memory_resource_t &) const noexcept
175+ {
176+ if constexpr (__cuda_policy_with_memory_resource<_Policy>)
177+ {
178+ return this ->__resource_ ;
179+ }
180+ else
181+ {
182+ ::cuda::stream_ref __stream = this ->query (::cuda::get_stream);
183+ return ::cuda::device_default_memory_pool (__stream.device ());
184+ }
185+ }
186+
112187 template <uint32_t _OtherPolicy, __execution_backend _OtherBackend>
113188 [[nodiscard]] _CCCL_API friend constexpr bool operator ==(
114189 const __execution_policy_base& __lhs, const __execution_policy_base<_OtherPolicy, _OtherBackend>& __rhs) noexcept
@@ -126,6 +201,14 @@ struct _CCCL_DECLSPEC_EMPTY_BASES __execution_policy_base<_Policy, __execution_b
126201 }
127202 }
128203
204+ if constexpr (__cuda_policy_with_memory_resource<_Policy>)
205+ {
206+ if (__lhs.query (::cuda::mr::get_memory_resource) != __rhs.query (::cuda::mr::get_memory_resource))
207+ {
208+ return false ;
209+ }
210+ }
211+
129212 return true ;
130213 }
131214
0 commit comments