@@ -2,24 +2,25 @@ use std::future::poll_fn;
22use std:: pin:: Pin ;
33use std:: sync:: Arc ;
44
5- use hyper_util :: rt :: { TokioExecutor , TokioIo } ;
6- use hyper_util :: server :: conn :: auto :: Builder as ConnBuilder ;
5+ use async_stream :: try_stream ;
6+ use futures :: Stream ;
77use libsql_replication:: rpc:: replication:: replication_log_server:: ReplicationLogServer ;
88use libsql_replication:: rpc:: replication:: { BoxReplicationService , NAMESPACE_METADATA_KEY } ;
99use rustls:: pki_types:: CertificateDer ;
1010use rustls:: RootCertStore ;
11+ use tokio:: io:: { AsyncRead , AsyncWrite } ;
1112use tokio_rustls:: TlsAcceptor ;
13+ use tonic:: transport:: server:: Connected ;
1214use tonic:: Status ;
1315use tower:: util:: option_layer;
14- use tower:: Service ;
1516use tower:: ServiceBuilder ;
1617use tower_http:: trace:: DefaultOnResponse ;
1718use tracing:: Span ;
1819
1920use crate :: config:: TlsConfig ;
2021use crate :: metrics:: CLIENT_VERSION ;
2122use crate :: namespace:: NamespaceName ;
22- use crate :: net:: { Accept , Conn } ;
23+ use crate :: net:: Accept ;
2324use crate :: rpc:: proxy:: rpc:: proxy_server:: ProxyServer ;
2425use crate :: rpc:: proxy:: ProxyService ;
2526use crate :: utils:: services:: idle_shutdown:: IdleShutdownKicker ;
@@ -31,18 +32,14 @@ pub mod streaming_exec;
3132
3233pub async fn run_rpc_server < A : Accept > (
3334 proxy_service : ProxyService ,
34- mut acceptor : A ,
35+ acceptor : A ,
3536 maybe_tls : Option < TlsConfig > ,
3637 idle_shutdown_layer : Option < IdleShutdownKicker > ,
3738 service : BoxReplicationService ,
3839) -> anyhow:: Result < ( ) > {
39- let router = tonic:: transport:: Server :: builder ( )
40+ // Build the tonic server with services
41+ let mut server = tonic:: transport:: Server :: builder ( )
4042 . layer ( & option_layer ( idle_shutdown_layer) )
41- . add_service ( ProxyServer :: new ( proxy_service) )
42- . add_service ( ReplicationLogServer :: new ( service) )
43- . into_router ( ) ;
44-
45- let svc = ServiceBuilder :: new ( )
4643 . layer (
4744 tower_http:: trace:: TraceLayer :: new_for_grpc ( )
4845 . on_request ( trace_request)
@@ -51,63 +48,26 @@ pub async fn run_rpc_server<A: Accept>(
5148 . level ( tracing:: Level :: DEBUG )
5249 . latency_unit ( tower_http:: LatencyUnit :: Micros ) ,
5350 ) ,
54- )
55- . service ( router) ;
51+ ) ;
52+
53+ let router = server
54+ . add_service ( ProxyServer :: new ( proxy_service) )
55+ . add_service ( ReplicationLogServer :: new ( service) ) ;
5656
5757 if let Some ( tls_config) = maybe_tls {
58- run_tls_server ( & mut acceptor, svc , tls_config) . await
58+ run_tls_server ( acceptor, router , tls_config) . await
5959 } else {
60- run_plain_server ( & mut acceptor, svc ) . await
60+ run_plain_server ( acceptor, router ) . await
6161 }
6262}
6363
64- /// Wrapper service that converts hyper 1.0's Incoming body to tonic's BoxBody
65- #[ derive( Clone ) ]
66- struct TonicServiceWrapper < S > {
67- inner : S ,
68- }
69-
70- impl < S , B > Service < hyper:: Request < hyper:: body:: Incoming > > for TonicServiceWrapper < S >
71- where
72- S : Service < hyper:: Request < tonic:: body:: BoxBody > , Response = hyper:: Response < B > , Error = std:: convert:: Infallible > + Clone + Send + ' static ,
73- S :: Future : Send + ' static ,
74- B : http_body:: Body < Data = bytes:: Bytes > + Send + ' static ,
75- B :: Error : Into < Box < dyn std:: error:: Error + Send + Sync > > + Send + Sync + ' static ,
76- {
77- type Response = hyper:: Response < B > ;
78- type Error = std:: convert:: Infallible ;
79- type Future = S :: Future ;
80-
81- fn poll_ready ( & mut self , cx : & mut std:: task:: Context < ' _ > ) -> std:: task:: Poll < Result < ( ) , Self :: Error > > {
82- self . inner . poll_ready ( cx)
83- }
84-
85- fn call ( & mut self , req : hyper:: Request < hyper:: body:: Incoming > ) -> Self :: Future {
86- // Convert Incoming body to tonic's BoxBody
87- // Need to map the error type from io::Error to tonic::Status
88- let ( parts, body) = req. into_parts ( ) ;
89- let body = body. map_err ( |e| tonic:: Status :: internal ( format ! ( "body error: {}" , e) ) ) ;
90- let body = tonic:: body:: BoxBody :: new ( body) ;
91- let req = hyper:: Request :: from_parts ( parts, body) ;
92- self . inner . call ( req)
93- }
94- }
95-
96- async fn run_tls_server < A , S , B > (
97- acceptor : & mut A ,
98- svc : S ,
64+ async fn run_tls_server < A > (
65+ acceptor : A ,
66+ router : tonic:: transport:: server:: Router ,
9967 tls_config : TlsConfig ,
10068) -> anyhow:: Result < ( ) >
10169where
10270 A : Accept ,
103- S : tower:: Service < hyper:: Request < tonic:: body:: BoxBody > , Response = hyper:: Response < B > , Error = std:: convert:: Infallible >
104- + Clone
105- + Send
106- + ' static ,
107- S :: Future : Send + ' static ,
108- S :: Response : Send + ' static ,
109- B : http_body:: Body < Data = bytes:: Bytes > + Send + ' static ,
110- B :: Error : Into < Box < dyn std:: error:: Error + Send + Sync > > + Send + Sync + ' static ,
11171{
11272 let cert_pem = tokio:: fs:: read_to_string ( & tls_config. cert ) . await ?;
11373 let certs: Vec < CertificateDer < ' static > > = rustls_pemfile:: certs ( & mut cert_pem. as_bytes ( ) )
11676 let key_pem = tokio:: fs:: read_to_string ( & tls_config. key ) . await ?;
11777 let keys: Vec < _ > = rustls_pemfile:: pkcs8_private_keys ( & mut key_pem. as_bytes ( ) )
11878 . collect :: < Result < Vec < _ > , _ > > ( ) ?;
119- let key = rustls:: pki_types:: PrivateKeyDer :: try_from ( keys. into_iter ( ) . next ( ) . ok_or_else ( || anyhow:: anyhow!( "no private keys found" ) ) ?) ?
79+ let key = rustls:: pki_types:: PrivateKeyDer :: try_from ( keys. into_iter ( ) . next ( ) . ok_or_else ( || anyhow:: anyhow!( "no private keys found" ) ) ?) ?;
12080
12181 let ca_cert_pem = std:: fs:: read_to_string ( & tls_config. ca_cert ) ?;
12282 let ca_certs: Vec < CertificateDer < ' static > > = rustls_pemfile:: certs ( & mut ca_cert_pem. as_bytes ( ) )
@@ -137,90 +97,136 @@ where
13797 let tls_acceptor = TlsAcceptor :: from ( Arc :: new ( config) ) ;
13898
13999 tracing:: info!( "serving internal rpc server with tls" ) ;
140-
141- let wrapped_svc = TonicServiceWrapper { inner : svc } ;
142-
143- // Drive the acceptor stream manually for hyper 1.0+ compatibility
144- loop {
145- let conn = match poll_fn ( |cx| Pin :: new ( & mut * acceptor) . poll_accept ( cx) ) . await {
146- Some ( Ok ( conn) ) => conn,
147- Some ( Err ( e) ) => {
148- tracing:: error!( "Accept error: {}" , e) ;
149- continue ;
150- }
151- None => break ,
152- } ;
153-
154- let tls_acceptor = tls_acceptor. clone ( ) ;
155- let svc = wrapped_svc. clone ( ) ;
156-
157- tokio:: spawn ( async move {
100+
101+ // Create a stream of TLS connections from the acceptor
102+ let incoming = tls_incoming_stream ( acceptor, tls_acceptor) ;
103+
104+ // Serve with tonic's native server
105+ router. serve_with_incoming ( incoming) . await ?;
106+
107+ Ok ( ( ) )
108+ }
109+
110+ async fn run_plain_server < A > (
111+ acceptor : A ,
112+ router : tonic:: transport:: server:: Router ,
113+ ) -> anyhow:: Result < ( ) >
114+ where
115+ A : Accept ,
116+ {
117+ tracing:: info!( "serving internal rpc server without tls" ) ;
118+
119+ // Create a stream of connections from the acceptor
120+ let incoming = plain_incoming_stream ( acceptor) ;
121+
122+ // Serve with tonic's native server
123+ router. serve_with_incoming ( incoming) . await ?;
124+
125+ Ok ( ( ) )
126+ }
127+
128+ fn tls_incoming_stream < A > (
129+ mut acceptor : A ,
130+ tls_acceptor : TlsAcceptor ,
131+ ) -> impl Stream < Item = Result < TlsStream < A :: Connection > , anyhow:: Error > >
132+ where
133+ A : Accept ,
134+ {
135+ try_stream ! {
136+ loop {
137+ let conn = match poll_fn( |cx| Pin :: new( & mut acceptor) . poll_accept( cx) ) . await {
138+ Some ( Ok ( conn) ) => conn,
139+ Some ( Err ( e) ) => {
140+ tracing:: error!( "Accept error: {}" , e) ;
141+ continue ;
142+ }
143+ None => break ,
144+ } ;
145+
158146 let tls_stream = match tls_acceptor. accept( conn) . await {
159147 Ok ( tls_stream) => tls_stream,
160148 Err ( err) => {
161149 tracing:: error!( "failed to perform tls handshake: {:#}" , err) ;
162- return ;
150+ continue ;
163151 }
164152 } ;
165153
166- let io = TokioIo :: new ( tls_stream) ;
154+ yield TlsStream ( tls_stream) ;
155+ }
156+ }
157+ }
167158
168- if let Err ( err) = ConnBuilder :: new ( TokioExecutor :: new ( ) )
169- . serve_connection ( io, svc)
170- . await
171- {
172- tracing:: error!( "failed to serve connection: {:#}" , err) ;
173- }
174- } ) ;
159+ fn plain_incoming_stream < A > (
160+ mut acceptor : A ,
161+ ) -> impl Stream < Item = Result < A :: Connection , anyhow:: Error > >
162+ where
163+ A : Accept ,
164+ {
165+ try_stream ! {
166+ loop {
167+ let conn = match poll_fn( |cx| Pin :: new( & mut acceptor) . poll_accept( cx) ) . await {
168+ Some ( Ok ( conn) ) => conn,
169+ Some ( Err ( e) ) => {
170+ tracing:: error!( "Accept error: {}" , e) ;
171+ continue ;
172+ }
173+ None => break ,
174+ } ;
175+
176+ yield conn;
177+ }
175178 }
179+ }
176180
177- Ok ( ( ) )
181+ // Wrapper for TLS stream to implement Connected
182+ pub struct TlsStream < S > ( tokio_rustls:: server:: TlsStream < S > ) ;
183+
184+ impl < S > AsyncRead for TlsStream < S >
185+ where
186+ S : AsyncRead + AsyncWrite + Unpin ,
187+ {
188+ fn poll_read (
189+ self : Pin < & mut Self > ,
190+ cx : & mut std:: task:: Context < ' _ > ,
191+ buf : & mut tokio:: io:: ReadBuf < ' _ > ,
192+ ) -> std:: task:: Poll < std:: io:: Result < ( ) > > {
193+ Pin :: new ( & mut self . get_mut ( ) . 0 ) . poll_read ( cx, buf)
194+ }
178195}
179196
180- async fn run_plain_server < A , S , B > (
181- acceptor : & mut A ,
182- svc : S ,
183- ) -> anyhow:: Result < ( ) >
197+ impl < S > AsyncWrite for TlsStream < S >
184198where
185- A : Accept ,
186- S : tower:: Service < hyper:: Request < tonic:: body:: BoxBody > , Response = hyper:: Response < B > , Error = std:: convert:: Infallible >
187- + Clone
188- + Send
189- + ' static ,
190- S :: Future : Send + ' static ,
191- S :: Response : Send + ' static ,
192- B : http_body:: Body < Data = bytes:: Bytes > + Send + ' static ,
193- B :: Error : Into < Box < dyn std:: error:: Error + Send + Sync > > + Send + Sync + ' static ,
199+ S : AsyncRead + AsyncWrite + Unpin ,
194200{
195- tracing:: info!( "serving internal rpc server without tls" ) ;
196- let wrapped_svc = TonicServiceWrapper { inner : svc } ;
197-
198- // Drive the acceptor stream manually for hyper 1.0+ compatibility
199- loop {
200- let conn = match poll_fn ( |cx| Pin :: new ( & mut * acceptor) . poll_accept ( cx) ) . await {
201- Some ( Ok ( conn) ) => conn,
202- Some ( Err ( e) ) => {
203- tracing:: error!( "Accept error: {}" , e) ;
204- continue ;
205- }
206- None => break ,
207- } ;
208-
209- let svc = wrapped_svc. clone ( ) ;
210-
211- tokio:: spawn ( async move {
212- let io = TokioIo :: new ( conn) ;
213-
214- if let Err ( err) = ConnBuilder :: new ( TokioExecutor :: new ( ) )
215- . serve_connection ( io, svc)
216- . await
217- {
218- tracing:: error!( "failed to serve connection: {:#}" , err) ;
219- }
220- } ) ;
201+ fn poll_write (
202+ self : Pin < & mut Self > ,
203+ cx : & mut std:: task:: Context < ' _ > ,
204+ buf : & [ u8 ] ,
205+ ) -> std:: task:: Poll < std:: io:: Result < usize > > {
206+ Pin :: new ( & mut self . get_mut ( ) . 0 ) . poll_write ( cx, buf)
221207 }
222208
223- Ok ( ( ) )
209+ fn poll_flush (
210+ self : Pin < & mut Self > ,
211+ cx : & mut std:: task:: Context < ' _ > ,
212+ ) -> std:: task:: Poll < std:: io:: Result < ( ) > > {
213+ Pin :: new ( & mut self . get_mut ( ) . 0 ) . poll_flush ( cx)
214+ }
215+
216+ fn poll_shutdown (
217+ self : Pin < & mut Self > ,
218+ cx : & mut std:: task:: Context < ' _ > ,
219+ ) -> std:: task:: Poll < std:: io:: Result < ( ) > > {
220+ Pin :: new ( & mut self . get_mut ( ) . 0 ) . poll_shutdown ( cx)
221+ }
222+ }
223+
224+ impl < S : Connected > Connected for TlsStream < S > {
225+ type ConnectInfo = S :: ConnectInfo ;
226+
227+ fn connect_info ( & self ) -> Self :: ConnectInfo {
228+ self . 0 . get_ref ( ) . 0 . connect_info ( )
229+ }
224230}
225231
226232fn extract_namespace < T > (
@@ -242,7 +248,7 @@ fn extract_namespace<T>(
242248 }
243249}
244250
245- fn trace_request < B > ( req : & hyper :: Request < B > , span : & Span ) {
251+ fn trace_request < B > ( req : & http :: Request < B > , span : & Span ) {
246252 let _s = span. enter ( ) ;
247253
248254 tracing:: debug!(
0 commit comments