@@ -7,6 +7,7 @@ use crate::pool::{Pool, PoolConnector};
77use futures_core:: future:: BoxFuture ;
88use log:: LevelFilter ;
99use std:: fmt:: { self , Debug , Formatter } ;
10+ use std:: num:: NonZero ;
1011use std:: sync:: Arc ;
1112use std:: time:: { Duration , Instant } ;
1213
@@ -68,6 +69,7 @@ pub struct PoolOptions<DB: Database> {
6869 > ,
6970 > ,
7071 pub ( crate ) max_connections : usize ,
72+ pub ( crate ) shards : NonZero < usize > ,
7173 pub ( crate ) acquire_time_level : LevelFilter ,
7274 pub ( crate ) acquire_slow_level : LevelFilter ,
7375 pub ( crate ) acquire_slow_threshold : Duration ,
@@ -91,6 +93,7 @@ impl<DB: Database> Clone for PoolOptions<DB> {
9193 before_acquire : self . before_acquire . clone ( ) ,
9294 after_release : self . after_release . clone ( ) ,
9395 max_connections : self . max_connections ,
96+ shards : self . shards ,
9497 acquire_time_level : self . acquire_time_level ,
9598 acquire_slow_threshold : self . acquire_slow_threshold ,
9699 acquire_slow_level : self . acquire_slow_level ,
@@ -143,6 +146,7 @@ impl<DB: Database> PoolOptions<DB> {
143146 // A production application will want to set a higher limit than this.
144147 max_connections : 10 ,
145148 min_connections : 0 ,
149+ shards : NonZero :: < usize > :: MIN ,
146150 // Logging all acquires is opt-in
147151 acquire_time_level : LevelFilter :: Off ,
148152 // Default to warning, because an acquire timeout will be an error
@@ -206,6 +210,58 @@ impl<DB: Database> PoolOptions<DB> {
206210 self . min_connections
207211 }
208212
213+ /// Set the number of shards to split the internal structures into.
214+ ///
215+ /// The default value is dynamically determined based on the configured number of worker threads
216+ /// in the current runtime (if that information is available),
217+ /// or [`std::thread::available_parallelism()`],
218+ /// or 1 otherwise.
219+ ///
220+ /// Each shard is assigned an equal share of [`max_connections`][Self::max_connections]
221+ /// and its own queue of tasks waiting to acquire a connection.
222+ ///
223+ /// Then, when accessing the pool, each thread selects a "local" shard based on its
224+ /// [thread ID][std::thread::Thread::id]<sup>1</sup>.
225+ ///
226+ /// If the number of shards equals the number of threads (which they do by default),
227+ /// and worker threads are spawned sequentially (which they generally are),
228+ /// each thread should access a different shard, which should significantly reduce
229+ /// cache coherence overhead on multicore systems.
230+ ///
231+ /// If the number of shards does not evenly divide `max_connections`,
232+ /// the implementation makes a best-effort to distribute them as evenly as possible
233+ /// (if `remainder = max_connections % shards` and `remainder != 0`,
234+ /// then `remainder` shards will get one additional connection each).
235+ ///
236+ /// The implementation then clamps the number of connections in a shard to the range `[1, 64]`.
237+ ///
238+ /// ### Details
239+ /// When a task calls [`Pool::acquire()`] (or any other method that calls `acquire()`),
240+ /// it will first attempt to acquire a connection from its thread-local shard, or lock an empty
241+ /// slot to open a new connection (acquiring an idle connection and opening a new connection
242+ /// happen concurrently to minimize acquire time).
243+ ///
244+ /// Failing that, it joins the wait list on the shard. Released connections are passed to
245+ /// waiting tasks in a first-come, first-serve order per shard.
246+ ///
247+ /// If the task cannot acquire a connection after a short delay,
248+ /// it tries to acquire a connection from another shard.
249+ ///
250+ /// If the task _still_ cannot acquire a connection after a longer delay,
251+ /// it joins a global wait list. Tasks in the global wait list are the highest priority
252+ /// for released connections, implementing a kind of eventual fairness.
253+ ///
254+ /// <sup>1</sup> because, as of writing, [`std::thread::ThreadId::as_u64`] is unstable,
255+ /// the current implementation assigns each thread its own sequential ID in a `thread_local!()`.
256+ pub fn shards ( mut self , shards : NonZero < usize > ) -> Self {
257+ self . shards = shards;
258+ self
259+ }
260+
261+ pub fn get_shards ( & self ) -> usize {
262+ self . shards . get ( )
263+ }
264+
209265 /// Enable logging of time taken to acquire a connection from the connection pool via
210266 /// [`Pool::acquire()`].
211267 ///
@@ -572,3 +628,28 @@ impl<DB: Database> Debug for PoolOptions<DB> {
572628 . finish ( )
573629 }
574630}
631+
632+ fn default_shards ( ) -> NonZero < usize > {
633+ #[ cfg( feature = "_rt-tokio" ) ]
634+ if let Ok ( rt) = tokio:: runtime:: Handle :: try_current ( ) {
635+ return rt
636+ . metrics ( )
637+ . num_workers ( )
638+ . try_into ( )
639+ . unwrap_or ( NonZero :: < usize > :: MIN ) ;
640+ }
641+
642+ #[ cfg( feature = "_rt-async-std" ) ]
643+ if let Some ( val) = std:: env:: var ( "ASYNC_STD_THREAD_COUNT" )
644+ . ok ( )
645+ . and_then ( |s| s. parse ( ) )
646+ {
647+ return val;
648+ }
649+
650+ if let Ok ( val) = std:: thread:: available_parallelism ( ) {
651+ return val;
652+ }
653+
654+ NonZero :: < usize > :: MIN
655+ }
0 commit comments