Skip to content

Commit d500c21

Browse files
committed
WIP feat: integrate sharding into pool
1 parent fb7c77b commit d500c21

File tree

4 files changed

+130
-8
lines changed

4 files changed

+130
-8
lines changed

sqlx-core/src/pool/connect.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,11 @@ pub trait PoolConnector<DB: Database>: Send + Sync + 'static {
220220
) -> impl Future<Output = crate::Result<DB::Connection>> + Send + '_;
221221
}
222222

223+
/// # Note: Future Changes (FIXME)
224+
/// This could theoretically be replaced with an impl over `AsyncFn` to allow lending closures,
225+
/// except we have no way to put the `Send` bound on the returned future.
226+
///
227+
/// We need Return Type Notation for that: https://github.com/rust-lang/rust/pull/138424
223228
impl<DB, F, Fut> PoolConnector<DB> for F
224229
where
225230
DB: Database,

sqlx-core/src/pool/inner.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ use std::task::ready;
1313
use crate::logger::private_level_filter_to_trace_level;
1414
use crate::pool::connect::{ConnectPermit, ConnectionCounter, ConnectionId, DynConnector};
1515
use crate::pool::idle::IdleQueue;
16+
use crate::pool::shard::Sharded;
1617
use crate::rt::JoinHandle;
1718
use crate::{private_tracing_dynamic_event, rt};
1819
use either::Either;
@@ -24,6 +25,7 @@ use tracing::Level;
2425
pub(crate) struct PoolInner<DB: Database> {
2526
pub(super) connector: DynConnector<DB>,
2627
pub(super) counter: ConnectionCounter,
28+
pub(super) sharded: Sharded<DB::Connection>,
2729
pub(super) idle: IdleQueue<DB>,
2830
is_closed: AtomicBool,
2931
pub(super) on_closed: event_listener::Event,
@@ -40,6 +42,7 @@ impl<DB: Database> PoolInner<DB> {
4042
let pool = Self {
4143
connector: DynConnector::new(connector),
4244
counter: ConnectionCounter::new(),
45+
sharded: Sharded::new(options.max_connections, options.shards),
4346
idle: IdleQueue::new(options.fair, options.max_connections),
4447
is_closed: AtomicBool::new(false),
4548
on_closed: event_listener::Event::new(),

sqlx-core/src/pool/options.rs

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use crate::pool::{Pool, PoolConnector};
77
use futures_core::future::BoxFuture;
88
use log::LevelFilter;
99
use std::fmt::{self, Debug, Formatter};
10+
use std::num::NonZero;
1011
use std::sync::Arc;
1112
use std::time::{Duration, Instant};
1213

@@ -68,6 +69,7 @@ pub struct PoolOptions<DB: Database> {
6869
>,
6970
>,
7071
pub(crate) max_connections: usize,
72+
pub(crate) shards: NonZero<usize>,
7173
pub(crate) acquire_time_level: LevelFilter,
7274
pub(crate) acquire_slow_level: LevelFilter,
7375
pub(crate) acquire_slow_threshold: Duration,
@@ -91,6 +93,7 @@ impl<DB: Database> Clone for PoolOptions<DB> {
9193
before_acquire: self.before_acquire.clone(),
9294
after_release: self.after_release.clone(),
9395
max_connections: self.max_connections,
96+
shards: self.shards,
9497
acquire_time_level: self.acquire_time_level,
9598
acquire_slow_threshold: self.acquire_slow_threshold,
9699
acquire_slow_level: self.acquire_slow_level,
@@ -143,6 +146,7 @@ impl<DB: Database> PoolOptions<DB> {
143146
// A production application will want to set a higher limit than this.
144147
max_connections: 10,
145148
min_connections: 0,
149+
shards: NonZero::<usize>::MIN,
146150
// Logging all acquires is opt-in
147151
acquire_time_level: LevelFilter::Off,
148152
// Default to warning, because an acquire timeout will be an error
@@ -206,6 +210,58 @@ impl<DB: Database> PoolOptions<DB> {
206210
self.min_connections
207211
}
208212

213+
/// Set the number of shards to split the internal structures into.
214+
///
215+
/// The default value is dynamically determined based on the configured number of worker threads
216+
/// in the current runtime (if that information is available),
217+
/// or [`std::thread::available_parallelism()`],
218+
/// or 1 otherwise.
219+
///
220+
/// Each shard is assigned an equal share of [`max_connections`][Self::max_connections]
221+
/// and its own queue of tasks waiting to acquire a connection.
222+
///
223+
/// Then, when accessing the pool, each thread selects a "local" shard based on its
224+
/// [thread ID][std::thread::Thread::id]<sup>1</sup>.
225+
///
226+
/// If the number of shards equals the number of threads (which they do by default),
227+
/// and worker threads are spawned sequentially (which they generally are),
228+
/// each thread should access a different shard, which should significantly reduce
229+
/// cache coherence overhead on multicore systems.
230+
///
231+
/// If the number of shards does not evenly divide `max_connections`,
232+
/// the implementation makes a best-effort to distribute them as evenly as possible
233+
/// (if `remainder = max_connections % shards` and `remainder != 0`,
234+
/// then `remainder` shards will get one additional connection each).
235+
///
236+
/// The implementation then clamps the number of connections in a shard to the range `[1, 64]`.
237+
///
238+
/// ### Details
239+
/// When a task calls [`Pool::acquire()`] (or any other method that calls `acquire()`),
240+
/// it will first attempt to acquire a connection from its thread-local shard, or lock an empty
241+
/// slot to open a new connection (acquiring an idle connection and opening a new connection
242+
/// happen concurrently to minimize acquire time).
243+
///
244+
/// Failing that, it joins the wait list on the shard. Released connections are passed to
245+
/// waiting tasks in a first-come, first-serve order per shard.
246+
///
247+
/// If the task cannot acquire a connection after a short delay,
248+
/// it tries to acquire a connection from another shard.
249+
///
250+
/// If the task _still_ cannot acquire a connection after a longer delay,
251+
/// it joins a global wait list. Tasks in the global wait list are the highest priority
252+
/// for released connections, implementing a kind of eventual fairness.
253+
///
254+
/// <sup>1</sup> because, as of writing, [`std::thread::ThreadId::as_u64`] is unstable,
255+
/// the current implementation assigns each thread its own sequential ID in a `thread_local!()`.
256+
pub fn shards(mut self, shards: NonZero<usize>) -> Self {
257+
self.shards = shards;
258+
self
259+
}
260+
261+
pub fn get_shards(&self) -> usize {
262+
self.shards.get()
263+
}
264+
209265
/// Enable logging of time taken to acquire a connection from the connection pool via
210266
/// [`Pool::acquire()`].
211267
///
@@ -572,3 +628,28 @@ impl<DB: Database> Debug for PoolOptions<DB> {
572628
.finish()
573629
}
574630
}
631+
632+
fn default_shards() -> NonZero<usize> {
633+
#[cfg(feature = "_rt-tokio")]
634+
if let Ok(rt) = tokio::runtime::Handle::try_current() {
635+
return rt
636+
.metrics()
637+
.num_workers()
638+
.try_into()
639+
.unwrap_or(NonZero::<usize>::MIN);
640+
}
641+
642+
#[cfg(feature = "_rt-async-std")]
643+
if let Some(val) = std::env::var("ASYNC_STD_THREAD_COUNT")
644+
.ok()
645+
.and_then(|s| s.parse())
646+
{
647+
return val;
648+
}
649+
650+
if let Ok(val) = std::thread::available_parallelism() {
651+
return val;
652+
}
653+
654+
NonZero::<usize>::MIN
655+
}

sqlx-core/src/pool/shard.rs

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use event_listener::{Event, IntoNotification};
22
use parking_lot::Mutex;
33
use std::future::Future;
4+
use std::num::NonZero;
45
use std::pin::pin;
56
use std::sync::atomic::{AtomicUsize, Ordering};
67
use std::sync::Arc;
@@ -15,7 +16,11 @@ type ConnectionIndex = usize;
1516
///
1617
/// We want tasks to acquire from their local shards where possible, so they don't enter
1718
/// the global queue immediately.
18-
const GLOBAL_QUEUE_DELAY: Duration = Duration::from_millis(5);
19+
const GLOBAL_QUEUE_DELAY: Duration = Duration::from_millis(10);
20+
21+
/// Delay before attempting to acquire from a non-local shard,
22+
/// as well as the backoff when iterating through shards.
23+
const NON_LOCAL_ACQUIRE_DELAY: Duration = Duration::from_micros(100);
1924

2025
pub struct Sharded<T> {
2126
shards: Box<[ArcShard<T>]>,
@@ -32,8 +37,7 @@ struct Global<T> {
3237
type ArcMutexGuard<T> = parking_lot::ArcMutexGuard<parking_lot::RawMutex, Option<T>>;
3338

3439
pub struct LockGuard<T> {
35-
// `Option` allows us to drop the guard before sending the notification.
36-
// Otherwise, if the receiver wakes too quickly, it might fail to lock the mutex.
40+
// `Option` allows us to take the guard in the drop handler.
3741
locked: Option<ArcMutexGuard<T>>,
3842
shard: ArcShard<T>,
3943
index: ConnectionIndex,
@@ -73,13 +77,13 @@ const MAX_SHARD_SIZE: usize = if usize::BITS > 64 {
7377
};
7478

7579
impl<T> Sharded<T> {
76-
pub fn new(connections: usize, shards: usize) -> Sharded<T> {
80+
pub fn new(connections: usize, shards: NonZero<usize>) -> Sharded<T> {
7781
let global = Arc::new(Global {
7882
unlock_event: Event::with_tag(),
7983
disconnect_event: Event::with_tag(),
8084
});
8185

82-
let shards = Params::calc(connections, shards)
86+
let shards = Params::calc(connections, shards.get())
8387
.shard_sizes()
8488
.enumerate()
8589
.map(|(shard_id, size)| Shard::new(shard_id, size, global.clone()))
@@ -89,8 +93,28 @@ impl<T> Sharded<T> {
8993
}
9094

9195
pub async fn acquire(&self, connected: bool) -> LockGuard<T> {
92-
let mut acquire_local =
93-
pin!(self.shards[thread_id() % self.shards.len()].acquire(connected));
96+
if self.shards.len() == 1 {
97+
return self.shards[0].acquire(connected).await;
98+
}
99+
100+
let thread_id = current_thread_id();
101+
102+
let mut acquire_local = pin!(self.shards[thread_id % self.shards.len()].acquire(connected));
103+
104+
let mut acquire_nonlocal = pin!(async {
105+
let mut next_shard = thread_id;
106+
107+
loop {
108+
crate::rt::sleep(NON_LOCAL_ACQUIRE_DELAY).await;
109+
110+
// Choose shards pseudorandomly by multiplying with a (relatively) large prime.
111+
next_shard = (next_shard.wrapping_mul(547)) % self.shards.len();
112+
113+
if let Some(locked) = self.shards[next_shard].try_acquire(connected) {
114+
return locked;
115+
}
116+
}
117+
});
94118

95119
let mut acquire_global = pin!(async {
96120
crate::rt::sleep(GLOBAL_QUEUE_DELAY).await;
@@ -113,6 +137,10 @@ impl<T> Sharded<T> {
113137
return Poll::Ready(locked);
114138
}
115139

140+
if let Poll::Ready(locked) = acquire_nonlocal.as_mut().poll(cx) {
141+
return Poll::Ready(locked);
142+
}
143+
116144
if let Poll::Ready(locked) = acquire_global.as_mut().poll(cx) {
117145
return Poll::Ready(locked);
118146
}
@@ -125,6 +153,9 @@ impl<T> Sharded<T> {
125153

126154
impl<T> Shard<T, [Arc<Mutex<Option<T>>>]> {
127155
fn new(shard_id: ShardId, len: usize, global: Arc<Global<T>>) -> Arc<Self> {
156+
// There's no way to create DSTs like this, in `std::sync::Arc`, on stable.
157+
//
158+
// Instead, we coerce from an array.
128159
macro_rules! make_array {
129160
($($n:literal),+) => {
130161
match len {
@@ -206,6 +237,8 @@ impl<T> Shard<T, [Arc<Mutex<Option<T>>>]> {
206237

207238
impl Params {
208239
fn calc(connections: usize, mut shards: usize) -> Params {
240+
assert_ne!(shards, 0);
241+
209242
let mut shard_size = connections / shards;
210243
let mut remainder = connections % shards;
211244

@@ -239,7 +272,7 @@ impl Params {
239272
}
240273
}
241274

242-
fn thread_id() -> usize {
275+
fn current_thread_id() -> usize {
243276
// FIXME: this can be replaced when this is stabilized:
244277
// https://doc.rust-lang.org/stable/std/thread/struct.ThreadId.html#method.as_u64
245278
static THREAD_ID: AtomicUsize = AtomicUsize::new(0);

0 commit comments

Comments
 (0)