Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions postgres-derive-test/src/composites.rs
Original file line number Diff line number Diff line change
Expand Up @@ -346,3 +346,38 @@ fn generics() {
},
);
}

#[test]
fn duplicate_composite_field_name_does_not_panic() {
use postgres_types::{Field, Kind, Type};

// A malicious server can report a composite type whose field list contains a
// duplicate name; the generated `from_sql` then leaves one struct field
// unset. This must surface as an error rather than panicking.
#[derive(FromSql, Debug)]
#[allow(dead_code)]
struct Dup {
a: i32,
b: i32,
}

let ty = Type::new(
"Dup".to_string(),
0,
Kind::Composite(vec![
Field::new("a".to_string(), Type::INT4),
Field::new("a".to_string(), Type::INT4),
]),
"public".to_string(),
);

let raw: &[u8] = &[
0, 0, 0, 2, // field count: 2
0, 0, 0, 23, // field 0 oid: INT4
0, 0, 0, 4, 0, 0, 0, 1, // field 0 value: 1
0, 0, 0, 23, // field 1 oid: INT4 (duplicate name "a")
0, 0, 0, 4, 0, 0, 0, 1, // field 1 value: 1
];

assert!(<Dup as FromSql>::from_sql(&ty, raw).is_err());
}
4 changes: 3 additions & 1 deletion postgres-derive/src/fromsql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,9 @@ fn composite_body(ident: &Ident, fields: &[Field]) -> TokenStream {

std::result::Result::Ok(#ident {
#(
#field_idents: #temp_vars.unwrap(),
// A field is left unset if the server's composite type omitted it
// (e.g. reported a duplicate field name); error rather than panic.
#field_idents: #temp_vars.ok_or("composite type is missing a field")?,
)*
})
}
Expand Down
30 changes: 30 additions & 0 deletions postgres-protocol/src/authentication/sasl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,16 @@ use std::str;

const NONCE_LENGTH: usize = 24;

/// The maximum SCRAM iteration count the client will accept from the server.
///
/// The iteration count is sent by the server and drives a PBKDF2 loop, so an
/// unbounded value lets a malicious or impersonating server force the client to
/// perform an arbitrary number of HMAC operations before authentication even
/// completes (a denial of service). 100_000 is ~24x the PostgreSQL default of
/// 4096 and matches the default cap the PostgreSQL JDBC driver (pgjdbc) adopted
/// for the same issue (CVE-2026-42198).
const MAX_ITERATION_COUNT: u32 = 100_000;

/// The identifier of the SCRAM-SHA-256 SASL authentication mechanism.
pub const SCRAM_SHA_256: &str = "SCRAM-SHA-256";
/// The identifier of the SCRAM-SHA-256-PLUS SASL authentication mechanism.
Expand Down Expand Up @@ -192,6 +202,13 @@ impl ScramSha256 {
return Err(io::Error::new(io::ErrorKind::InvalidInput, "invalid nonce"));
}

if parsed.iteration_count > MAX_ITERATION_COUNT {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"SCRAM iteration count exceeds the maximum allowed",
));
}

let salt = match STANDARD.decode(parsed.salt) {
Ok(salt) => salt,
Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)),
Expand Down Expand Up @@ -484,4 +501,17 @@ mod test {

scram.finish(server_final.as_bytes()).unwrap();
}

#[test]
fn excessive_iteration_count_is_rejected() {
// a malicious server cannot force an unbounded PBKDF2 loop; the iteration
// count is rejected before `hi()` runs.
let nonce = "9IZ2O01zb9IgiIZ1WJ/zgpJB";
let server_first =
"r=9IZ2O01zb9IgiIZ1WJ/zgpJBjx/oIRLs02gGSHcw1KEty3eY,s=fs3IXBy7U7+IvVjZ,i=1000000";

let mut scram =
ScramSha256::new_inner(b"foobar", ChannelBinding::unsupported(), nonce.to_string());
assert!(scram.update(server_first.as_bytes()).is_err());
}
}
10 changes: 8 additions & 2 deletions postgres-protocol/src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -273,15 +273,21 @@ impl<'a> FallibleIterator for HstoreEntries<'a> {
if key_len < 0 {
return Err("invalid key length".into());
}
let (key, buf) = self.buf.split_at(key_len as usize);
let (key, buf) = self
.buf
.split_at_checked(key_len as usize)
.ok_or("invalid key length")?;
let key = str::from_utf8(key)?;
self.buf = buf;

let value_len = self.buf.read_i32::<BigEndian>()?;
let value = if value_len < 0 {
None
} else {
let (value, buf) = self.buf.split_at(value_len as usize);
let (value, buf) = self
.buf
.split_at_checked(value_len as usize)
.ok_or("invalid value length")?;
let value = str::from_utf8(value)?;
self.buf = buf;
Some(value)
Expand Down
17 changes: 17 additions & 0 deletions postgres-protocol/src/types/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,23 @@ fn hstore() {
);
}

#[test]
fn hstore_invalid_length() {
// a malicious server can declare a key (or value) length larger than the
// remaining buffer; this must error rather than panic.
let buf: &[u8] = &[
0, 0, 0, 1, // entry count: 1
0, 0, 3, 232, // key length: 1000
b'a', b'b', // only two bytes actually present
];
assert!(
hstore_from_sql(buf)
.unwrap()
.collect::<HashMap<_, _>>()
.is_err()
);
}

#[test]
fn varbit() {
let len = 12;
Expand Down
56 changes: 56 additions & 0 deletions postgres-types/src/time_02.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,30 @@ const fn base() -> PrimitiveDateTime {
PrimitiveDateTime::new(date!(2000-01-01), time!(00:00:00))
}

// `time` 0.2 represents years in the range -100_000..=100_000 and its `Add`
// implementations panic (rather than returning an error) when the result falls
// outside that range. Unlike `time` 0.3 it has no `checked_add`, so the
// resulting Julian day is validated against the representable range before the
// add is performed.
fn date_in_range(julian_day: i64) -> bool {
let min = Date::try_from_ymd(-100_000, 1, 1)
.expect("year is in range")
.julian_day();
let max = Date::try_from_ymd(100_000, 12, 31)
.expect("year is in range")
.julian_day();
(min..=max).contains(&julian_day)
}

impl<'a> FromSql<'a> for PrimitiveDateTime {
fn from_sql(_: &Type, raw: &[u8]) -> Result<PrimitiveDateTime, Box<dyn Error + Sync + Send>> {
let t = types::timestamp_from_sql(raw)?;
// adding the sub-day remainder can shift the date by at most one day, so
// a one-day margin guarantees the add below cannot overflow the range.
let julian_day = base().date().julian_day() + Duration::microseconds(t).whole_days();
if !date_in_range(julian_day - 1) || !date_in_range(julian_day + 1) {
return Err("value too large to decode".into());
}
Ok(base() + Duration::microseconds(t))
}

Expand Down Expand Up @@ -62,6 +83,10 @@ impl ToSql for OffsetDateTime {
impl<'a> FromSql<'a> for Date {
fn from_sql(_: &Type, raw: &[u8]) -> Result<Date, Box<dyn Error + Sync + Send>> {
let jd = types::date_from_sql(raw)?;
let julian_day = base().date().julian_day() + i64::from(jd);
if !date_in_range(julian_day) {
return Err("value too large to decode".into());
}
Ok(base().date() + Duration::days(i64::from(jd)))
}

Expand Down Expand Up @@ -104,3 +129,34 @@ impl ToSql for Time {
accepts!(TIME);
to_sql_checked!();
}

#[cfg(test)]
mod test {
use super::*;

#[test]
fn date_out_of_range_errors() {
// a value that would land outside `time`'s representable year range must
// error rather than panic.
let raw = i32::MAX.to_be_bytes();
assert!(<Date as FromSql>::from_sql(&Type::DATE, &raw).is_err());
}

#[test]
fn date_in_range_decodes() {
let raw = 1_000i32.to_be_bytes();
assert!(<Date as FromSql>::from_sql(&Type::DATE, &raw).is_ok());
}

#[test]
fn timestamp_out_of_range_errors() {
let raw = 9_000_000_000_000_000_000i64.to_be_bytes();
assert!(<PrimitiveDateTime as FromSql>::from_sql(&Type::TIMESTAMP, &raw).is_err());
}

#[test]
fn timestamp_in_range_decodes() {
let raw = 0i64.to_be_bytes();
assert!(<PrimitiveDateTime as FromSql>::from_sql(&Type::TIMESTAMP, &raw).is_ok());
}
}
82 changes: 78 additions & 4 deletions tokio-postgres/src/row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use crate::{Error, Statement};
use fallible_iterator::FallibleIterator;
use postgres_protocol::message::backend::DataRowBody;
use std::fmt;
use std::io;
use std::ops::Range;
use std::str;
use std::sync::Arc;
Expand Down Expand Up @@ -113,11 +114,21 @@ impl fmt::Debug for Row {
impl Row {
pub(crate) fn new(statement: Statement, body: DataRowBody) -> Result<Row, Error> {
let ranges = body.ranges().collect().map_err(Error::parse)?;
Ok(Row {
let row = Row {
statement,
body,
ranges,
})
};
// The DataRow field count is sent by the server independently of the
// RowDescription column count; a mismatch would make column accessors
// index `ranges` out of bounds and panic, so reject it up front.
if row.ranges.len() != row.statement.columns().len() {
return Err(Error::parse(io::Error::new(
io::ErrorKind::InvalidData,
"DataRow field count does not match the number of columns",
)));
}
Ok(row)
}

/// Returns information about the columns of data in the row.
Expand Down Expand Up @@ -217,11 +228,21 @@ impl SimpleQueryRow {
body: DataRowBody,
) -> Result<SimpleQueryRow, Error> {
let ranges = body.ranges().collect().map_err(Error::parse)?;
Ok(SimpleQueryRow {
let row = SimpleQueryRow {
columns,
body,
ranges,
})
};
// The DataRow field count is sent by the server independently of the
// RowDescription column count; a mismatch would make column accessors
// index `ranges` out of bounds and panic, so reject it up front.
if row.ranges.len() != row.columns.len() {
return Err(Error::parse(io::Error::new(
io::ErrorKind::InvalidData,
"DataRow field count does not match the number of columns",
)));
}
Ok(row)
}

/// Returns information about the columns of data in the row.
Expand Down Expand Up @@ -278,3 +299,56 @@ impl SimpleQueryRow {
FromSql::from_sql_nullable(&Type::TEXT, buf).map_err(|e| Error::from_sql(e, idx))
}
}

#[cfg(test)]
mod test {
use bytes::BytesMut;
use postgres_protocol::message::backend::{DataRowBody, Message};

use super::*;

fn data_row(field_count: u16, fields: &[&[u8]]) -> DataRowBody {
let mut body = BytesMut::new();
body.extend_from_slice(&field_count.to_be_bytes());
for field in fields {
body.extend_from_slice(&(field.len() as i32).to_be_bytes());
body.extend_from_slice(field);
}

let mut buf = BytesMut::new();
buf.extend_from_slice(b"D");
buf.extend_from_slice(&(body.len() as i32 + 4).to_be_bytes());
buf.extend_from_slice(&body);

match Message::parse(&mut buf).unwrap().unwrap() {
Message::DataRow(body) => body,
_ => unreachable!("expected DataRow"),
}
}

fn column(name: &str) -> Column {
Column {
name: name.to_string(),
table_oid: None,
column_id: None,
type_modifier: 0,
r#type: Type::TEXT,
}
}

#[test]
fn fewer_data_row_fields_than_columns_is_rejected() {
// a server advertising two columns but sending a DataRow with a single
// field would make column accessors index out of bounds and panic.
let body = data_row(1, &[b""]);
let statement = Statement::unnamed(vec![], vec![column("a"), column("b")]);
assert!(Row::new(statement, body).is_err());
}

#[test]
fn matching_data_row_field_count_is_accepted() {
let body = data_row(2, &[b"x", b"y"]);
let statement = Statement::unnamed(vec![], vec![column("a"), column("b")]);
assert!(Row::new(statement, body).is_ok());
}
}
Loading