Skip to content

Commit f08b4e8

Browse files
authored
feat: Optional input_schema for ODFV (#6308) (#6312)
1 parent a310eaf commit f08b4e8

3 files changed

Lines changed: 308 additions & 11 deletions

File tree

docs/reference/beta-on-demand-feature-view.md

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,42 @@ def driver_aggregated_stats(inputs):
6969

7070
Aggregated columns are automatically named using the pattern `{function}_{column}` (e.g., `sum_trips`, `mean_rating`).
7171

72+
### Using `input_schema` with Aggregations
73+
74+
When the input data is not already stored as a feature view, use `input_schema` instead of `sources` to describe the fields that will be passed at request time. Feast will create an internal `RequestSource` automatically.
75+
76+
```python
77+
from datetime import timedelta
78+
from feast import Field, on_demand_feature_view
79+
from feast.aggregation import Aggregation
80+
from feast.types import Float64, Int64
81+
82+
@on_demand_feature_view(
83+
input_schema=[
84+
Field(name="txn_amount", dtype=Float64),
85+
],
86+
schema=[
87+
Field(name="txn_count", dtype=Int64),
88+
Field(name="total_txn_amount", dtype=Float64),
89+
Field(name="avg_txn_amount", dtype=Float64),
90+
],
91+
aggregations=[
92+
Aggregation(column="txn_amount", function="count", name="txn_count",
93+
time_window=timedelta(days=30)),
94+
Aggregation(column="txn_amount", function="sum", name="total_txn_amount",
95+
time_window=timedelta(days=30)),
96+
Aggregation(column="txn_amount", function="mean", name="avg_txn_amount",
97+
time_window=timedelta(days=30)),
98+
],
99+
entities=[user],
100+
)
101+
def user_transaction_stats(inputs):
102+
# Aggregations replace the transformation function — no body needed.
103+
pass
104+
```
105+
106+
`input_schema` also accepts fields that are not aggregation columns — for example, thresholds, currency codes, or other contextual values passed at request time that your UDF needs but that are not stored as features.
107+
72108
## Example
73109
See [https://github.com/feast-dev/on-demand-feature-views-demo](https://github.com/feast-dev/on-demand-feature-views-demo) for an example on how to use on demand feature views.
74110

sdk/python/feast/on_demand_feature_view.py

Lines changed: 105 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ class OnDemandFeatureView(BaseFeatureView):
134134
"""
135135

136136
_TRACK_METRICS_TAG = "feast:track_metrics"
137+
_INPUT_SCHEMA_SOURCE_PREFIX = "__input_schema__"
137138

138139
name: str
139140
entities: Optional[List[str]]
@@ -158,7 +159,8 @@ def __init__( # noqa: C901
158159
name: str,
159160
entities: Optional[List[Entity]] = None,
160161
schema: Optional[List[Field]] = None,
161-
sources: List[OnDemandSourceType],
162+
sources: Optional[List[OnDemandSourceType]] = None,
163+
input_schema: Optional[List[Field]] = None,
162164
udf: Optional[FunctionType] = None,
163165
udf_string: Optional[str] = "",
164166
feature_transformation: Optional[Transformation] = None,
@@ -183,6 +185,11 @@ def __init__( # noqa: C901
183185
sources: A map from input source names to the actual input sources, which may be
184186
feature views, or request data sources. These sources serve as inputs to the udf,
185187
which will refer to them by name.
188+
input_schema (optional): A list of Fields describing data that is accepted as input
189+
but not stored directly as features — e.g. aggregation columns, normalization
190+
parameters, thresholds, or other contextual values passed at request time.
191+
When provided, sources is not required — an internal RequestSource will be
192+
created automatically.
186193
udf: The user defined transformation function, which must take pandas
187194
dataframes as inputs.
188195
udf_string: The source code version of the udf (for diffing and displaying in Web UI)
@@ -214,15 +221,44 @@ def __init__( # noqa: C901
214221
self.version = version
215222
schema = schema or []
216223
self.entities = [e.name for e in entities] if entities else [DUMMY_ENTITY_NAME]
217-
self.sources = sources
224+
self.input_schema = input_schema
218225
self.mode = mode.lower()
219226
self.udf = udf
220227
self.udf_string = udf_string
221228
self.source_feature_view_projections: dict[str, FeatureViewProjection] = {}
222229
self.source_request_sources: dict[str, RequestSource] = {}
230+
self._input_schema_sentinel: Optional[RequestSource] = None
231+
232+
# Strip any existing sentinel from sources (handles __copy__ round-trip)
233+
effective_sources: List[OnDemandSourceType] = [
234+
s
235+
for s in (sources or [])
236+
if not (
237+
isinstance(s, RequestSource)
238+
and s.name.startswith(self._INPUT_SCHEMA_SOURCE_PREFIX)
239+
)
240+
]
241+
242+
if input_schema is not None:
243+
# Automatically create an internal RequestSource from input_schema.
244+
# Stored privately so it does not appear in source_request_sources for
245+
# external consumers (e.g. the feature server, apply(), utils.py).
246+
self._input_schema_sentinel = RequestSource(
247+
name=f"{self._INPUT_SCHEMA_SOURCE_PREFIX}{name}",
248+
schema=input_schema,
249+
)
250+
self.source_request_sources[self._input_schema_sentinel.name] = (
251+
self._input_schema_sentinel
252+
)
253+
elif not effective_sources:
254+
raise ValueError(
255+
"Either 'sources' or 'input_schema' must be provided for OnDemandFeatureView."
256+
)
257+
258+
self.sources = effective_sources
223259

224260
# Process each source with explicit type handling
225-
for odfv_source in sources:
261+
for odfv_source in effective_sources:
226262
self._add_source_to_collections(odfv_source)
227263

228264
features: List[Field] = []
@@ -274,6 +310,20 @@ def __init__( # noqa: C901
274310
self.track_metrics = track_metrics
275311
self.aggregations = aggregations or []
276312

313+
if input_schema is not None and self.aggregations:
314+
input_field_names = {f.name for f in input_schema}
315+
unknown = [
316+
agg.column
317+
for agg in self.aggregations
318+
if agg.column and agg.column not in input_field_names
319+
]
320+
if unknown:
321+
raise ValueError(
322+
f"Aggregation column(s) {unknown} not found in input_schema "
323+
f"for OnDemandFeatureView '{name}'. "
324+
f"Available fields: {sorted(input_field_names)}"
325+
)
326+
277327
def _add_source_to_collections(self, odfv_source: OnDemandSourceType) -> None:
278328
"""
279329
Add a source to the appropriate collection with explicit type checking.
@@ -328,6 +378,7 @@ def __copy__(self):
328378
schema=self.features,
329379
sources=list(self.source_feature_view_projections.values())
330380
+ list(self.source_request_sources.values()),
381+
input_schema=self.input_schema,
331382
feature_transformation=self.feature_transformation,
332383
mode=self.mode,
333384
description=self.description,
@@ -337,6 +388,7 @@ def __copy__(self):
337388
singleton=self.singleton,
338389
version=self.version,
339390
track_metrics=self.track_metrics,
391+
aggregations=self.aggregations,
340392
)
341393
fv.entities = self.entities
342394
fv.features = self.features
@@ -536,6 +588,14 @@ def to_proto(self) -> OnDemandFeatureViewProto:
536588
request_data_source=request_sources.to_proto()
537589
)
538590

591+
# Serialize the input_schema sentinel so that from_proto() can reconstruct
592+
# input_schema correctly; it is excluded from source_request_sources so that
593+
# external consumers never see it as a real data source.
594+
if self._input_schema_sentinel is not None:
595+
sources[self._input_schema_sentinel.name] = OnDemandSource(
596+
request_data_source=self._input_schema_sentinel.to_proto()
597+
)
598+
539599
feature_transformation = transformation_to_proto(self.feature_transformation)
540600

541601
tags = dict(self.tags) if self.tags else {}
@@ -559,7 +619,7 @@ def to_proto(self) -> OnDemandFeatureViewProto:
559619
owner=self.owner,
560620
write_to_online_store=self.write_to_online_store,
561621
singleton=self.singleton or False,
562-
aggregations=self.aggregations,
622+
aggregations=[agg.to_proto() for agg in self.aggregations],
563623
version=self.version,
564624
)
565625
return OnDemandFeatureViewProto(spec=spec, meta=meta)
@@ -585,6 +645,18 @@ def from_proto(
585645
on_demand_feature_view_proto, skip_udf=skip_udf
586646
)
587647

648+
# Detect and strip input_schema sentinel from sources
649+
input_schema: Optional[List[Field]] = None
650+
sources_without_sentinel: List[OnDemandSourceType] = []
651+
for source in sources:
652+
if isinstance(source, RequestSource) and source.name.startswith(
653+
cls._INPUT_SCHEMA_SOURCE_PREFIX
654+
):
655+
input_schema = source.schema
656+
else:
657+
sources_without_sentinel.append(source)
658+
sources = sources_without_sentinel
659+
588660
# Parse transformation from proto (skip UDF deserialization if requested)
589661
transformation = cls._parse_transformation_from_proto(
590662
on_demand_feature_view_proto, skip_udf=skip_udf
@@ -607,6 +679,7 @@ def from_proto(
607679
name=on_demand_feature_view_proto.spec.name,
608680
schema=cls._parse_features_from_proto(on_demand_feature_view_proto),
609681
sources=cast(List[OnDemandSourceType], sources),
682+
input_schema=input_schema,
610683
feature_transformation=transformation,
611684
mode=on_demand_feature_view_proto.spec.mode or "pandas",
612685
description=on_demand_feature_view_proto.spec.description,
@@ -817,6 +890,10 @@ def get_request_data_schema(self) -> dict[str, ValueType]:
817890
raise TypeError(
818891
f"Request source schema is not correct type: ${str(type(request_source.schema))}"
819892
)
893+
# Include fields from the input_schema sentinel (stored privately)
894+
if self._input_schema_sentinel is not None:
895+
for field in self._input_schema_sentinel.schema:
896+
schema[field.name] = field.dtype.to_value_type()
820897
return schema
821898

822899
def _get_projected_feature_name(self, feature: str) -> str:
@@ -1092,7 +1169,7 @@ def _is_array_type(self, dtype) -> bool:
10921169
"""Check if the dtype represents an array type."""
10931170
# Use proper type checking instead of string comparison
10941171
dtype_str = str(dtype)
1095-
return "Array" in dtype_str or "List" in dtype_str
1172+
return "Array" in dtype_str or "List" in dtype_str or "Set" in dtype_str
10961173

10971174
def _construct_random_input(
10981175
self, singleton: bool = False
@@ -1137,6 +1214,13 @@ def _construct_random_input(
11371214
sample_value = sample_values.get(value_type, default_value)
11381215
feature_dict[field.name] = sample_value
11391216

1217+
# Add input_schema fields (stored privately outside source_request_sources)
1218+
if self._input_schema_sentinel is not None:
1219+
for field in self._input_schema_sentinel.schema:
1220+
value_type = field.dtype.to_value_type()
1221+
sample_value = sample_values.get(value_type, default_value)
1222+
feature_dict[field.name] = sample_value
1223+
11401224
return feature_dict
11411225

11421226
def _get_sample_values_by_type(self) -> dict[ValueType, list[Any]]:
@@ -1224,13 +1308,17 @@ def on_demand_feature_view(
12241308
name: Optional[str] = None,
12251309
entities: Optional[List[Entity]] = None,
12261310
schema: list[Field],
1227-
sources: list[
1228-
Union[
1229-
FeatureView,
1230-
RequestSource,
1231-
FeatureViewProjection,
1311+
sources: Optional[
1312+
list[
1313+
Union[
1314+
FeatureView,
1315+
RequestSource,
1316+
FeatureViewProjection,
1317+
]
12321318
]
1233-
],
1319+
] = None,
1320+
input_schema: Optional[list[Field]] = None,
1321+
aggregations: Optional[List[Aggregation]] = None,
12341322
mode: str = "pandas",
12351323
description: str = "",
12361324
tags: Optional[dict[str, str]] = None,
@@ -1252,6 +1340,10 @@ def on_demand_feature_view(
12521340
sources: A map from input source names to the actual input sources, which may be
12531341
feature views, or request data sources. These sources serve as inputs to the udf,
12541342
which will refer to them by name.
1343+
input_schema (optional): A list of Fields describing data that is accepted as input
1344+
but not stored directly as features — e.g. aggregation columns, normalization
1345+
parameters, thresholds, or other contextual values passed at request time.
1346+
When provided, sources is not required.
12551347
mode: The mode of execution (e.g,. Pandas or Python Native)
12561348
description (optional): A human-readable description.
12571349
tags (optional): A dictionary of key-value pairs to store arbitrary metadata.
@@ -1279,6 +1371,7 @@ def decorator(user_function):
12791371
on_demand_feature_view_obj = OnDemandFeatureView(
12801372
name=name if name is not None else user_function.__name__,
12811373
sources=sources,
1374+
input_schema=input_schema,
12821375
schema=schema,
12831376
mode=mode,
12841377
description=description,
@@ -1288,6 +1381,7 @@ def decorator(user_function):
12881381
entities=entities,
12891382
singleton=singleton,
12901383
track_metrics=track_metrics,
1384+
aggregations=aggregations,
12911385
udf=user_function,
12921386
udf_string=udf_string,
12931387
version=version,

0 commit comments

Comments
 (0)