Skip to content

Commit 82f708c

Browse files
committed
Add a MessageStream
1 parent f135d22 commit 82f708c

File tree

4 files changed

+87
-54
lines changed

4 files changed

+87
-54
lines changed

src/lib.rs

Lines changed: 27 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,11 @@ use io::{TlsStream, TlsHandshake};
6868
use message::{Backend, RowDescriptionEntry, ReadMessage};
6969
use notification::{Notifications, Notification};
7070
use params::{ConnectParams, IntoConnectParams, UserInfo};
71+
use priv_io::MessageStream;
7172
use rows::{Rows, LazyRows};
7273
use stmt::{Statement, Column};
73-
use types::{IsNull, Kind, Type, SessionInfo, Oid, Other, WrongType, ToSql, FromSql, Field};
7474
use transaction::{Transaction, IsolationLevel};
75+
use types::{IsNull, Kind, Type, SessionInfo, Oid, Other, WrongType, ToSql, FromSql, Field};
7576

7677
#[macro_use]
7778
mod macros;
@@ -127,9 +128,9 @@ impl HandleNotice for LoggingNoticeHandler {
127128
#[derive(Copy, Clone, Debug)]
128129
pub struct CancelData {
129130
/// The process ID of the session.
130-
pub process_id: u32,
131+
pub process_id: i32,
131132
/// The secret key for the session.
132-
pub secret_key: u32,
133+
pub secret_key: i32,
133134
}
134135

135136
/// Attempts to cancel an in-progress query.
@@ -167,8 +168,8 @@ pub fn cancel_query<T>(params: T,
167168
let mut socket = try!(priv_io::initialize_stream(&params, tls));
168169

169170
let message = frontend::CancelRequest {
170-
process_id: data.process_id as i32,
171-
secret_key: data.secret_key as i32,
171+
process_id: data.process_id,
172+
secret_key: data.secret_key,
172173
};
173174
let mut buf = vec![];
174175
try!(frontend::Message::write(&message, &mut buf));
@@ -208,8 +209,7 @@ struct StatementInfo {
208209
}
209210

210211
struct InnerConnection {
211-
stream: BufStream<Box<TlsStream>>,
212-
io_buf: Vec<u8>,
212+
stream: MessageStream,
213213
notice_handler: Box<HandleNotice>,
214214
notifications: VecDeque<Notification>,
215215
cancel_data: CancelData,
@@ -250,8 +250,7 @@ impl InnerConnection {
250250
};
251251

252252
let mut conn = InnerConnection {
253-
stream: BufStream::new(stream),
254-
io_buf: vec![],
253+
stream: MessageStream::new(stream),
255254
next_stmt_id: 0,
256255
notice_handler: Box::new(LoggingNoticeHandler),
257256
notifications: VecDeque::new(),
@@ -280,7 +279,7 @@ impl InnerConnection {
280279
options.push(("database".to_owned(), database));
281280
}
282281

283-
try!(conn.write_message(&frontend::StartupMessage {
282+
try!(conn.stream.write_message(&frontend::StartupMessage {
284283
parameters: &options,
285284
}));
286285
try!(conn.stream.flush());
@@ -290,8 +289,8 @@ impl InnerConnection {
290289
loop {
291290
match try!(conn.read_message()) {
292291
Backend::BackendKeyData { process_id, secret_key } => {
293-
conn.cancel_data.process_id = process_id;
294-
conn.cancel_data.secret_key = secret_key;
292+
conn.cancel_data.process_id = process_id as i32;
293+
conn.cancel_data.secret_key = secret_key as i32;
295294
}
296295
Backend::ReadyForQuery { .. } => break,
297296
Backend::ErrorResponse { fields } => return DbError::new_connect(fields),
@@ -302,16 +301,6 @@ impl InnerConnection {
302301
Ok(conn)
303302
}
304303

305-
fn write_message<M>(&mut self, message: &M) -> std_io::Result<()>
306-
where M: frontend::Message
307-
{
308-
debug_assert!(!self.desynchronized);
309-
self.io_buf.clear();
310-
try!(message.write(&mut self.io_buf));
311-
try_desync!(self, self.stream.write_all(&self.io_buf));
312-
Ok(())
313-
}
314-
315304
fn read_message_with_notification(&mut self) -> std_io::Result<Backend> {
316305
debug_assert!(!self.desynchronized);
317306
loop {
@@ -388,7 +377,7 @@ impl InnerConnection {
388377
let pass = try!(user.password.ok_or_else(|| {
389378
ConnectError::ConnectParams("a password was requested but not provided".into())
390379
}));
391-
try!(self.write_message(&frontend::PasswordMessage { password: &pass }));
380+
try!(self.stream.write_message(&frontend::PasswordMessage { password: &pass }));
392381
try!(self.stream.flush());
393382
}
394383
Backend::AuthenticationMD5Password { salt } => {
@@ -403,7 +392,7 @@ impl InnerConnection {
403392
hasher.input(output.as_bytes());
404393
hasher.input(&salt);
405394
let output = format!("md5{}", hasher.result_str());
406-
try!(self.write_message(&frontend::PasswordMessage { password: &output }));
395+
try!(self.stream.write_message(&frontend::PasswordMessage { password: &output }));
407396
try!(self.stream.flush());
408397
}
409398
Backend::AuthenticationKerberosV5 |
@@ -431,16 +420,16 @@ impl InnerConnection {
431420
fn raw_prepare(&mut self, stmt_name: &str, query: &str) -> Result<(Vec<Type>, Vec<Column>)> {
432421
debug!("preparing query with name `{}`: {}", stmt_name, query);
433422

434-
try!(self.write_message(&frontend::Parse {
423+
try!(self.stream.write_message(&frontend::Parse {
435424
name: stmt_name,
436425
query: query,
437426
param_types: &[],
438427
}));
439-
try!(self.write_message(&frontend::Describe {
428+
try!(self.stream.write_message(&frontend::Describe {
440429
variant: b'S',
441430
name: stmt_name,
442431
}));
443-
try!(self.write_message(&frontend::Sync));
432+
try!(self.stream.write_message(&frontend::Sync));
444433
try!(self.stream.flush());
445434

446435
match try!(self.read_message()) {
@@ -496,10 +485,10 @@ impl InnerConnection {
496485
return DbError::new(fields);
497486
}
498487
Backend::CopyInResponse { .. } => {
499-
try!(self.write_message(&frontend::CopyFail {
488+
try!(self.stream.write_message(&frontend::CopyFail {
500489
message: "COPY queries cannot be directly executed",
501490
}));
502-
try!(self.write_message(&frontend::Sync));
491+
try!(self.stream.write_message(&frontend::Sync));
503492
try!(self.stream.flush());
504493
}
505494
Backend::CopyOutResponse { .. } => {
@@ -545,18 +534,18 @@ impl InnerConnection {
545534
}
546535
}
547536

548-
try!(self.write_message(&frontend::Bind {
537+
try!(self.stream.write_message(&frontend::Bind {
549538
portal: portal_name,
550539
statement: &stmt_name,
551540
formats: &[1],
552541
values: &values,
553542
result_formats: &[1],
554543
}));
555-
try!(self.write_message(&frontend::Execute {
544+
try!(self.stream.write_message(&frontend::Execute {
556545
portal: portal_name,
557546
max_rows: row_limit,
558547
}));
559-
try!(self.write_message(&frontend::Sync));
548+
try!(self.stream.write_message(&frontend::Sync));
560549
try!(self.stream.flush());
561550

562551
match try!(self.read_message()) {
@@ -611,11 +600,11 @@ impl InnerConnection {
611600
}
612601

613602
fn close_statement(&mut self, name: &str, type_: u8) -> Result<()> {
614-
try!(self.write_message(&frontend::Close {
603+
try!(self.stream.write_message(&frontend::Close {
615604
variant: type_,
616605
name: name,
617606
}));
618-
try!(self.write_message(&frontend::Sync));
607+
try!(self.stream.write_message(&frontend::Sync));
619608
try!(self.stream.flush());
620609
let resp = match try!(self.read_message()) {
621610
Backend::CloseComplete => Ok(()),
@@ -815,7 +804,7 @@ impl InnerConnection {
815804
fn quick_query(&mut self, query: &str) -> Result<Vec<Vec<Option<String>>>> {
816805
check_desync!(self);
817806
debug!("executing query: {}", query);
818-
try!(self.write_message(&frontend::Query { query: query }));
807+
try!(self.stream.write_message(&frontend::Query { query: query }));
819808
try!(self.stream.flush());
820809

821810
let mut result = vec![];
@@ -830,10 +819,10 @@ impl InnerConnection {
830819
.collect());
831820
}
832821
Backend::CopyInResponse { .. } => {
833-
try!(self.write_message(&frontend::CopyFail {
822+
try!(self.stream.write_message(&frontend::CopyFail {
834823
message: "COPY queries cannot be directly executed",
835824
}));
836-
try!(self.write_message(&frontend::Sync));
825+
try!(self.stream.write_message(&frontend::Sync));
837826
try!(self.stream.flush());
838827
}
839828
Backend::ErrorResponse { fields } => {
@@ -848,7 +837,7 @@ impl InnerConnection {
848837

849838
fn finish_inner(&mut self) -> Result<()> {
850839
check_desync!(self);
851-
try!(self.write_message(&frontend::Terminate));
840+
try!(self.stream.write_message(&frontend::Terminate));
852841
try!(self.stream.flush());
853842
Ok(())
854843
}

src/priv_io.rs

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,23 +20,67 @@ use io::TlsStream;
2020

2121
const DEFAULT_PORT: u16 = 5432;
2222

23+
pub struct MessageStream {
24+
stream: BufStream<Box<TlsStream>>,
25+
buf: Vec<u8>,
26+
}
27+
28+
impl MessageStream {
29+
pub fn new(stream: Box<TlsStream>) -> MessageStream {
30+
MessageStream {
31+
stream: BufStream::new(stream),
32+
buf: vec![],
33+
}
34+
}
35+
36+
pub fn get_ref(&self) -> &Box<TlsStream> {
37+
self.stream.get_ref()
38+
}
39+
40+
pub fn write_message(&mut self, message: &frontend::Message) -> io::Result<()> {
41+
self.buf.clear();
42+
try!(frontend::Message::write(message, &mut self.buf));
43+
self.stream.write_all(&self.buf)
44+
}
45+
46+
pub fn flush(&mut self) -> io::Result<()> {
47+
self.stream.flush()
48+
}
49+
}
50+
51+
impl io::Read for MessageStream {
52+
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
53+
self.stream.read(buf)
54+
}
55+
}
56+
57+
impl io::BufRead for MessageStream {
58+
fn fill_buf(&mut self) -> io::Result<&[u8]> {
59+
self.stream.fill_buf()
60+
}
61+
62+
fn consume(&mut self, amt: usize) {
63+
self.stream.consume(amt)
64+
}
65+
}
66+
2367
#[doc(hidden)]
2468
pub trait StreamOptions {
2569
fn set_read_timeout(&self, timeout: Option<Duration>) -> io::Result<()>;
2670
fn set_nonblocking(&self, nonblock: bool) -> io::Result<()>;
2771
}
2872

29-
impl StreamOptions for BufStream<Box<TlsStream>> {
73+
impl StreamOptions for MessageStream {
3074
fn set_read_timeout(&self, timeout: Option<Duration>) -> io::Result<()> {
31-
match self.get_ref().get_ref().0 {
75+
match self.stream.get_ref().get_ref().0 {
3276
InternalStream::Tcp(ref s) => s.set_read_timeout(timeout),
3377
#[cfg(unix)]
3478
InternalStream::Unix(ref s) => s.set_read_timeout(timeout),
3579
}
3680
}
3781

3882
fn set_nonblocking(&self, nonblock: bool) -> io::Result<()> {
39-
match self.get_ref().get_ref().0 {
83+
match self.stream.get_ref().get_ref().0 {
4084
InternalStream::Tcp(ref s) => s.set_nonblocking(nonblock),
4185
#[cfg(unix)]
4286
InternalStream::Unix(ref s) => s.set_nonblocking(nonblock),

src/rows.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -351,11 +351,11 @@ impl<'trans, 'stmt> LazyRows<'trans, 'stmt> {
351351
fn execute(&mut self) -> Result<()> {
352352
let mut conn = self.stmt.conn().conn.borrow_mut();
353353

354-
try!(conn.write_message(&frontend::Execute {
354+
try!(conn.stream.write_message(&frontend::Execute {
355355
portal: &self.name,
356356
max_rows: self.row_limit,
357357
}));
358-
try!(conn.write_message(&frontend::Sync));
358+
try!(conn.stream.write_message(&frontend::Sync));
359359
try!(conn.stream.flush());
360360
conn.read_rows(&mut self.data).map(|more_rows| self.more_rows = more_rows)
361361
}

src/stmt.rs

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -147,10 +147,10 @@ impl<'conn> Statement<'conn> {
147147
break;
148148
}
149149
Backend::CopyInResponse { .. } => {
150-
try!(conn.write_message(&frontend::CopyFail {
150+
try!(conn.stream.write_message(&frontend::CopyFail {
151151
message: "COPY queries cannot be directly executed",
152152
}));
153-
try!(conn.write_message(&frontend::Sync));
153+
try!(conn.stream.write_message(&frontend::Sync));
154154
try!(conn.stream.flush());
155155
}
156156
Backend::CopyOutResponse { .. } => {
@@ -297,12 +297,12 @@ impl<'conn> Statement<'conn> {
297297
match fill_copy_buf(&mut buf, r, &info) {
298298
Ok(0) => break,
299299
Ok(len) => {
300-
try!(info.conn.write_message(&frontend::CopyData { data: &buf[..len] }));
300+
try!(info.conn.stream.write_message(&frontend::CopyData { data: &buf[..len] }));
301301
}
302302
Err(err) => {
303-
try!(info.conn.write_message(&frontend::CopyFail { message: "" }));
304-
try!(info.conn.write_message(&frontend::CopyDone));
305-
try!(info.conn.write_message(&frontend::Sync));
303+
try!(info.conn.stream.write_message(&frontend::CopyFail { message: "" }));
304+
try!(info.conn.stream.write_message(&frontend::CopyDone));
305+
try!(info.conn.stream.write_message(&frontend::Sync));
306306
try!(info.conn.stream.flush());
307307
match try!(info.conn.read_message()) {
308308
Backend::ErrorResponse { .. } => {
@@ -319,8 +319,8 @@ impl<'conn> Statement<'conn> {
319319
}
320320
}
321321

322-
try!(info.conn.write_message(&frontend::CopyDone));
323-
try!(info.conn.write_message(&frontend::Sync));
322+
try!(info.conn.stream.write_message(&frontend::CopyDone));
323+
try!(info.conn.stream.write_message(&frontend::Sync));
324324
try!(info.conn.stream.flush());
325325

326326
let num = match try!(info.conn.read_message()) {
@@ -368,9 +368,9 @@ impl<'conn> Statement<'conn> {
368368
let (format, column_formats) = match try!(conn.read_message()) {
369369
Backend::CopyOutResponse { format, column_formats } => (format, column_formats),
370370
Backend::CopyInResponse { .. } => {
371-
try!(conn.write_message(&frontend::CopyFail { message: "" }));
372-
try!(conn.write_message(&frontend::CopyDone));
373-
try!(conn.write_message(&frontend::Sync));
371+
try!(conn.stream.write_message(&frontend::CopyFail { message: "" }));
372+
try!(conn.stream.write_message(&frontend::CopyDone));
373+
try!(conn.stream.write_message(&frontend::Sync));
374374
try!(conn.stream.flush());
375375
match try!(conn.read_message()) {
376376
Backend::ErrorResponse { .. } => {

0 commit comments

Comments
 (0)