Skip to content

Commit ce85084

Browse files
feat: add fixed size list support (apache#1231)
1 parent 39980e8 commit ce85084

4 files changed

Lines changed: 36 additions & 10 deletions

File tree

src/ast/data_type.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,8 @@ impl fmt::Display for DataType {
349349
DataType::Bytea => write!(f, "BYTEA"),
350350
DataType::Array(ty) => match ty {
351351
ArrayElemTypeDef::None => write!(f, "ARRAY"),
352-
ArrayElemTypeDef::SquareBracket(t) => write!(f, "{t}[]"),
352+
ArrayElemTypeDef::SquareBracket(t, None) => write!(f, "{t}[]"),
353+
ArrayElemTypeDef::SquareBracket(t, Some(size)) => write!(f, "{t}[{size}]"),
353354
ArrayElemTypeDef::AngleBracket(t) => write!(f, "ARRAY<{t}>"),
354355
},
355356
DataType::Custom(ty, modifiers) => {
@@ -592,6 +593,6 @@ pub enum ArrayElemTypeDef {
592593
None,
593594
/// `ARRAY<INT>`
594595
AngleBracket(Box<DataType>),
595-
/// `[]INT`
596-
SquareBracket(Box<DataType>),
596+
/// `INT[]` or `INT[2]`
597+
SquareBracket(Box<DataType>, Option<u64>),
597598
}

src/parser/mod.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6360,7 +6360,7 @@ impl<'a> Parser<'a> {
63606360
&mut self,
63616361
) -> Result<(DataType, MatchedTrailingBracket), ParserError> {
63626362
let next_token = self.next_token();
6363-
let mut trailing_bracket = false.into();
6363+
let mut trailing_bracket: MatchedTrailingBracket = false.into();
63646364
let mut data = match next_token.token {
63656365
Token::Word(w) => match w.keyword {
63666366
Keyword::BOOLEAN => Ok(DataType::Boolean),
@@ -6580,8 +6580,13 @@ impl<'a> Parser<'a> {
65806580
// Parse array data types. Note: this is postgresql-specific and different from
65816581
// Keyword::ARRAY syntax from above
65826582
while self.consume_token(&Token::LBracket) {
6583+
let size = if dialect_of!(self is GenericDialect | DuckDbDialect | PostgreSqlDialect) {
6584+
self.maybe_parse(|p| p.parse_literal_uint())
6585+
} else {
6586+
None
6587+
};
65836588
self.expect_token(&Token::RBracket)?;
6584-
data = DataType::Array(ArrayElemTypeDef::SquareBracket(Box::new(data)))
6589+
data = DataType::Array(ArrayElemTypeDef::SquareBracket(Box::new(data), size))
65856590
}
65866591
Ok((data, trailing_bracket))
65876592
}

tests/sqlparser_common.rs

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3132,7 +3132,7 @@ fn parse_create_table_hive_array() {
31323132
let expected = if angle_bracket_syntax {
31333133
ArrayElemTypeDef::AngleBracket(expected)
31343134
} else {
3135-
ArrayElemTypeDef::SquareBracket(expected)
3135+
ArrayElemTypeDef::SquareBracket(expected, None)
31363136
};
31373137

31383138
match dialects.one_statement_parses_to(sql.as_str(), sql.as_str()) {
@@ -9257,3 +9257,21 @@ fn test_select_wildcard_with_replace() {
92579257
});
92589258
assert_eq!(expected, select.projection[0]);
92599259
}
9260+
9261+
#[test]
9262+
fn parse_sized_list() {
9263+
let dialects = TestedDialects {
9264+
dialects: vec![
9265+
Box::new(GenericDialect {}),
9266+
Box::new(PostgreSqlDialect {}),
9267+
Box::new(DuckDbDialect {}),
9268+
],
9269+
options: None,
9270+
};
9271+
let sql = r#"CREATE TABLE embeddings (data FLOAT[1536])"#;
9272+
dialects.verified_stmt(sql);
9273+
let sql = r#"CREATE TABLE embeddings (data FLOAT[1536][3])"#;
9274+
dialects.verified_stmt(sql);
9275+
let sql = r#"SELECT data::FLOAT[1536] FROM embeddings"#;
9276+
dialects.verified_stmt(sql);
9277+
}

tests/sqlparser_postgres.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1917,11 +1917,13 @@ fn parse_array_index_expr() {
19171917
})],
19181918
named: true,
19191919
})),
1920-
data_type: DataType::Array(ArrayElemTypeDef::SquareBracket(Box::new(
1921-
DataType::Array(ArrayElemTypeDef::SquareBracket(Box::new(DataType::Int(
1920+
data_type: DataType::Array(ArrayElemTypeDef::SquareBracket(
1921+
Box::new(DataType::Array(ArrayElemTypeDef::SquareBracket(
1922+
Box::new(DataType::Int(None)),
19221923
None
1923-
))))
1924-
))),
1924+
))),
1925+
None
1926+
)),
19251927
format: None,
19261928
}))),
19271929
indexes: vec![num[1].clone(), num[2].clone()],

0 commit comments

Comments
 (0)