Skip to content

Commit 205af89

Browse files
committed
feat: add rows_affected to RowStream
Signed-off-by: Alex Chi <[email protected]>
1 parent f4d8d60 commit 205af89

File tree

3 files changed

+35
-28
lines changed

3 files changed

+35
-28
lines changed

tokio-postgres/src/copy_in.rs

+2-8
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use crate::client::{InnerClient, Responses};
22
use crate::codec::FrontendMessage;
33
use crate::connection::RequestMessages;
4+
use crate::query::extract_row_affected;
45
use crate::{query, slice_iter, Error, Statement};
56
use bytes::{Buf, BufMut, BytesMut};
67
use futures_channel::mpsc;
@@ -110,14 +111,7 @@ where
110111
let this = self.as_mut().project();
111112
match ready!(this.responses.poll_next(cx))? {
112113
Message::CommandComplete(body) => {
113-
let rows = body
114-
.tag()
115-
.map_err(Error::parse)?
116-
.rsplit(' ')
117-
.next()
118-
.unwrap()
119-
.parse()
120-
.unwrap_or(0);
114+
let rows = extract_row_affected(&body)?;
121115
return Poll::Ready(Ok(rows));
122116
}
123117
_ => return Poll::Ready(Err(Error::unexpected_message())),

tokio-postgres/src/query.rs

+31-12
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use bytes::{Bytes, BytesMut};
77
use futures_util::{ready, Stream};
88
use log::{debug, log_enabled, Level};
99
use pin_project_lite::pin_project;
10-
use postgres_protocol::message::backend::Message;
10+
use postgres_protocol::message::backend::{CommandCompleteBody, Message};
1111
use postgres_protocol::message::frontend;
1212
use std::fmt;
1313
use std::marker::PhantomPinned;
@@ -52,6 +52,7 @@ where
5252
Ok(RowStream {
5353
statement,
5454
responses,
55+
rows_affected: None,
5556
_p: PhantomPinned,
5657
})
5758
}
@@ -72,10 +73,24 @@ pub async fn query_portal(
7273
Ok(RowStream {
7374
statement: portal.statement().clone(),
7475
responses,
76+
rows_affected: None,
7577
_p: PhantomPinned,
7678
})
7779
}
7880

81+
/// Extract the number of rows affected from [`CommandCompleteBody`].
82+
pub fn extract_row_affected(body: &CommandCompleteBody) -> Result<u64, Error> {
83+
let rows = body
84+
.tag()
85+
.map_err(Error::parse)?
86+
.rsplit(' ')
87+
.next()
88+
.unwrap()
89+
.parse()
90+
.unwrap_or(0);
91+
Ok(rows)
92+
}
93+
7994
pub async fn execute<P, I>(
8095
client: &InnerClient,
8196
statement: Statement,
@@ -104,14 +119,7 @@ where
104119
match responses.next().await? {
105120
Message::DataRow(_) => {}
106121
Message::CommandComplete(body) => {
107-
rows = body
108-
.tag()
109-
.map_err(Error::parse)?
110-
.rsplit(' ')
111-
.next()
112-
.unwrap()
113-
.parse()
114-
.unwrap_or(0);
122+
rows = extract_row_affected(&body)?;
115123
}
116124
Message::EmptyQueryResponse => rows = 0,
117125
Message::ReadyForQuery(_) => return Ok(rows),
@@ -202,6 +210,7 @@ pin_project! {
202210
pub struct RowStream {
203211
statement: Statement,
204212
responses: Responses,
213+
rows_affected: Option<u64>,
205214
#[pin]
206215
_p: PhantomPinned,
207216
}
@@ -217,12 +226,22 @@ impl Stream for RowStream {
217226
Message::DataRow(body) => {
218227
return Poll::Ready(Some(Ok(Row::new(this.statement.clone(), body)?)))
219228
}
220-
Message::EmptyQueryResponse
221-
| Message::CommandComplete(_)
222-
| Message::PortalSuspended => {}
229+
Message::CommandComplete(body) => {
230+
*this.rows_affected = Some(extract_row_affected(&body)?);
231+
}
232+
Message::EmptyQueryResponse | Message::PortalSuspended => {}
223233
Message::ReadyForQuery(_) => return Poll::Ready(None),
224234
_ => return Poll::Ready(Some(Err(Error::unexpected_message()))),
225235
}
226236
}
227237
}
228238
}
239+
240+
impl RowStream {
241+
/// Returns the number of rows affected by the query.
242+
///
243+
/// This will be `None` if the information is not available yet.
244+
pub fn rows_affected(&self) -> Option<u64> {
245+
self.rows_affected
246+
}
247+
}

tokio-postgres/src/simple_query.rs

+2-8
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use crate::client::{InnerClient, Responses};
22
use crate::codec::FrontendMessage;
33
use crate::connection::RequestMessages;
4+
use crate::query::extract_row_affected;
45
use crate::{Error, SimpleQueryMessage, SimpleQueryRow};
56
use bytes::Bytes;
67
use fallible_iterator::FallibleIterator;
@@ -87,14 +88,7 @@ impl Stream for SimpleQueryStream {
8788
loop {
8889
match ready!(this.responses.poll_next(cx)?) {
8990
Message::CommandComplete(body) => {
90-
let rows = body
91-
.tag()
92-
.map_err(Error::parse)?
93-
.rsplit(' ')
94-
.next()
95-
.unwrap()
96-
.parse()
97-
.unwrap_or(0);
91+
let rows = extract_row_affected(&body)?;
9892
return Poll::Ready(Some(Ok(SimpleQueryMessage::CommandComplete(rows))));
9993
}
10094
Message::EmptyQueryResponse => {

0 commit comments

Comments
 (0)