Skip to content

Commit e8ddab3

Browse files
committed
WIP: Hyper 1.0 migration - RPC server fixes
- Simplified rpc/mod.rs to use tonic's serve_with_incoming - Added tonic::transport::server::Connected bound to Accept trait - Added From<hyper_util::client::legacy::Error> for LoadDumpError - Down to 39 errors from 43
1 parent 0e8667f commit e8ddab3

3 files changed

Lines changed: 141 additions & 129 deletions

File tree

libsql-server/src/error.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,12 @@ pub enum LoadDumpError {
298298
InvalidSqlInput(String),
299299
}
300300

301+
impl From<hyper_util::client::legacy::Error> for LoadDumpError {
302+
fn from(e: hyper_util::client::legacy::Error) -> Self {
303+
LoadDumpError::Internal(format!("HTTP client error: {}", e))
304+
}
305+
}
306+
301307
impl ResponseError for LoadDumpError {}
302308

303309
impl IntoResponse for &LoadDumpError {

libsql-server/src/net.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ pub trait Conn: AsyncRead + AsyncWrite + Read + Write + Unpin + Send + 'static {
102102
/// Trait for accepting incoming connections.
103103
/// This is the hyper 1.0+ compatible version that replaces `hyper::server::accept::Accept`.
104104
pub trait Accept: Unpin + Send + 'static {
105-
type Connection: Conn;
105+
type Connection: Conn + Connected<ConnectInfo = TcpConnectInfo>;
106106
type Error: std::error::Error + Send + Sync + 'static;
107107

108108
fn poll_accept(

libsql-server/src/rpc/mod.rs

Lines changed: 134 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,25 @@ use std::future::poll_fn;
22
use std::pin::Pin;
33
use std::sync::Arc;
44

5-
use hyper_util::rt::{TokioExecutor, TokioIo};
6-
use hyper_util::server::conn::auto::Builder as ConnBuilder;
5+
use async_stream::try_stream;
6+
use futures::Stream;
77
use libsql_replication::rpc::replication::replication_log_server::ReplicationLogServer;
88
use libsql_replication::rpc::replication::{BoxReplicationService, NAMESPACE_METADATA_KEY};
99
use rustls::pki_types::CertificateDer;
1010
use rustls::RootCertStore;
11+
use tokio::io::{AsyncRead, AsyncWrite};
1112
use tokio_rustls::TlsAcceptor;
13+
use tonic::transport::server::Connected;
1214
use tonic::Status;
1315
use tower::util::option_layer;
14-
use tower::Service;
1516
use tower::ServiceBuilder;
1617
use tower_http::trace::DefaultOnResponse;
1718
use tracing::Span;
1819

1920
use crate::config::TlsConfig;
2021
use crate::metrics::CLIENT_VERSION;
2122
use crate::namespace::NamespaceName;
22-
use crate::net::{Accept, Conn};
23+
use crate::net::Accept;
2324
use crate::rpc::proxy::rpc::proxy_server::ProxyServer;
2425
use crate::rpc::proxy::ProxyService;
2526
use crate::utils::services::idle_shutdown::IdleShutdownKicker;
@@ -31,18 +32,14 @@ pub mod streaming_exec;
3132

3233
pub async fn run_rpc_server<A: Accept>(
3334
proxy_service: ProxyService,
34-
mut acceptor: A,
35+
acceptor: A,
3536
maybe_tls: Option<TlsConfig>,
3637
idle_shutdown_layer: Option<IdleShutdownKicker>,
3738
service: BoxReplicationService,
3839
) -> anyhow::Result<()> {
39-
let router = tonic::transport::Server::builder()
40+
// Build the tonic server with services
41+
let mut server = tonic::transport::Server::builder()
4042
.layer(&option_layer(idle_shutdown_layer))
41-
.add_service(ProxyServer::new(proxy_service))
42-
.add_service(ReplicationLogServer::new(service))
43-
.into_router();
44-
45-
let svc = ServiceBuilder::new()
4643
.layer(
4744
tower_http::trace::TraceLayer::new_for_grpc()
4845
.on_request(trace_request)
@@ -51,63 +48,26 @@ pub async fn run_rpc_server<A: Accept>(
5148
.level(tracing::Level::DEBUG)
5249
.latency_unit(tower_http::LatencyUnit::Micros),
5350
),
54-
)
55-
.service(router);
51+
);
52+
53+
let router = server
54+
.add_service(ProxyServer::new(proxy_service))
55+
.add_service(ReplicationLogServer::new(service));
5656

5757
if let Some(tls_config) = maybe_tls {
58-
run_tls_server(&mut acceptor, svc, tls_config).await
58+
run_tls_server(acceptor, router, tls_config).await
5959
} else {
60-
run_plain_server(&mut acceptor, svc).await
60+
run_plain_server(acceptor, router).await
6161
}
6262
}
6363

64-
/// Wrapper service that converts hyper 1.0's Incoming body to tonic's BoxBody
65-
#[derive(Clone)]
66-
struct TonicServiceWrapper<S> {
67-
inner: S,
68-
}
69-
70-
impl<S, B> Service<hyper::Request<hyper::body::Incoming>> for TonicServiceWrapper<S>
71-
where
72-
S: Service<hyper::Request<tonic::body::BoxBody>, Response = hyper::Response<B>, Error = std::convert::Infallible> + Clone + Send + 'static,
73-
S::Future: Send + 'static,
74-
B: http_body::Body<Data = bytes::Bytes> + Send + 'static,
75-
B::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + Sync + 'static,
76-
{
77-
type Response = hyper::Response<B>;
78-
type Error = std::convert::Infallible;
79-
type Future = S::Future;
80-
81-
fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> std::task::Poll<Result<(), Self::Error>> {
82-
self.inner.poll_ready(cx)
83-
}
84-
85-
fn call(&mut self, req: hyper::Request<hyper::body::Incoming>) -> Self::Future {
86-
// Convert Incoming body to tonic's BoxBody
87-
// Need to map the error type from io::Error to tonic::Status
88-
let (parts, body) = req.into_parts();
89-
let body = body.map_err(|e| tonic::Status::internal(format!("body error: {}", e)));
90-
let body = tonic::body::BoxBody::new(body);
91-
let req = hyper::Request::from_parts(parts, body);
92-
self.inner.call(req)
93-
}
94-
}
95-
96-
async fn run_tls_server<A, S, B>(
97-
acceptor: &mut A,
98-
svc: S,
64+
async fn run_tls_server<A>(
65+
acceptor: A,
66+
router: tonic::transport::server::Router,
9967
tls_config: TlsConfig,
10068
) -> anyhow::Result<()>
10169
where
10270
A: Accept,
103-
S: tower::Service<hyper::Request<tonic::body::BoxBody>, Response = hyper::Response<B>, Error = std::convert::Infallible>
104-
+ Clone
105-
+ Send
106-
+ 'static,
107-
S::Future: Send + 'static,
108-
S::Response: Send + 'static,
109-
B: http_body::Body<Data = bytes::Bytes> + Send + 'static,
110-
B::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + Sync + 'static,
11171
{
11272
let cert_pem = tokio::fs::read_to_string(&tls_config.cert).await?;
11373
let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut cert_pem.as_bytes())
@@ -116,7 +76,7 @@ where
11676
let key_pem = tokio::fs::read_to_string(&tls_config.key).await?;
11777
let keys: Vec<_> = rustls_pemfile::pkcs8_private_keys(&mut key_pem.as_bytes())
11878
.collect::<Result<Vec<_>, _>>()?;
119-
let key = rustls::pki_types::PrivateKeyDer::try_from(keys.into_iter().next().ok_or_else(|| anyhow::anyhow!("no private keys found"))?)?
79+
let key = rustls::pki_types::PrivateKeyDer::try_from(keys.into_iter().next().ok_or_else(|| anyhow::anyhow!("no private keys found"))?)?;
12080

12181
let ca_cert_pem = std::fs::read_to_string(&tls_config.ca_cert)?;
12282
let ca_certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut ca_cert_pem.as_bytes())
@@ -137,90 +97,136 @@ where
13797
let tls_acceptor = TlsAcceptor::from(Arc::new(config));
13898

13999
tracing::info!("serving internal rpc server with tls");
140-
141-
let wrapped_svc = TonicServiceWrapper { inner: svc };
142-
143-
// Drive the acceptor stream manually for hyper 1.0+ compatibility
144-
loop {
145-
let conn = match poll_fn(|cx| Pin::new(&mut *acceptor).poll_accept(cx)).await {
146-
Some(Ok(conn)) => conn,
147-
Some(Err(e)) => {
148-
tracing::error!("Accept error: {}", e);
149-
continue;
150-
}
151-
None => break,
152-
};
153-
154-
let tls_acceptor = tls_acceptor.clone();
155-
let svc = wrapped_svc.clone();
156-
157-
tokio::spawn(async move {
100+
101+
// Create a stream of TLS connections from the acceptor
102+
let incoming = tls_incoming_stream(acceptor, tls_acceptor);
103+
104+
// Serve with tonic's native server
105+
router.serve_with_incoming(incoming).await?;
106+
107+
Ok(())
108+
}
109+
110+
async fn run_plain_server<A>(
111+
acceptor: A,
112+
router: tonic::transport::server::Router,
113+
) -> anyhow::Result<()>
114+
where
115+
A: Accept,
116+
{
117+
tracing::info!("serving internal rpc server without tls");
118+
119+
// Create a stream of connections from the acceptor
120+
let incoming = plain_incoming_stream(acceptor);
121+
122+
// Serve with tonic's native server
123+
router.serve_with_incoming(incoming).await?;
124+
125+
Ok(())
126+
}
127+
128+
fn tls_incoming_stream<A>(
129+
mut acceptor: A,
130+
tls_acceptor: TlsAcceptor,
131+
) -> impl Stream<Item = Result<TlsStream<A::Connection>, anyhow::Error>>
132+
where
133+
A: Accept,
134+
{
135+
try_stream! {
136+
loop {
137+
let conn = match poll_fn(|cx| Pin::new(&mut acceptor).poll_accept(cx)).await {
138+
Some(Ok(conn)) => conn,
139+
Some(Err(e)) => {
140+
tracing::error!("Accept error: {}", e);
141+
continue;
142+
}
143+
None => break,
144+
};
145+
158146
let tls_stream = match tls_acceptor.accept(conn).await {
159147
Ok(tls_stream) => tls_stream,
160148
Err(err) => {
161149
tracing::error!("failed to perform tls handshake: {:#}", err);
162-
return;
150+
continue;
163151
}
164152
};
165153

166-
let io = TokioIo::new(tls_stream);
154+
yield TlsStream(tls_stream);
155+
}
156+
}
157+
}
167158

168-
if let Err(err) = ConnBuilder::new(TokioExecutor::new())
169-
.serve_connection(io, svc)
170-
.await
171-
{
172-
tracing::error!("failed to serve connection: {:#}", err);
173-
}
174-
});
159+
fn plain_incoming_stream<A>(
160+
mut acceptor: A,
161+
) -> impl Stream<Item = Result<A::Connection, anyhow::Error>>
162+
where
163+
A: Accept,
164+
{
165+
try_stream! {
166+
loop {
167+
let conn = match poll_fn(|cx| Pin::new(&mut acceptor).poll_accept(cx)).await {
168+
Some(Ok(conn)) => conn,
169+
Some(Err(e)) => {
170+
tracing::error!("Accept error: {}", e);
171+
continue;
172+
}
173+
None => break,
174+
};
175+
176+
yield conn;
177+
}
175178
}
179+
}
176180

177-
Ok(())
181+
// Wrapper for TLS stream to implement Connected
182+
pub struct TlsStream<S>(tokio_rustls::server::TlsStream<S>);
183+
184+
impl<S> AsyncRead for TlsStream<S>
185+
where
186+
S: AsyncRead + AsyncWrite + Unpin,
187+
{
188+
fn poll_read(
189+
self: Pin<&mut Self>,
190+
cx: &mut std::task::Context<'_>,
191+
buf: &mut tokio::io::ReadBuf<'_>,
192+
) -> std::task::Poll<std::io::Result<()>> {
193+
Pin::new(&mut self.get_mut().0).poll_read(cx, buf)
194+
}
178195
}
179196

180-
async fn run_plain_server<A, S, B>(
181-
acceptor: &mut A,
182-
svc: S,
183-
) -> anyhow::Result<()>
197+
impl<S> AsyncWrite for TlsStream<S>
184198
where
185-
A: Accept,
186-
S: tower::Service<hyper::Request<tonic::body::BoxBody>, Response = hyper::Response<B>, Error = std::convert::Infallible>
187-
+ Clone
188-
+ Send
189-
+ 'static,
190-
S::Future: Send + 'static,
191-
S::Response: Send + 'static,
192-
B: http_body::Body<Data = bytes::Bytes> + Send + 'static,
193-
B::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + Sync + 'static,
199+
S: AsyncRead + AsyncWrite + Unpin,
194200
{
195-
tracing::info!("serving internal rpc server without tls");
196-
let wrapped_svc = TonicServiceWrapper { inner: svc };
197-
198-
// Drive the acceptor stream manually for hyper 1.0+ compatibility
199-
loop {
200-
let conn = match poll_fn(|cx| Pin::new(&mut *acceptor).poll_accept(cx)).await {
201-
Some(Ok(conn)) => conn,
202-
Some(Err(e)) => {
203-
tracing::error!("Accept error: {}", e);
204-
continue;
205-
}
206-
None => break,
207-
};
208-
209-
let svc = wrapped_svc.clone();
210-
211-
tokio::spawn(async move {
212-
let io = TokioIo::new(conn);
213-
214-
if let Err(err) = ConnBuilder::new(TokioExecutor::new())
215-
.serve_connection(io, svc)
216-
.await
217-
{
218-
tracing::error!("failed to serve connection: {:#}", err);
219-
}
220-
});
201+
fn poll_write(
202+
self: Pin<&mut Self>,
203+
cx: &mut std::task::Context<'_>,
204+
buf: &[u8],
205+
) -> std::task::Poll<std::io::Result<usize>> {
206+
Pin::new(&mut self.get_mut().0).poll_write(cx, buf)
221207
}
222208

223-
Ok(())
209+
fn poll_flush(
210+
self: Pin<&mut Self>,
211+
cx: &mut std::task::Context<'_>,
212+
) -> std::task::Poll<std::io::Result<()>> {
213+
Pin::new(&mut self.get_mut().0).poll_flush(cx)
214+
}
215+
216+
fn poll_shutdown(
217+
self: Pin<&mut Self>,
218+
cx: &mut std::task::Context<'_>,
219+
) -> std::task::Poll<std::io::Result<()>> {
220+
Pin::new(&mut self.get_mut().0).poll_shutdown(cx)
221+
}
222+
}
223+
224+
impl<S: Connected> Connected for TlsStream<S> {
225+
type ConnectInfo = S::ConnectInfo;
226+
227+
fn connect_info(&self) -> Self::ConnectInfo {
228+
self.0.get_ref().0.connect_info()
229+
}
224230
}
225231

226232
fn extract_namespace<T>(
@@ -242,7 +248,7 @@ fn extract_namespace<T>(
242248
}
243249
}
244250

245-
fn trace_request<B>(req: &hyper::Request<B>, span: &Span) {
251+
fn trace_request<B>(req: &http::Request<B>, span: &Span) {
246252
let _s = span.enter();
247253

248254
tracing::debug!(

0 commit comments

Comments
 (0)