Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit fea8a9f

Browse files
committed
add connect_raw
1 parent 046b5b2 commit fea8a9f

File tree

2 files changed

+75
-35
lines changed

2 files changed

+75
-35
lines changed

‎src/lib.rs

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,39 @@ where
6262
T: TlsConnect<AsyncStream>,
6363
{
6464
let stream = connect_stream(&config).await?;
65+
connect_raw(stream, config, tls).await
66+
}
67+
68+
/// Connect to postgres server with a tls connector.
69+
///
70+
/// ```rust
71+
/// use async_postgres::connect;
72+
///
73+
/// use std::error::Error;
74+
/// use async_std::task::spawn;
75+
///
76+
/// async fn play() -> Result<(), Box<dyn Error>> {
77+
/// let url = "host=localhost user=postgres";
78+
/// let (client, conn) = connect(url.parse()?).await?;
79+
/// spawn(conn);
80+
/// let row = client.query_one("SELECT * FROM user WHERE id=1ドル", &[&0]).await?;
81+
/// let value: &str = row.get(0);
82+
/// println!("value: {}", value);
83+
/// Ok(())
84+
/// }
85+
/// ```
86+
#[inline]
87+
pub async fn connect_raw<S, T>(
88+
stream: S,
89+
config: Config,
90+
tls: T,
91+
) -> io::Result<(Client, Connection<AsyncStream, T::Stream>)>
92+
where
93+
S: Into<AsyncStream>,
94+
T: TlsConnect<AsyncStream>,
95+
{
6596
config
66-
.connect_raw(stream, tls)
97+
.connect_raw(stream.into(), tls)
6798
.await
6899
.map_err(|err| io::Error::new(io::ErrorKind::Other, err))
69100
}

‎src/stream.rs

Lines changed: 43 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
#[cfg(unix)]
2+
use async_std::os::unix::net::UnixStream;
3+
14
use async_std::io::{self, Read, Write};
25
use async_std::net::TcpStream;
36
use std::mem::MaybeUninit;
@@ -6,11 +9,25 @@ use std::task::{Context, Poll};
69
use tokio::io::{AsyncRead, AsyncWrite};
710
use tokio_postgres::config::{Config, Host};
811

9-
/// Default port of postgres.
12+
/// Default socket port of postgres.
1013
const DEFAULT_PORT: u16 = 5432;
1114

12-
/// A wrapper for async_std::net::TcpStream, implementing tokio::io::{AsyncRead, AsyncWrite}.
13-
pub struct AsyncStream(TcpStream);
15+
/// A alias for 'static + Unpin + Send + Read + Write
16+
pub trait AsyncReadWriter: 'static + Unpin + Send + Read + Write {}
17+
18+
impl<T> AsyncReadWriter for T where T: 'static + Unpin + Send + Read + Write {}
19+
20+
/// A adaptor between futures::io::{AsyncRead, AsyncWrite} and tokio::io::{AsyncRead, AsyncWrite}.
21+
pub struct AsyncStream(Box<dyn AsyncReadWriter>);
22+
23+
impl<T> From<T> for AsyncStream
24+
where
25+
T: AsyncReadWriter,
26+
{
27+
fn from(stream: T) -> Self {
28+
Self(Box::new(stream))
29+
}
30+
}
1431

1532
impl AsyncRead for AsyncStream {
1633
#[inline]
@@ -56,39 +73,31 @@ impl AsyncWrite for AsyncStream {
5673
}
5774

5875
/// Establish connection to postgres server by AsyncStream.
76+
///
77+
///
5978
#[inline]
6079
pub async fn connect_stream(config: &Config) -> io::Result<AsyncStream> {
61-
let host = try_tcp_host(&config)?;
62-
let port = config
63-
.get_ports()
64-
.iter()
65-
.copied()
66-
.next()
67-
.unwrap_or(DEFAULT_PORT);
68-
69-
let tcp_stream = TcpStream::connect((host, port)).await?;
70-
Ok(AsyncStream(tcp_stream))
71-
}
72-
73-
/// Try to get TCP hostname from postgres config.
74-
#[inline]
75-
fn try_tcp_host(config: &Config) -> io::Result<&str> {
76-
match config
77-
.get_hosts()
78-
.iter()
79-
.filter_map(|host| {
80-
if let Host::Tcp(value) = host {
81-
Some(value)
82-
} else {
83-
None
80+
let mut error = io::Error::new(io::ErrorKind::Other, "host missing");
81+
let mut ports = config.get_ports().iter().cloned();
82+
for host in config.get_hosts() {
83+
let result = match host {
84+
#[cfg(unix)]
85+
Host::Unix(path) => UnixStream::connect(path).await.map(Into::into),
86+
Host::Tcp(tcp) => {
87+
let port = ports.next().unwrap_or(DEFAULT_PORT);
88+
TcpStream::connect((tcp.as_str(), port))
89+
.await
90+
.map(Into::into)
91+
}
92+
#[cfg(not(unix))]
93+
Host::Unix(_) => {
94+
io::Error::new(io::ErrorKind::Other, "unix domain socket is unsupported")
8495
}
85-
})
86-
.next()
87-
{
88-
Some(host) => Ok(host),
89-
None => Err(io::Error::new(
90-
io::ErrorKind::Other,
91-
"At least one tcp hostname is required",
92-
)),
96+
};
97+
match result {
98+
Err(err) => error = err,
99+
stream => return stream,
100+
}
93101
}
102+
Err(error)
94103
}

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /