Skip to content

Commit 6ab7088

Browse files
committed
fix: Hyper 1.0 migration - net.rs and trait fixes
- Added HyperStream wrapper for bridging tokio and hyper traits - Implemented hyper::rt::Read/Write for AddrStream - Removed invalid TlsStream trait impls (orphan rules) - Fixed Connector trait to require hyper 1.0 Read/Write - Down to 40 compilation errors from 84+ Remaining issues: - H2cMaker service trait - Axum Router trait bounds - StreamBody type mismatches
1 parent 60d694c commit 6ab7088

4 files changed

Lines changed: 178 additions & 21 deletions

File tree

libsql-server/src/connection/config.rs

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,9 @@ impl From<&metadata::DatabaseConfig> for DatabaseConfig {
8686
.map(NamespaceName::new_unchecked),
8787
durability_mode: match value.durability_mode {
8888
None => DurabilityMode::default(),
89-
Some(m) => DurabilityMode::from(metadata::DurabilityMode::try_from(m)),
89+
Some(m) => metadata::DurabilityMode::try_from(m)
90+
.map(|mode| DurabilityMode::from(mode))
91+
.unwrap_or_default(),
9092
},
9193
}
9294
}
@@ -171,16 +173,13 @@ impl From<DurabilityMode> for metadata::DurabilityMode {
171173
}
172174
}
173175

174-
impl From<Result<metadata::DurabilityMode, prost::DecodeError>> for DurabilityMode {
175-
fn from(value: Result<metadata::DurabilityMode, prost::DecodeError>) -> Self {
176-
match value {
177-
Ok(mode) => match mode {
178-
metadata::DurabilityMode::Relaxed => DurabilityMode::Relaxed,
179-
metadata::DurabilityMode::Strong => DurabilityMode::Strong,
180-
metadata::DurabilityMode::Extra => DurabilityMode::Extra,
181-
metadata::DurabilityMode::Off => DurabilityMode::Off,
182-
},
183-
Err(_) => DurabilityMode::default(),
176+
impl From<metadata::DurabilityMode> for DurabilityMode {
177+
fn from(mode: metadata::DurabilityMode) -> Self {
178+
match mode {
179+
metadata::DurabilityMode::Relaxed => DurabilityMode::Relaxed,
180+
metadata::DurabilityMode::Strong => DurabilityMode::Strong,
181+
metadata::DurabilityMode::Extra => DurabilityMode::Extra,
182+
metadata::DurabilityMode::Off => DurabilityMode::Off,
184183
}
185184
}
186185
}

libsql-server/src/h2c.rs

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,9 +171,25 @@ where
171171

172172
let executor = TokioExecutor::new();
173173
let conn = Http2Builder::new(executor);
174-
let svc = tower::service_fn(move |mut r: Request<Body>| {
174+
175+
// Create a service that handles incoming HTTP/2 requests
176+
let svc = hyper::service::service_fn(move |mut r: Request<hyper::body::Incoming>| {
175177
r.extensions_mut().insert(connect_info.clone());
176-
svc.call(r)
178+
// Convert the axum service response
179+
let svc_clone = svc.clone();
180+
async move {
181+
// Convert Request<Incoming> to Request<Body> for axum
182+
let (parts, body) = r.into_parts();
183+
let body = Body::from_stream(body);
184+
let req = Request::from_parts(parts, body);
185+
186+
svc_clone.call(req).await.map(|res| {
187+
// Convert Response<B> to Response<BoxBody>
188+
let (parts, body) = res.into_parts();
189+
let body = body.boxed_unsync();
190+
Response::from_parts(parts, body)
191+
}).map_err(|e| Box::new(e) as BoxError)
192+
}
177193
});
178194

179195
if let Err(e) = conn.serve_connection(upgraded_io, svc).await {

libsql-server/src/http/user/mod.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,13 @@ async fn handle_upgrade(
180180
return StatusCode::NOT_FOUND.into_response();
181181
}
182182

183+
// Convert axum Request<Body> to hyper Request<Incoming>
184+
// In axum 0.7, Body can be converted to Incoming by consuming it
185+
let (parts, body) = req.into_parts();
186+
let body = body.into_data_stream();
187+
let body = hyper::body::Body::from_stream(body);
188+
let req = Request::from_parts(parts, body);
189+
183190
let (response_tx, response_rx) = oneshot::channel();
184191
let _: Result<_, _> = upgrade_tx
185192
.send(hrana::ws::Upgrade {

libsql-server/src/net.rs

Lines changed: 143 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,29 +5,119 @@ use std::pin::Pin;
55
use std::task::{ready, Context, Poll};
66

77
use http::Uri;
8+
use hyper::rt::{Read, Write};
89
use hyper_util::client::legacy::connect::Connection;
10+
use hyper_util::rt::TokioIo;
911
use pin_project_lite::pin_project;
1012
use tokio::io::{AsyncRead, AsyncWrite};
1113
use tokio_rustls::server::TlsStream;
1214
use tonic::transport::server::{Connected, TcpConnectInfo};
1315
use tower::Service;
1416

17+
pin_project! {
18+
/// A wrapper that adds hyper 1.0's Read/Write traits to any tokio AsyncRead/AsyncWrite type.
19+
/// This uses TokioIo internally to bridge between tokio and hyper traits.
20+
pub struct HyperStream<S> {
21+
#[pin]
22+
inner: TokioIo<S>,
23+
}
24+
}
25+
26+
impl<S> HyperStream<S> {
27+
pub fn new(stream: S) -> Self {
28+
Self {
29+
inner: TokioIo::new(stream),
30+
}
31+
}
32+
33+
pub fn into_inner(self) -> S {
34+
self.inner.into_inner()
35+
}
36+
}
37+
38+
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for HyperStream<S> {
39+
fn poll_read(
40+
self: Pin<&mut Self>,
41+
cx: &mut Context<'_>,
42+
buf: &mut tokio::io::ReadBuf<'_>,
43+
) -> Poll<std::io::Result<()>> {
44+
// SAFETY: HyperStream is Unpin if S is Unpin
45+
let this = unsafe { self.get_unchecked_mut() };
46+
Pin::new(&mut this.inner).poll_read(cx, buf)
47+
}
48+
}
49+
50+
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for HyperStream<S> {
51+
fn poll_write(
52+
self: Pin<&mut Self>,
53+
cx: &mut Context<'_>,
54+
buf: &[u8],
55+
) -> Poll<std::io::Result<usize>> {
56+
let this = unsafe { self.get_unchecked_mut() };
57+
Pin::new(&mut this.inner).poll_write(cx, buf)
58+
}
59+
60+
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
61+
let this = unsafe { self.get_unchecked_mut() };
62+
Pin::new(&mut this.inner).poll_flush(cx)
63+
}
64+
65+
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
66+
let this = unsafe { self.get_unchecked_mut() };
67+
Pin::new(&mut this.inner).poll_shutdown(cx)
68+
}
69+
}
70+
71+
impl<S: AsyncRead + AsyncWrite + Unpin> Read for HyperStream<S> {
72+
fn poll_read(
73+
self: Pin<&mut Self>,
74+
cx: &mut Context<'_>,
75+
buf: hyper::rt::ReadBufCursor<'_>,
76+
) -> Poll<std::io::Result<()>> {
77+
self.project().inner.poll_read(cx, buf)
78+
}
79+
}
80+
81+
impl<S: AsyncRead + AsyncWrite + Unpin> Write for HyperStream<S> {
82+
fn poll_write(
83+
self: Pin<&mut Self>,
84+
cx: &mut Context<'_>,
85+
buf: &[u8],
86+
) -> Poll<std::io::Result<usize>> {
87+
self.project().inner.poll_write(cx, buf)
88+
}
89+
90+
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
91+
self.project().inner.poll_flush(cx)
92+
}
93+
94+
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
95+
self.project().inner.poll_shutdown(cx)
96+
}
97+
}
98+
99+
impl<S: AsyncRead + AsyncWrite + Connection + Unpin> Connection for HyperStream<S> {
100+
fn connected(&self) -> hyper_util::client::legacy::connect::Connected {
101+
self.inner.inner().connected()
102+
}
103+
}
104+
15105
pub trait Connector:
16106
Service<Uri, Response = Self::Conn, Future = Self::Fut, Error = Self::Err>
17107
+ Send
18108
+ Sync
19109
+ 'static
20110
+ Clone
21111
{
22-
type Conn: Unpin + Send + 'static + AsyncRead + AsyncWrite + Connection;
112+
type Conn: Unpin + Send + 'static + AsyncRead + AsyncWrite + Read + Write + Connection;
23113
type Fut: Send + 'static + Unpin;
24114
type Err: Into<Box<dyn StdError + Send + Sync>> + Send + Sync;
25115
}
26116

27117
impl<T> Connector for T
28118
where
29119
T: Service<Uri> + Send + Sync + 'static + Clone,
30-
T::Response: Unpin + Send + 'static + AsyncRead + AsyncWrite + Connection,
120+
T::Response: Unpin + Send + 'static + AsyncRead + AsyncWrite + Read + Write + Connection,
31121
T::Future: Send + 'static + Unpin,
32122
T::Error: Into<Box<dyn StdError + Send + Sync>> + Send + Sync,
33123
{
@@ -36,7 +126,7 @@ where
36126
type Err = Self::Error;
37127
}
38128

39-
pub trait Conn: AsyncRead + AsyncWrite + Unpin + Send + 'static {
129+
pub trait Conn: AsyncRead + AsyncWrite + Read + Write + Unpin + Send + 'static {
40130
fn connect_info(&self) -> TcpConnectInfo;
41131
}
42132

@@ -107,11 +197,8 @@ where
107197
}
108198
}
109199

110-
impl<C: Conn> Conn for TlsStream<C> {
111-
fn connect_info(&self) -> TcpConnectInfo {
112-
self.get_ref().0.connect_info()
113-
}
114-
}
200+
// Note: TlsStream doesn't implement Conn directly because it doesn't implement hyper::rt::Read/Write.
201+
// Use HyperStream<TlsStream<C>> when you need a connection that implements Conn.
115202

116203
impl<S> AsyncRead for AddrStream<S>
117204
where
@@ -153,6 +240,54 @@ where
153240
}
154241
}
155242

243+
impl<S> Read for AddrStream<S>
244+
where
245+
S: AsyncRead + AsyncWrite + Unpin,
246+
{
247+
fn poll_read(
248+
self: Pin<&mut Self>,
249+
cx: &mut Context<'_>,
250+
mut buf: hyper::rt::ReadBufCursor<'_>,
251+
) -> Poll<std::io::Result<()>> {
252+
// SAFETY: We're creating a tokio ReadBuf from the hyper ReadBufCursor
253+
let slice = unsafe {
254+
std::slice::from_raw_parts_mut(buf.as_mut().as_mut_ptr(), buf.as_mut().len())
255+
};
256+
let mut read_buf = tokio::io::ReadBuf::new(slice);
257+
258+
match self.project().stream.poll_read(cx, &mut read_buf) {
259+
Poll::Ready(Ok(())) => {
260+
let filled = read_buf.filled().len();
261+
unsafe { buf.advance(filled) };
262+
Poll::Ready(Ok(()))
263+
}
264+
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
265+
Poll::Pending => Poll::Pending,
266+
}
267+
}
268+
}
269+
270+
impl<S> Write for AddrStream<S>
271+
where
272+
S: AsyncRead + AsyncWrite + Unpin,
273+
{
274+
fn poll_write(
275+
self: Pin<&mut Self>,
276+
cx: &mut Context<'_>,
277+
buf: &[u8],
278+
) -> Poll<std::io::Result<usize>> {
279+
self.project().stream.poll_write(cx, buf)
280+
}
281+
282+
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
283+
self.project().stream.poll_flush(cx)
284+
}
285+
286+
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
287+
self.project().stream.poll_shutdown(cx)
288+
}
289+
}
290+
156291
impl<S> Connected for AddrStream<S> {
157292
type ConnectInfo = TcpConnectInfo;
158293

0 commit comments

Comments
 (0)