Skip to content

Commit 5d3cfba

Browse files
authored
Add ST_SRID (#31)
1 parent 2ee2e7c commit 5d3cfba

4 files changed

Lines changed: 194 additions & 4 deletions

File tree

rust/sedona-functions/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ mod st_perimeter;
4343
mod st_point;
4444
mod st_pointzm;
4545
mod st_setsrid;
46+
mod st_srid;
4647
mod st_transform;
4748
pub mod st_union_aggr;
4849
mod st_xyzm;

rust/sedona-functions/src/register.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ pub fn default_function_set() -> FunctionSet {
8686
crate::st_pointzm::st_pointzm_udf,
8787
crate::st_transform::st_transform_udf,
8888
crate::st_setsrid::st_set_srid_udf,
89+
crate::st_srid::st_srid_udf,
8990
crate::st_xyzm::st_m_udf,
9091
crate::st_xyzm::st_x_udf,
9192
crate::st_xyzm::st_y_udf,
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
use arrow_array::builder::UInt32Builder;
18+
use std::{sync::Arc, vec};
19+
20+
use crate::executor::WkbExecutor;
21+
use arrow_schema::DataType;
22+
use datafusion_common::{DataFusionError, Result};
23+
use datafusion_expr::{
24+
scalar_doc_sections::DOC_SECTION_OTHER, ColumnarValue, Documentation, Volatility,
25+
};
26+
use sedona_expr::scalar_udf::{ArgMatcher, SedonaScalarKernel, SedonaScalarUDF};
27+
use sedona_schema::datatypes::SedonaType;
28+
29+
/// ST_Srid() scalar UDF implementation
30+
///
31+
/// Scalar function to return the SRID of a geometry or geography
32+
pub fn st_srid_udf() -> SedonaScalarUDF {
33+
SedonaScalarUDF::new(
34+
"st_srid",
35+
vec![Arc::new(StSrid {})],
36+
Volatility::Immutable,
37+
Some(st_srid_doc()),
38+
)
39+
}
40+
41+
fn st_srid_doc() -> Documentation {
42+
Documentation::builder(
43+
DOC_SECTION_OTHER,
44+
"Return the spatial reference system identifier (SRID) of the geometry.",
45+
"ST_SRID (geom: Geometry)",
46+
)
47+
.with_argument("geom", "geometry: Input geometry or geography")
48+
.with_sql_example("SELECT ST_SRID(polygon)".to_string())
49+
.build()
50+
}
51+
52+
#[derive(Debug)]
53+
struct StSrid {}
54+
55+
impl SedonaScalarKernel for StSrid {
56+
fn return_type(&self, args: &[SedonaType]) -> Result<Option<SedonaType>> {
57+
let matcher = ArgMatcher::new(
58+
vec![ArgMatcher::is_geometry_or_geography()],
59+
SedonaType::Arrow(DataType::UInt32),
60+
);
61+
62+
matcher.match_args(args)
63+
}
64+
65+
fn invoke_batch(
66+
&self,
67+
arg_types: &[SedonaType],
68+
args: &[ColumnarValue],
69+
) -> Result<ColumnarValue> {
70+
let executor = WkbExecutor::new(arg_types, args);
71+
let mut builder = UInt32Builder::with_capacity(executor.num_iterations());
72+
let srid_opt = match &arg_types[0] {
73+
SedonaType::Wkb(_, Some(crs)) | SedonaType::WkbView(_, Some(crs)) => {
74+
match crs.srid()? {
75+
Some(srid) => Some(srid),
76+
None => return Err(DataFusionError::Execution("CRS has no SRID".to_string())),
77+
}
78+
}
79+
_ => Some(0),
80+
};
81+
82+
executor.execute_wkb_void(|maybe_wkb| {
83+
match maybe_wkb {
84+
Some(_wkb) => {
85+
builder.append_option(srid_opt);
86+
}
87+
_ => builder.append_null(),
88+
}
89+
90+
Ok(())
91+
})?;
92+
93+
executor.finish(Arc::new(builder.finish()))
94+
}
95+
}
96+
97+
#[cfg(test)]
98+
mod test {
99+
use datafusion_common::ScalarValue;
100+
use datafusion_expr::ScalarUDF;
101+
use sedona_schema::crs::deserialize_crs;
102+
use sedona_schema::datatypes::Edges;
103+
use sedona_testing::testers::ScalarUdfTester;
104+
use std::str::FromStr;
105+
106+
use super::*;
107+
108+
#[test]
109+
fn udf_metadata() {
110+
let udf: ScalarUDF = st_srid_udf().into();
111+
assert_eq!(udf.name(), "st_srid");
112+
assert!(udf.documentation().is_some())
113+
}
114+
115+
#[test]
116+
fn udf() {
117+
let udf: ScalarUDF = st_srid_udf().into();
118+
119+
// Test that when no CRS is set, SRID is 0
120+
let sedona_type = SedonaType::Wkb(Edges::Planar, None);
121+
let tester = ScalarUdfTester::new(udf.clone(), vec![sedona_type]);
122+
tester.assert_return_type(DataType::UInt32);
123+
let result = tester
124+
.invoke_scalar("POLYGON ((0 0, 1 0, 0 1, 0 0))")
125+
.unwrap();
126+
tester.assert_scalar_result_equals(result, 0_u32);
127+
128+
// Test that NULL input returns NULL output
129+
let result = tester.invoke_scalar(ScalarValue::Null).unwrap();
130+
tester.assert_scalar_result_equals(result, ScalarValue::Null);
131+
132+
// Test with a CRS with an EPSG code
133+
let crs_value = serde_json::Value::String("EPSG:4837".to_string());
134+
let crs = deserialize_crs(&crs_value).unwrap();
135+
let sedona_type = SedonaType::Wkb(Edges::Planar, crs.clone());
136+
let tester = ScalarUdfTester::new(udf.clone(), vec![sedona_type]);
137+
let result = tester
138+
.invoke_scalar("POLYGON ((0 0, 1 0, 0 1, 0 0))")
139+
.unwrap();
140+
tester.assert_scalar_result_equals(result, 4837_u32);
141+
142+
// Test with a CRS but null geom
143+
let result = tester.invoke_scalar(ScalarValue::Null).unwrap();
144+
tester.assert_scalar_result_equals(result, ScalarValue::Null);
145+
146+
// Call with a CRS with no SRID (should error)
147+
let crs_value = serde_json::Value::from_str("{}");
148+
let crs = deserialize_crs(&crs_value.unwrap()).unwrap();
149+
let sedona_type = SedonaType::Wkb(Edges::Planar, crs.clone());
150+
let tester = ScalarUdfTester::new(udf.clone(), vec![sedona_type]);
151+
let result = tester.invoke_scalar("POINT (0 1)");
152+
assert!(result.is_err());
153+
assert!(result.unwrap_err().to_string().contains("CRS has no SRID"));
154+
}
155+
}

rust/sedona-schema/src/crs.rs

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ pub trait CoordinateReferenceSystem: Debug {
8888
fn to_json(&self) -> String;
8989
fn to_authority_code(&self) -> Result<Option<String>>;
9090
fn crs_equals(&self, other: &dyn CoordinateReferenceSystem) -> bool;
91+
fn srid(&self) -> Result<Option<u32>>;
9192
}
9293

9394
/// Concrete implementation of a default longitude/latitude coordinate reference system
@@ -207,6 +208,15 @@ impl CoordinateReferenceSystem for AuthorityCode {
207208
(_, _) => false,
208209
}
209210
}
211+
212+
/// Get the SRID if authority is EPSG
213+
fn srid(&self) -> Result<Option<u32>> {
214+
if self.authority.eq_ignore_ascii_case("EPSG") {
215+
Ok(self.code.parse::<u32>().ok())
216+
} else {
217+
Ok(None)
218+
}
219+
}
210220
}
211221

212222
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
@@ -272,13 +282,28 @@ impl CoordinateReferenceSystem for ProjJSON {
272282
false
273283
}
274284
}
285+
286+
fn srid(&self) -> Result<Option<u32>> {
287+
let authority_code_opt = self.to_authority_code()?;
288+
if let Some(authority_code) = authority_code_opt {
289+
if LngLat::is_authority_code_lnglat(&authority_code) {
290+
return Ok(Some(4326));
291+
}
292+
if let Some((_, code)) = AuthorityCode::split_auth_code(&authority_code) {
293+
return Ok(code.parse::<u32>().ok());
294+
}
295+
}
296+
297+
Ok(None)
298+
}
275299
}
276300

277301
pub const OGC_CRS84_PROJJSON: &str = r#"{"$schema":"https://proj.org/schemas/v0.7/projjson.schema.json","type":"GeographicCRS","name":"WGS 84 (CRS84)","datum_ensemble":{"name":"World Geodetic System 1984 ensemble","members":[{"name":"World Geodetic System 1984 (Transit)","id":{"authority":"EPSG","code":1166}},{"name":"World Geodetic System 1984 (G730)","id":{"authority":"EPSG","code":1152}},{"name":"World Geodetic System 1984 (G873)","id":{"authority":"EPSG","code":1153}},{"name":"World Geodetic System 1984 (G1150)","id":{"authority":"EPSG","code":1154}},{"name":"World Geodetic System 1984 (G1674)","id":{"authority":"EPSG","code":1155}},{"name":"World Geodetic System 1984 (G1762)","id":{"authority":"EPSG","code":1156}},{"name":"World Geodetic System 1984 (G2139)","id":{"authority":"EPSG","code":1309}},{"name":"World Geodetic System 1984 (G2296)","id":{"authority":"EPSG","code":1383}}],"ellipsoid":{"name":"WGS 84","semi_major_axis":6378137,"inverse_flattening":298.257223563},"accuracy":"2.0","id":{"authority":"EPSG","code":6326}},"coordinate_system":{"subtype":"ellipsoidal","axis":[{"name":"Geodetic longitude","abbreviation":"Lon","direction":"east","unit":"degree"},{"name":"Geodetic latitude","abbreviation":"Lat","direction":"north","unit":"degree"}]},"scope":"Not known.","area":"World.","bbox":{"south_latitude":-90,"west_longitude":-180,"north_latitude":90,"east_longitude":180},"id":{"authority":"OGC","code":"CRS84"}}"#;
278302

279303
#[cfg(test)]
280304
mod test {
281305
use super::*;
306+
const EPSG_6318_PROJJSON: &str = r#"{"$schema": "https://proj.org/schemas/v0.4/projjson.schema.json","type": "GeographicCRS","name": "NAD83(2011)","datum": {"type": "GeodeticReferenceFrame","name": "NAD83 (National Spatial Reference System 2011)","ellipsoid": {"name": "GRS 1980","semi_major_axis": 6378137,"inverse_flattening": 298.257222101}},"coordinate_system": {"subtype": "ellipsoidal","axis": [{"name": "Geodetic latitude","abbreviation": "Lat","direction": "north","unit": "degree"},{"name": "Geodetic longitude","abbreviation": "Lon","direction": "east","unit": "degree"}]},"scope": "Horizontal component of 3D system.","area": "Puerto Rico - onshore and offshore. United States (USA) onshore and offshore - Alabama; Alaska; Arizona; Arkansas; California; Colorado; Connecticut; Delaware; Florida; Georgia; Idaho; Illinois; Indiana; Iowa; Kansas; Kentucky; Louisiana; Maine; Maryland; Massachusetts; Michigan; Minnesota; Mississippi; Missouri; Montana; Nebraska; Nevada; New Hampshire; New Jersey; New Mexico; New York; North Carolina; North Dakota; Ohio; Oklahoma; Oregon; Pennsylvania; Rhode Island; South Carolina; South Dakota; Tennessee; Texas; Utah; Vermont; Virginia; Washington; West Virginia; Wisconsin; Wyoming. US Virgin Islands - onshore and offshore.", "bbox": {"south_latitude": 14.92,"west_longitude": 167.65,"north_latitude": 74.71,"east_longitude": -63.88},"id": {"authority": "EPSG","code": 6318}}"#;
282307

283308
#[test]
284309
fn deserialize() {
@@ -304,6 +329,7 @@ mod test {
304329
fn crs_projjson() {
305330
let projjson = OGC_CRS84_PROJJSON.parse::<ProjJSON>().unwrap();
306331
assert_eq!(projjson.to_authority_code().unwrap().unwrap(), "OGC:CRS84");
332+
assert_eq!(projjson.srid().unwrap(), Some(4326));
307333

308334
let json_value: Value = serde_json::from_str(OGC_CRS84_PROJJSON).unwrap();
309335
let json_value_roundtrip: Value = serde_json::from_str(&projjson.to_json()).unwrap();
@@ -317,7 +343,11 @@ mod test {
317343
.to_authority_code()
318344
.unwrap()
319345
.is_none());
320-
assert!(!projjson.crs_equals(&projjson_without_identifier))
346+
assert!(!projjson.crs_equals(&projjson_without_identifier));
347+
348+
let projjson = EPSG_6318_PROJJSON.parse::<ProjJSON>().unwrap();
349+
assert_eq!(projjson.to_authority_code().unwrap().unwrap(), "EPSG:6318");
350+
assert_eq!(projjson.srid().unwrap(), Some(6318));
321351
}
322352

323353
#[test]
@@ -328,6 +358,7 @@ mod test {
328358
};
329359
assert!(auth_code.crs_equals(&auth_code));
330360
assert!(!auth_code.crs_equals(LngLat::crs().unwrap().as_ref()));
361+
assert_eq!(auth_code.srid().unwrap(), Some(4269));
331362

332363
assert_eq!(
333364
auth_code.to_authority_code().unwrap(),
@@ -345,18 +376,20 @@ mod test {
345376
assert!(AuthorityCode::is_authority_code(&auth_code_parsed.unwrap()));
346377

347378
let value: Value = serde_json::from_str("\"EPSG:4269\"").unwrap();
348-
let new_crs = deserialize_crs(&value).unwrap();
379+
let new_crs = deserialize_crs(&value).unwrap().unwrap();
349380
assert_eq!(
350-
new_crs.unwrap().to_authority_code().unwrap(),
381+
new_crs.to_authority_code().unwrap(),
351382
Some("EPSG:4269".to_string())
352383
);
384+
assert_eq!(new_crs.srid().unwrap(), Some(4269));
353385

354386
// Ensure we can also just pass a code here
355387
let value: Value = serde_json::from_str("\"4269\"").unwrap();
356388
let new_crs = deserialize_crs(&value).unwrap();
357389
assert_eq!(
358-
new_crs.unwrap().to_authority_code().unwrap(),
390+
new_crs.clone().unwrap().to_authority_code().unwrap(),
359391
Some("EPSG:4269".to_string())
360392
);
393+
assert_eq!(new_crs.unwrap().srid().unwrap(), Some(4269));
361394
}
362395
}

0 commit comments

Comments
 (0)