Skip to content

Commit 47a1728

Browse files
committed
Special case casting for UNIONs
1 parent f41329d commit 47a1728

2 files changed

Lines changed: 78 additions & 14 deletions

File tree

src/duckdb_py/native/python_conversion.cpp

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,36 @@
1313

1414
namespace duckdb {
1515

16+
// Like DefaultCastAs, but handles UNION targets by finding the first compatible member. DefaultCastAs raises a
17+
// Conversion Error when multiple UNION members have the same type (e.g. UNION(u1 DOUBLE, u2 DOUBLE)), so for UNION
18+
// targets we resolve the member ourselves.
19+
static Value CastToTarget(Value val, const LogicalType &target_type) {
20+
if (target_type.id() != LogicalTypeId::UNION) {
21+
return val.DefaultCastAs(target_type);
22+
}
23+
24+
auto member_count = UnionType::GetMemberCount(target_type);
25+
auto &source_type = val.type();
26+
27+
// First pass: if there's an exact type match we use that
28+
for (idx_t i = 0; i < member_count; i++) {
29+
if (UnionType::GetMemberType(target_type, i) == source_type) {
30+
return Value::UNION(UnionType::CopyMemberTypes(target_type), NumericCast<uint8_t>(i), std::move(val));
31+
}
32+
}
33+
34+
// Second pass: if there's a type we can implicitly cast to, we do that
35+
for (idx_t i = 0; i < member_count; i++) {
36+
auto member_type = UnionType::GetMemberType(target_type, i);
37+
Value candidate = val;
38+
if (candidate.DefaultTryCastAs(member_type)) {
39+
return Value::UNION(UnionType::CopyMemberTypes(target_type), NumericCast<uint8_t>(i), std::move(candidate));
40+
}
41+
}
42+
throw ConversionException("Could not convert value of type %s to %s", source_type.ToString(),
43+
target_type.ToString());
44+
}
45+
1646
static Value EmptyMapValue() {
1747
auto map_type = LogicalType::MAP(LogicalType::SQLNULL, LogicalType::SQLNULL);
1848
return Value::MAP(ListType::GetChildType(map_type), vector<Value>());
@@ -265,7 +295,7 @@ static Value TransformPythonLongToHugeInt(py::handle ele, const LogicalType &tar
265295
if (target_type.id() == LogicalTypeId::UNKNOWN || target_type.id() == LogicalTypeId::HUGEINT) {
266296
return val;
267297
}
268-
return val.DefaultCastAs(target_type);
298+
return CastToTarget(std::move(val), target_type);
269299
}
270300
PyErr_Clear();
271301

@@ -280,7 +310,7 @@ static Value TransformPythonLongToHugeInt(py::handle ele, const LogicalType &tar
280310
if (target_type.id() == LogicalTypeId::UNKNOWN || target_type.id() == LogicalTypeId::UHUGEINT) {
281311
return val;
282312
}
283-
return val.DefaultCastAs(target_type);
313+
return CastToTarget(std::move(val), target_type);
284314
}
285315

286316
void TransformPythonUnsigned(uint64_t value, Value &res) {
@@ -410,7 +440,7 @@ bool TryTransformPythonNumeric(Value &res, py::handle ele, const LogicalType &ta
410440
if (!TrySniffPythonNumeric(res, value)) {
411441
return false;
412442
}
413-
res = res.DefaultCastAs(target_type, true);
443+
res = CastToTarget(std::move(res), target_type);
414444
return true;
415445
}
416446
}
@@ -516,20 +546,20 @@ struct PythonValueConversion {
516546
target_type.ToString());
517547
}
518548
default:
519-
result = Value::DOUBLE(val).DefaultCastAs(target_type);
549+
result = CastToTarget(Value::DOUBLE(val), target_type);
520550
break;
521551
}
522552
}
523553
static void HandleLongAsDouble(Value &result, const LogicalType &target_type, double val) {
524554
auto cast_as = target_type.id() == LogicalTypeId::UNKNOWN ? LogicalType::DOUBLE : target_type;
525-
result = Value::DOUBLE(val).DefaultCastAs(cast_as);
555+
result = CastToTarget(Value::DOUBLE(val), cast_as);
526556
}
527557
static void HandleLongOverflow(Value &result, const LogicalType &target_type, py::handle ele) {
528558
result = TransformPythonLongToHugeInt(ele, target_type);
529559
}
530560
static void HandleUnsignedBigint(Value &result, const LogicalType &target_type, uint64_t val) {
531561
auto cast_as = target_type.id() == LogicalTypeId::UNKNOWN ? LogicalType::UBIGINT : target_type;
532-
result = Value::UBIGINT(val).DefaultCastAs(cast_as);
562+
result = CastToTarget(Value::UBIGINT(val), cast_as);
533563
}
534564
static void HandleBigint(Value &res, const LogicalType &target_type, int64_t value) {
535565
switch (target_type.id()) {
@@ -545,7 +575,7 @@ struct PythonValueConversion {
545575
break;
546576
}
547577
default:
548-
res = Value::BIGINT(value).DefaultCastAs(target_type);
578+
res = CastToTarget(Value::BIGINT(value), target_type);
549579
break;
550580
}
551581
}
@@ -555,7 +585,7 @@ struct PythonValueConversion {
555585
(target_type.id() == LogicalTypeId::VARCHAR && !target_type.HasAlias())) {
556586
result = Value(value);
557587
} else {
558-
result = Value(value).DefaultCastAs(target_type);
588+
result = CastToTarget(Value(value), target_type);
559589
}
560590
}
561591

@@ -692,7 +722,7 @@ struct PythonVectorConversion {
692722
break;
693723
}
694724
default:
695-
FallbackValueConversion(result, result_offset, Value::DOUBLE(val).DefaultCastAs(result.GetType()));
725+
FallbackValueConversion(result, result_offset, CastToTarget(Value::DOUBLE(val), result.GetType()));
696726
break;
697727
}
698728
}
@@ -716,7 +746,7 @@ struct PythonVectorConversion {
716746
FlatVector::GetData<uint64_t>(result)[result_offset] = value;
717747
break;
718748
default:
719-
FallbackValueConversion(result, result_offset, Value::UBIGINT(value));
749+
FallbackValueConversion(result, result_offset, CastToTarget(Value::UBIGINT(value), result.GetType()));
720750
break;
721751
}
722752
}
@@ -787,7 +817,7 @@ struct PythonVectorConversion {
787817
break;
788818
}
789819
default:
790-
FallbackValueConversion(result, result_offset, Value::BIGINT(value));
820+
FallbackValueConversion(result, result_offset, CastToTarget(Value::BIGINT(value), result.GetType()));
791821
break;
792822
}
793823
}

tests/fast/test_type_conversion.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55
Issue #330: Integers >64-bit lose precision via double conversion
66
"""
77

8+
import numpy as np
89
import pytest
910

1011
import duckdb
11-
from duckdb.sqltypes import BIGINT, DOUBLE, HUGEINT, UHUGEINT, VARCHAR, DuckDBPyType
12+
from duckdb.sqltypes import BIGINT, DOUBLE, FLOAT, HUGEINT, UHUGEINT, VARCHAR, DuckDBPyType
1213

1314

1415
class TestIssue115FloatToUnion:
@@ -24,22 +25,55 @@ def test_udf_float_to_union_type(self):
2425
result = conn.sql("SELECT return_float()").fetchone()[0]
2526
assert result == 1.5
2627

28+
def test_udf_float_to_ambiguous_union_type(self):
29+
"""UNION with duplicate DOUBLE members (from np.float64 and float) must not raise ambiguity error."""
30+
conn = duckdb.connect()
31+
conn.create_function(
32+
"return_float",
33+
lambda: 1.5,
34+
return_type=duckdb.union_type({"u1": VARCHAR, "u2": BIGINT, "u3": DOUBLE, "u4": FLOAT, "u5": DOUBLE}),
35+
)
36+
result = conn.sql("SELECT return_float()").fetchone()[0]
37+
assert result == 1.5
38+
2739
def test_udf_dict_with_float_in_union_struct(self):
28-
"""Original repro from issue #115."""
40+
"""Original repro from issue #115 with ambiguous UNION members."""
2941
conn = duckdb.connect()
3042

3143
arr = [{"a": 1, "b": 1.2}, {"a": 3, "b": 2.4}]
3244

3345
def test():
3446
return arr
3547

36-
return_type = DuckDBPyType(list[dict[str, int | float]])
48+
return_type = DuckDBPyType(list[dict[str, str | int | np.float64 | np.float32 | float]])
3749
conn.create_function("test", test, return_type=return_type)
3850
result = conn.sql("SELECT test()").fetchone()[0]
3951
assert len(result) == 2
4052
assert result[0]["b"] == pytest.approx(1.2)
4153
assert result[1]["b"] == pytest.approx(2.4)
4254

55+
def test_udf_int_to_ambiguous_union_type(self):
56+
"""HandleBigint default branch: int into UNION with duplicate BIGINT members."""
57+
conn = duckdb.connect()
58+
conn.create_function(
59+
"return_int",
60+
lambda: 42,
61+
return_type=duckdb.union_type({"u1": VARCHAR, "u2": BIGINT, "u3": BIGINT}),
62+
)
63+
result = conn.sql("SELECT return_int()").fetchone()[0]
64+
assert result == 42
65+
66+
def test_udf_string_to_ambiguous_union_type(self):
67+
"""HandleString default branch: str into UNION with duplicate VARCHAR members."""
68+
conn = duckdb.connect()
69+
conn.create_function(
70+
"return_str",
71+
lambda: "hello",
72+
return_type=duckdb.union_type({"u1": VARCHAR, "u2": BIGINT, "u3": VARCHAR}),
73+
)
74+
result = conn.sql("SELECT return_str()").fetchone()[0]
75+
assert result == "hello"
76+
4377

4478
class TestIssue171DictKeyCaseSensitivity:
4579
"""Dict keys differing only by case must preserve their individual values."""

0 commit comments

Comments
 (0)