1
+ #[ cfg( unix) ]
2
+ use async_std:: os:: unix:: net:: UnixStream ;
3
+
1
4
use async_std:: io:: { self , Read , Write } ;
2
5
use async_std:: net:: TcpStream ;
3
6
use std:: mem:: MaybeUninit ;
@@ -6,11 +9,25 @@ use std::task::{Context, Poll};
6
9
use tokio:: io:: { AsyncRead , AsyncWrite } ;
7
10
use tokio_postgres:: config:: { Config , Host } ;
8
11
9
- /// Default port of postgres.
12
+ /// Default socket port of postgres.
10
13
const DEFAULT_PORT : u16 = 5432 ;
11
14
12
- /// A wrapper for async_std::net::TcpStream, implementing tokio::io::{AsyncRead, AsyncWrite}.
13
- pub struct AsyncStream ( TcpStream ) ;
15
+ /// A alias for 'static + Unpin + Send + Read + Write
16
+ pub trait AsyncReadWriter : ' static + Unpin + Send + Read + Write { }
17
+
18
+ impl < T > AsyncReadWriter for T where T : ' static + Unpin + Send + Read + Write { }
19
+
20
+ /// A adaptor between futures::io::{AsyncRead, AsyncWrite} and tokio::io::{AsyncRead, AsyncWrite}.
21
+ pub struct AsyncStream ( Box < dyn AsyncReadWriter > ) ;
22
+
23
+ impl < T > From < T > for AsyncStream
24
+ where
25
+ T : AsyncReadWriter ,
26
+ {
27
+ fn from ( stream : T ) -> Self {
28
+ Self ( Box :: new ( stream) )
29
+ }
30
+ }
14
31
15
32
impl AsyncRead for AsyncStream {
16
33
#[ inline]
@@ -56,39 +73,31 @@ impl AsyncWrite for AsyncStream {
56
73
}
57
74
58
75
/// Establish connection to postgres server by AsyncStream.
76
+ ///
77
+ ///
59
78
#[ inline]
60
79
pub async fn connect_stream ( config : & Config ) -> io:: Result < AsyncStream > {
61
- let host = try_tcp_host ( & config) ?;
62
- let port = config
63
- . get_ports ( )
64
- . iter ( )
65
- . copied ( )
66
- . next ( )
67
- . unwrap_or ( DEFAULT_PORT ) ;
68
-
69
- let tcp_stream = TcpStream :: connect ( ( host, port) ) . await ?;
70
- Ok ( AsyncStream ( tcp_stream) )
71
- }
72
-
73
- /// Try to get TCP hostname from postgres config.
74
- #[ inline]
75
- fn try_tcp_host ( config : & Config ) -> io:: Result < & str > {
76
- match config
77
- . get_hosts ( )
78
- . iter ( )
79
- . filter_map ( |host| {
80
- if let Host :: Tcp ( value) = host {
81
- Some ( value)
82
- } else {
83
- None
80
+ let mut error = io:: Error :: new ( io:: ErrorKind :: Other , "host missing" ) ;
81
+ let mut ports = config. get_ports ( ) . iter ( ) . cloned ( ) ;
82
+ for host in config. get_hosts ( ) {
83
+ let result = match host {
84
+ #[ cfg( unix) ]
85
+ Host :: Unix ( path) => UnixStream :: connect ( path) . await . map ( Into :: into) ,
86
+ Host :: Tcp ( tcp) => {
87
+ let port = ports. next ( ) . unwrap_or ( DEFAULT_PORT ) ;
88
+ TcpStream :: connect ( ( tcp. as_str ( ) , port) )
89
+ . await
90
+ . map ( Into :: into)
91
+ }
92
+ #[ cfg( not( unix) ) ]
93
+ Host :: Unix ( _) => {
94
+ io:: Error :: new ( io:: ErrorKind :: Other , "unix domain socket is unsupported" )
84
95
}
85
- } )
86
- . next ( )
87
- {
88
- Some ( host) => Ok ( host) ,
89
- None => Err ( io:: Error :: new (
90
- io:: ErrorKind :: Other ,
91
- "At least one tcp hostname is required" ,
92
- ) ) ,
96
+ } ;
97
+ match result {
98
+ Err ( err) => error = err,
99
+ stream => return stream,
100
+ }
93
101
}
102
+ Err ( error)
94
103
}
0 commit comments