Skip to content

Commit 5947190

Browse files
committed
Implement a way to provide a memory_resource through the execution policy
1 parent d98cab3 commit 5947190

File tree

5 files changed

+222
-26
lines changed

5 files changed

+222
-26
lines changed

libcudacxx/include/cuda/__execution/policy.h

Lines changed: 86 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,14 @@
2323
#if _CCCL_HAS_BACKEND_CUDA()
2424

2525
# include <cuda/__fwd/execution_policy.h>
26+
# include <cuda/__memory_resource/device_memory_pool.h>
27+
# include <cuda/__memory_resource/get_memory_resource.h>
28+
# include <cuda/__memory_resource/resource.h>
2629
# include <cuda/__stream/get_stream.h>
2730
# include <cuda/__stream/stream_ref.h>
2831
# include <cuda/std/__execution/policy.h>
2932
# include <cuda/std/__type_traits/is_execution_policy.h>
33+
# include <cuda/std/__utility/forward.h>
3034

3135
# include <cuda/std/__cccl/prologue.h>
3236

@@ -37,7 +41,7 @@ struct __policy_stream_holder
3741
{
3842
::cuda::stream_ref __stream_;
3943

40-
_CCCL_API constexpr __policy_stream_holder(::cuda::stream_ref __stream) noexcept
44+
_CCCL_HOST_API constexpr __policy_stream_holder(::cuda::stream_ref __stream) noexcept
4145
: __stream_(__stream)
4246
{}
4347
};
@@ -51,27 +55,64 @@ struct __policy_stream_holder<false>
5155
_CCCL_HOST_API constexpr __policy_stream_holder(::cuda::stream_ref) noexcept {}
5256
};
5357

58+
template <bool _HasResource>
59+
struct __policy_memory_resource_holder
60+
{
61+
using __resource_t = ::cuda::mr::any_resource<::cuda::mr::device_accessible>;
62+
63+
__resource_t __resource_;
64+
65+
_CCCL_TEMPLATE(class _Resource)
66+
_CCCL_REQUIRES(::cuda::mr::resource_with<_Resource, ::cuda::mr::device_accessible>)
67+
_CCCL_HOST_API constexpr __policy_memory_resource_holder(_Resource&& __resource) noexcept
68+
: __resource_(::cuda::std::forward<_Resource>(__resource))
69+
{}
70+
};
71+
72+
template <>
73+
struct __policy_memory_resource_holder<false>
74+
{
75+
_CCCL_HIDE_FROM_ABI __policy_memory_resource_holder() = default;
76+
77+
//! @brief Dummy constructor to simplify implementation of the cuda policy
78+
_CCCL_TEMPLATE(class _Resource)
79+
_CCCL_REQUIRES(::cuda::mr::resource_with<_Resource, ::cuda::mr::device_accessible>)
80+
_CCCL_HOST_API constexpr __policy_memory_resource_holder(_Resource&&) noexcept {}
81+
};
82+
5483
template <uint32_t _Policy>
5584
struct _CCCL_DECLSPEC_EMPTY_BASES __execution_policy_base<_Policy, __execution_backend::__cuda>
5685
: __execution_policy_base<_Policy, __execution_backend::__none>
5786
, protected __policy_stream_holder<__cuda_policy_with_stream<_Policy>>
87+
, protected __policy_memory_resource_holder<__cuda_policy_with_memory_resource<_Policy>>
5888
{
5989
private:
6090
template <uint32_t, __execution_backend>
6191
friend struct __execution_policy_base;
6292

63-
using __stream_holder = __policy_stream_holder<__cuda_policy_with_stream<_Policy>>;
93+
using __stream_holder = __policy_stream_holder<__cuda_policy_with_stream<_Policy>>;
94+
using __resource_holder = __policy_memory_resource_holder<__cuda_policy_with_memory_resource<_Policy>>;
6495

6596
template <uint32_t _OtherPolicy>
6697
_CCCL_HOST_API constexpr __execution_policy_base(
6798
const __execution_policy_base<_OtherPolicy, __execution_backend::__cuda>& __policy) noexcept
6899
: __stream_holder(__policy.query(::cuda::get_stream))
100+
, __resource_holder(__policy.query(::cuda::mr::get_memory_resource))
69101
{}
70102

71103
template <uint32_t _OtherPolicy>
72104
_CCCL_HOST_API constexpr __execution_policy_base(
73-
const __execution_policy_base<_OtherPolicy, __execution_backend::__cuda>&, ::cuda::stream_ref __stream) noexcept
105+
const __execution_policy_base<_OtherPolicy, __execution_backend::__cuda>& __policy,
106+
::cuda::stream_ref __stream) noexcept
74107
: __stream_holder(__stream)
108+
, __resource_holder(__policy.query(::cuda::mr::get_memory_resource))
109+
{}
110+
111+
template <uint32_t _OtherPolicy, class _Resource>
112+
_CCCL_HOST_API constexpr __execution_policy_base(
113+
const __execution_policy_base<_OtherPolicy, __execution_backend::__cuda>& __policy, _Resource&& __resource) noexcept
114+
: __stream_holder(__policy.query(::cuda::get_stream))
115+
, __resource_holder(::cuda::std::forward<_Resource>(__resource))
75116
{}
76117

77118
public:
@@ -109,6 +150,40 @@ struct _CCCL_DECLSPEC_EMPTY_BASES __execution_policy_base<_Policy, __execution_b
109150
}
110151
}
111152

153+
//! @brief Set the current memory resource
154+
_CCCL_TEMPLATE(class _Resource, bool _WithResource = __cuda_policy_with_memory_resource<_Policy>)
155+
_CCCL_REQUIRES(::cuda::mr::resource_with<_Resource, ::cuda::mr::device_accessible> _CCCL_AND _WithResource)
156+
[[nodiscard]] _CCCL_HOST_API __execution_policy_base& set_memory_resource(_Resource&& __resource) noexcept
157+
{
158+
this->__resource_ = __resource;
159+
return *this;
160+
}
161+
162+
//! @brief Convert to a policy that holds a memory resource
163+
_CCCL_TEMPLATE(class _Resource, bool _WithResource = __cuda_policy_with_memory_resource<_Policy>)
164+
_CCCL_REQUIRES(::cuda::mr::resource_with<_Resource, ::cuda::mr::device_accessible> _CCCL_AND(!_WithResource))
165+
[[nodiscard]] _CCCL_HOST_API auto set_memory_resource(_Resource&& __resource) const noexcept
166+
{
167+
constexpr uint32_t __new_policy =
168+
__set_cuda_backend_option<_Policy, __cuda_backend_options::__with_memory_resource>;
169+
return __execution_policy_base<__new_policy>{*this, __resource};
170+
}
171+
172+
//! @brief Return either a stored or a default memory resource
173+
//! @note We cannot put that into the __policy_memory_resource_holder because we need a stream for the device
174+
[[nodiscard]] _CCCL_HOST_API auto query(const ::cuda::mr::get_memory_resource_t&) const noexcept
175+
{
176+
if constexpr (__cuda_policy_with_memory_resource<_Policy>)
177+
{
178+
return this->__resource_;
179+
}
180+
else
181+
{
182+
::cuda::stream_ref __stream = this->query(::cuda::get_stream);
183+
return ::cuda::device_default_memory_pool(__stream.device());
184+
}
185+
}
186+
112187
template <uint32_t _OtherPolicy, __execution_backend _OtherBackend>
113188
[[nodiscard]] _CCCL_API friend constexpr bool operator==(
114189
const __execution_policy_base& __lhs, const __execution_policy_base<_OtherPolicy, _OtherBackend>& __rhs) noexcept
@@ -126,6 +201,14 @@ struct _CCCL_DECLSPEC_EMPTY_BASES __execution_policy_base<_Policy, __execution_b
126201
}
127202
}
128203

204+
if constexpr (__cuda_policy_with_memory_resource<_Policy>)
205+
{
206+
if (__lhs.query(::cuda::mr::get_memory_resource) != __rhs.query(::cuda::mr::get_memory_resource))
207+
{
208+
return false;
209+
}
210+
}
211+
129212
return true;
130213
}
131214

libcudacxx/include/cuda/__fwd/execution_policy.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ _CCCL_BEGIN_NAMESPACE_CUDA_STD_EXECUTION
3030

3131
enum __cuda_backend_options : uint16_t
3232
{
33-
__with_stream = 1 << 0, ///> Determines whether the policy holds a stream
33+
__with_stream = 1 << 0, ///> Determines whether the policy holds a stream
34+
__with_memory_resource = 1 << 1, ///> Determines whether the policy holds a memory resource
3435
};
3536

3637
//! @brief Sets the execution backend to cuda
@@ -58,6 +59,11 @@ template <uint32_t _Policy>
5859
inline constexpr bool __cuda_policy_with_stream =
5960
__policy_to_cuda_backend_options<_Policy> & __cuda_backend_options::__with_stream;
6061

62+
//! @brief Detects whether a given policy holds a user provided memory resource
63+
template <uint32_t _Policy>
64+
inline constexpr bool __cuda_policy_with_memory_resource =
65+
__policy_to_cuda_backend_options<_Policy> & __cuda_backend_options::__with_memory_resource;
66+
6167
_CCCL_END_NAMESPACE_CUDA_STD_EXECUTION
6268

6369
# include <cuda/std/__cccl/epilogue.h>

libcudacxx/include/cuda/std/__pstl/cuda/reduce.h

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ _CCCL_DIAG_SUPPRESS_CLANG("-Wshadow")
3131
_CCCL_DIAG_POP
3232

3333
# include <cuda/__execution/policy.h>
34+
# include <cuda/__memory_resource/get_memory_resource.h>
3435
# include <cuda/__runtime/api_wrapper.h>
3536
# include <cuda/__stream/get_stream.h>
3637
# include <cuda/std/__exception/cuda_error.h>
@@ -58,7 +59,7 @@ template <>
5859
struct __pstl_dispatch<__pstl_algorithm::__reduce, __execution_backend::__cuda>
5960
{
6061
//! Ensures we properly deallocate the memory allocated for the result
61-
template <class _Tp>
62+
template <class _Tp, class _Resource>
6263
struct __allocation_guard
6364
{
6465
//! This helper struct ensures that we can properly assign types with a nontrivial assignment operator
@@ -78,24 +79,18 @@ struct __pstl_dispatch<__pstl_algorithm::__reduce, __execution_backend::__cuda>
7879
};
7980

8081
::cuda::stream_ref __stream_;
82+
_Resource& __resource_;
8183
_Tp* __ptr_;
8284

83-
_CCCL_HOST_API __allocation_guard(::cuda::stream_ref __stream)
85+
_CCCL_HOST_API __allocation_guard(::cuda::stream_ref __stream, _Resource& __resource)
8486
: __stream_(__stream)
85-
, __ptr_(nullptr)
86-
{
87-
_CCCL_TRY_CUDA_API(
88-
::cudaMallocAsync,
89-
"__pstl_cuda_reduce: allocation failed",
90-
reinterpret_cast<void**>(&__ptr_),
91-
sizeof(_Tp),
92-
__stream_.get());
93-
}
87+
, __resource_(__resource)
88+
, __ptr_(static_cast<_Tp*>(__resource_.allocate(__stream_, sizeof(_Tp), alignof(_Tp))))
89+
{}
9490

9591
_CCCL_HOST_API ~__allocation_guard()
9692
{
97-
_CCCL_TRY_CUDA_API(::cudaFreeAsync, "__pstl_cuda_reduce: deallocate failed", __ptr_, __stream_.get());
98-
93+
__resource_.deallocate(__stream_, __ptr_, sizeof(_Tp), alignof(_Tp));
9994
__stream_.sync();
10095
}
10196

@@ -113,8 +108,9 @@ struct __pstl_dispatch<__pstl_algorithm::__reduce, __execution_backend::__cuda>
113108

114109
{
115110
// Allocate memory for result
116-
auto __stream = __policy.query(::cuda::get_stream);
117-
__allocation_guard<_Tp> __guard{__stream};
111+
auto __stream = __policy.query(::cuda::get_stream);
112+
auto __resource = __policy.query(::cuda::mr::get_memory_resource);
113+
__allocation_guard<_Tp, decltype(__resource)> __guard{__stream, __resource};
118114

119115
const auto __count = ::cuda::std::distance(__first, __last);
120116
_CCCL_TRY_CUDA_API(
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of libcu++, the C++ Standard Library for your entire system,
4+
// under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
8+
//
9+
//===----------------------------------------------------------------------===//
10+
11+
// UNSUPPORTED: nvrtc
12+
13+
#include <cuda/memory_resource>
14+
#include <cuda/std/__pstl/for_each.h>
15+
#include <cuda/std/execution>
16+
#include <cuda/std/memory>
17+
#include <cuda/std/type_traits>
18+
#include <cuda/stream>
19+
20+
struct test_resource
21+
{
22+
__host__ __device__ void* allocate_sync(std::size_t, std::size_t)
23+
{
24+
return nullptr;
25+
}
26+
27+
__host__ __device__ void deallocate_sync(void* ptr, std::size_t, std::size_t) noexcept
28+
{
29+
// ensure that we did get the right inputs forwarded
30+
_val = *static_cast<int*>(ptr);
31+
}
32+
33+
__host__ __device__ void* allocate(cuda::stream_ref, std::size_t, std::size_t)
34+
{
35+
return &_val;
36+
}
37+
38+
__host__ __device__ void deallocate(cuda::stream_ref, void* ptr, std::size_t, std::size_t)
39+
{
40+
// ensure that we did get the right inputs forwarded
41+
_val = *static_cast<int*>(ptr);
42+
}
43+
44+
__host__ __device__ bool operator==(const test_resource& other) const
45+
{
46+
return _val == other._val;
47+
}
48+
__host__ __device__ bool operator!=(const test_resource& other) const
49+
{
50+
return _val != other._val;
51+
}
52+
53+
friend constexpr void get_property(const test_resource&, ::cuda::mr::device_accessible) noexcept {}
54+
55+
int _val = 0;
56+
};
57+
58+
template <class Policy>
59+
void test(Policy pol)
60+
{
61+
auto old_stream = ::cuda::get_stream(pol);
62+
{ // Ensure that the plain policy returns a well defined memory resource
63+
auto expected_resource = ::cuda::device_default_memory_pool(cuda::device_ref{0});
64+
assert(cuda::mr::get_memory_resource(pol) == expected_resource);
65+
}
66+
67+
{ // Ensure that we can attach a memory resource to an execution policy
68+
test_resource resource{42};
69+
auto pol_with_resource = pol.set_memory_resource(resource);
70+
assert(cuda::mr::get_memory_resource(pol_with_resource) == resource);
71+
assert(cuda::get_stream(pol_with_resource) == old_stream);
72+
73+
using policy_t = decltype(pol_with_resource);
74+
static_assert(noexcept(pol.set_memory_resource(resource)));
75+
static_assert(cuda::std::is_execution_policy_v<policy_t>);
76+
}
77+
78+
{ // Ensure that attaching a memory resource multiple times just overwrites the old one
79+
test_resource resource{42};
80+
auto pol_with_resource = pol.set_memory_resource(resource);
81+
assert(cuda::mr::get_memory_resource(pol_with_resource) == resource);
82+
assert(cuda::get_stream(pol_with_resource) == old_stream);
83+
84+
using policy_t = decltype(pol_with_resource);
85+
test_resource other_resource{1337};
86+
decltype(auto) pol_with_other_resource = pol_with_resource.set_memory_resource(other_resource);
87+
static_assert(cuda::std::is_same_v<decltype(pol_with_other_resource), policy_t&>);
88+
assert(::cuda::mr::get_memory_resource(pol_with_resource) == other_resource);
89+
assert(::cuda::mr::get_memory_resource(pol_with_other_resource) == other_resource);
90+
assert(cuda::std::addressof(pol_with_resource) == cuda::std::addressof(pol_with_other_resource));
91+
assert(cuda::get_stream(pol_with_resource) == old_stream);
92+
}
93+
}
94+
95+
void test()
96+
{
97+
namespace execution = cuda::std::execution;
98+
static_assert(!execution::__queryable_with<execution::sequenced_policy, ::cuda::mr::get_memory_resource_t>);
99+
static_assert(!execution::__queryable_with<execution::parallel_policy, ::cuda::mr::get_memory_resource_t>);
100+
static_assert(
101+
!execution::__queryable_with<execution::parallel_unsequenced_policy, ::cuda::mr::get_memory_resource_t>);
102+
static_assert(!execution::__queryable_with<execution::unsequenced_policy, ::cuda::mr::get_memory_resource_t>);
103+
104+
test(cuda::execution::__cub_par_unseq);
105+
106+
// Ensure that all works even if we have a stream attached
107+
test(cuda::execution::__cub_par_unseq.set_stream(::cuda::stream{cuda::device_ref{0}}));
108+
}
109+
110+
int main(int, char**)
111+
{
112+
NV_IF_TARGET(NV_IS_HOST, (test();))
113+
114+
return 0;
115+
}

libcudacxx/test/libcudacxx/cuda/execution/execution_policy/get_stream.pass.cpp

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,6 @@
2222
template <class Policy>
2323
void test(Policy pol)
2424
{
25-
namespace execution = cuda::std::execution;
26-
static_assert(!execution::__queryable_with<execution::sequenced_policy, ::cuda::get_stream_t>);
27-
static_assert(!execution::__queryable_with<execution::parallel_policy, ::cuda::get_stream_t>);
28-
static_assert(!execution::__queryable_with<execution::parallel_unsequenced_policy, ::cuda::get_stream_t>);
29-
static_assert(!execution::__queryable_with<execution::unsequenced_policy, ::cuda::get_stream_t>);
30-
3125
{ // Ensure that the plain policy returns a well defined stream
3226
cuda::stream_ref expected_stream{cudaStreamPerThread};
3327
assert(cuda::get_stream(pol) == expected_stream);
@@ -41,7 +35,6 @@ void test(Policy pol)
4135
using stream_policy_t = decltype(pol_with_stream);
4236
static_assert(noexcept(pol.set_stream(stream)));
4337
static_assert(cuda::std::is_execution_policy_v<stream_policy_t>);
44-
static_assert(cuda::std::is_base_of_v<cuda::std::execution::__policy_stream_holder<true>, stream_policy_t>);
4538
}
4639

4740
{ // Ensure that attaching a stream multiple times just overwrites the old stream
@@ -68,6 +61,9 @@ void test()
6861
static_assert(!execution::__queryable_with<execution::unsequenced_policy, ::cuda::get_stream_t>);
6962

7063
test(cuda::execution::__cub_par_unseq);
64+
65+
// Ensure that all works even if we have a memory resource
66+
test(cuda::execution::__cub_par_unseq.set_memory_resource(::cuda::device_default_memory_pool(::cuda::device_ref{0})));
7167
}
7268

7369
int main(int, char**)

0 commit comments

Comments
 (0)