|
3 | 3 | import datetime |
4 | 4 | from typing import Any |
5 | 5 |
|
| 6 | +from dqlitewire.constants import ValueType |
| 7 | + |
6 | 8 |
|
7 | 9 | # Type constructors |
8 | 10 | def Date(year: int, month: int, day: int) -> datetime.date: # noqa: N802 |
@@ -42,27 +44,69 @@ def Binary(data: bytes) -> bytes: # noqa: N802 |
42 | 44 | return bytes(data) |
43 | 45 |
|
44 | 46 |
|
45 | | -# Type objects for column type checking |
| 47 | +# Type objects for column type checking. |
| 48 | +# |
| 49 | +# PEP 249: "These objects represent a data type as represented in the |
| 50 | +# database. The module exports these objects: STRING, BINARY, NUMBER, |
| 51 | +# DATETIME, ROWID. The module should export a comparison for these types |
| 52 | +# and the object returned in Cursor.description[i][1]." |
| 53 | +# |
| 54 | +# Cursor.description[i][1] here is a wire-level ``ValueType`` integer |
| 55 | +# (e.g. 10 for ISO8601). The type objects below compare equal to both |
| 56 | +# the uppercase SQL type name strings (for declared-type matching) and |
| 57 | +# the matching ``ValueType`` ints. |
46 | 58 | class _DBAPIType: |
47 | | - """Base type for DB-API type objects.""" |
| 59 | + """Base type for DB-API type objects. Compares equal to matching |
| 60 | + uppercase SQL type names (str) and wire-level ``ValueType`` codes |
| 61 | + (int). |
| 62 | + """ |
48 | 63 |
|
49 | | - def __init__(self, *values: str) -> None: |
50 | | - self.values = set(values) |
| 64 | + def __init__(self, *values: str | int | ValueType) -> None: |
| 65 | + normalized: set[str | int] = set() |
| 66 | + for v in values: |
| 67 | + if isinstance(v, ValueType): |
| 68 | + normalized.add(int(v)) |
| 69 | + else: |
| 70 | + normalized.add(v) |
| 71 | + self.values = normalized |
51 | 72 |
|
52 | 73 | def __eq__(self, other: object) -> bool: |
53 | 74 | if isinstance(other, str): |
54 | 75 | return other.upper() in self.values |
| 76 | + if isinstance(other, ValueType): |
| 77 | + return int(other) in self.values |
| 78 | + if isinstance(other, int) and not isinstance(other, bool): |
| 79 | + return other in self.values |
55 | 80 | return NotImplemented |
56 | 81 |
|
57 | 82 | def __hash__(self) -> int: |
58 | 83 | return hash(frozenset(self.values)) |
59 | 84 |
|
60 | 85 |
|
61 | | -STRING = _DBAPIType("TEXT", "VARCHAR", "CHAR", "CLOB") |
62 | | -BINARY = _DBAPIType("BLOB", "BINARY", "VARBINARY") |
63 | | -NUMBER = _DBAPIType("INTEGER", "INT", "SMALLINT", "BIGINT", "REAL", "FLOAT", "DOUBLE", "NUMERIC") |
64 | | -DATETIME = _DBAPIType("DATE", "TIME", "TIMESTAMP", "DATETIME") |
65 | | -ROWID = _DBAPIType("ROWID", "INTEGER PRIMARY KEY") |
| 86 | +STRING = _DBAPIType("TEXT", "VARCHAR", "CHAR", "CLOB", ValueType.TEXT) |
| 87 | +BINARY = _DBAPIType("BLOB", "BINARY", "VARBINARY", ValueType.BLOB) |
| 88 | +NUMBER = _DBAPIType( |
| 89 | + "INTEGER", |
| 90 | + "INT", |
| 91 | + "SMALLINT", |
| 92 | + "BIGINT", |
| 93 | + "REAL", |
| 94 | + "FLOAT", |
| 95 | + "DOUBLE", |
| 96 | + "NUMERIC", |
| 97 | + ValueType.INTEGER, |
| 98 | + ValueType.FLOAT, |
| 99 | + ValueType.BOOLEAN, |
| 100 | +) |
| 101 | +DATETIME = _DBAPIType( |
| 102 | + "DATE", |
| 103 | + "TIME", |
| 104 | + "TIMESTAMP", |
| 105 | + "DATETIME", |
| 106 | + ValueType.ISO8601, |
| 107 | + ValueType.UNIXTIME, |
| 108 | +) |
| 109 | +ROWID = _DBAPIType("ROWID", "INTEGER PRIMARY KEY", ValueType.INTEGER) |
66 | 110 |
|
67 | 111 |
|
68 | 112 | # Internal conversion helpers. |
|
0 commit comments