Skip to content

Commit e1b1d2d

Browse files
abhijeet-dhumalntkathole
authored andcommitted
fix(spark): Use SELECT * when feature_name_columns is empty in pull_all_from_table_or_query
pull_all_from_table_or_query always builds an explicit SELECT projection from join_key_columns + feature_name_columns + timestamp_fields. When feature_name_columns=[] — the "read all source columns" signal used by FeatureBuilder.get_column_info for BatchFeatureView with TransformationMode.PYTHON, ray, and pandas — the generated SQL becomes: SELECT user_id, event_timestamp FROM source WHERE ... All raw feature columns (rating, text, helpful_vote, …) are silently dropped. The UDF receives a 2-column DataFrame and every aggregation returns null or fails. Fix: guard on feature_name_columns being non-empty before building the explicit projection; fall through to SELECT * when it is empty. Signed-off-by: abhijeet-dhumal <abhijeetdhumal652@gmail.com>
1 parent 835cda8 commit e1b1d2d

2 files changed

Lines changed: 183 additions & 5 deletions

File tree

sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -387,12 +387,18 @@ def pull_all_from_table_or_query(
387387
timestamp_fields = [timestamp_field]
388388
if created_timestamp_column:
389389
timestamp_fields.append(created_timestamp_column)
390-
(fields_with_aliases, aliases) = _get_fields_with_aliases(
391-
fields=join_key_columns + feature_name_columns + timestamp_fields,
392-
field_mappings=data_source.field_mapping,
393-
)
394390

395-
fields_with_alias_string = ", ".join(fields_with_aliases)
391+
if feature_name_columns:
392+
(fields_with_aliases, _) = _get_fields_with_aliases(
393+
fields=join_key_columns + feature_name_columns + timestamp_fields,
394+
field_mappings=data_source.field_mapping,
395+
)
396+
fields_with_alias_string = ", ".join(fields_with_aliases)
397+
else:
398+
# Empty feature_name_columns signals "read all source columns".
399+
# Used by BatchFeatureView with TransformationMode.PYTHON/ray/pandas where
400+
# the UDF computes output features from raw input — don't project upfront.
401+
fields_with_alias_string = "*"
396402

397403
from_expression = data_source.get_table_query_string()
398404
timestamp_filter = get_timestamp_filter_sql(
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
"""
2+
Unit tests for SparkOfflineStore.pull_all_from_table_or_query SQL generation.
3+
4+
Covers the bug where feature_name_columns=[] (signalling "read all source
5+
columns" for BatchFeatureView UDF transformations) caused a bare
6+
SELECT user_id, event_timestamp FROM source
7+
instead of SELECT *, silently dropping all columns the UDF needs.
8+
"""
9+
10+
from datetime import datetime, timezone
11+
from unittest.mock import MagicMock, patch
12+
13+
import pytest
14+
15+
from feast.infra.offline_stores.contrib.spark_offline_store.spark import ( # noqa: E402
16+
SparkOfflineStore,
17+
SparkOfflineStoreConfig,
18+
)
19+
from feast.infra.offline_stores.contrib.spark_offline_store.spark_source import ( # noqa: E402
20+
SparkSource,
21+
)
22+
from feast.infra.online_stores.sqlite import SqliteOnlineStoreConfig # noqa: E402
23+
from feast.repo_config import RepoConfig # noqa: E402
24+
25+
# ---------------------------------------------------------------------------
26+
# Shared fixtures
27+
# ---------------------------------------------------------------------------
28+
29+
START = datetime(2023, 1, 1, tzinfo=timezone.utc)
30+
END = datetime(2024, 1, 1, tzinfo=timezone.utc)
31+
32+
# Fixed table name returned by the mocked get_table_query_string
33+
_TABLE_EXPR = "`raw_reviews`"
34+
35+
36+
@pytest.fixture()
37+
def repo_config():
38+
return RepoConfig(
39+
registry="file:///tmp/registry.db",
40+
project="test",
41+
provider="local",
42+
online_store=SqliteOnlineStoreConfig(type="sqlite"),
43+
offline_store=SparkOfflineStoreConfig(type="spark"),
44+
)
45+
46+
47+
@pytest.fixture()
48+
def spark_source():
49+
return SparkSource(
50+
name="raw_reviews",
51+
path="s3a://bucket/processed/reviews/",
52+
file_format="parquet",
53+
timestamp_field="event_timestamp",
54+
)
55+
56+
57+
def _run_pull_all(repo_config, spark_source, feature_name_columns):
58+
"""
59+
Call pull_all_from_table_or_query with a mocked SparkSession and mocked
60+
data-source table resolution, then return the SQL query string.
61+
62+
Two things are patched so no real Spark/S3 access occurs:
63+
1. get_spark_session_or_start_new_with_repoconfig → MagicMock session
64+
2. spark_source.get_table_query_string → fixed table expression
65+
(avoids SparkSource.validate / _load_dataframe_from_path hitting S3)
66+
"""
67+
mock_spark = MagicMock()
68+
69+
with (
70+
patch(
71+
"feast.infra.offline_stores.contrib.spark_offline_store.spark"
72+
".get_spark_session_or_start_new_with_repoconfig",
73+
return_value=mock_spark,
74+
),
75+
patch.object(
76+
spark_source,
77+
"get_table_query_string",
78+
return_value=_TABLE_EXPR,
79+
),
80+
):
81+
job = SparkOfflineStore.pull_all_from_table_or_query(
82+
config=repo_config,
83+
data_source=spark_source,
84+
join_key_columns=["user_id"],
85+
feature_name_columns=feature_name_columns,
86+
timestamp_field="event_timestamp",
87+
created_timestamp_column=None,
88+
start_date=START,
89+
end_date=END,
90+
)
91+
92+
return job.query.strip()
93+
94+
95+
def test_pull_all_with_empty_feature_cols_generates_select_star(
96+
repo_config, spark_source
97+
):
98+
"""
99+
feature_name_columns=[] must produce SELECT * so UDF-based
100+
BatchFeatureViews receive all raw source columns for aggregation.
101+
"""
102+
sql = _run_pull_all(repo_config, spark_source, feature_name_columns=[])
103+
104+
assert sql.startswith("SELECT *"), (
105+
"Expected 'SELECT *' when feature_name_columns=[], "
106+
f"got: {sql[:120]!r}\n\n"
107+
"BatchFeatureView UDFs need all raw source columns to compute "
108+
"aggregations — projecting only join key + timestamp silently "
109+
"drops rating, text, helpful_vote, etc."
110+
)
111+
assert "user_id" not in sql.split("FROM")[0], (
112+
"SELECT * must not also explicitly list join key columns"
113+
)
114+
115+
116+
def test_pull_all_with_feature_cols_generates_explicit_projection(
117+
repo_config, spark_source
118+
):
119+
"""
120+
When feature_name_columns is non-empty (normal FeatureView path),
121+
the query must project only the requested columns — not SELECT *.
122+
"""
123+
sql = _run_pull_all(
124+
repo_config,
125+
spark_source,
126+
feature_name_columns=["avg_rating", "review_count"],
127+
)
128+
129+
assert "SELECT *" not in sql, (
130+
"Non-empty feature_name_columns must produce explicit SELECT projection, not SELECT *"
131+
)
132+
assert "avg_rating" in sql
133+
assert "review_count" in sql
134+
assert "user_id" in sql
135+
assert "event_timestamp" in sql
136+
137+
138+
def test_pull_all_empty_feature_cols_upstream_regression(repo_config, spark_source):
139+
"""
140+
Regression guard: the upstream (unfixed) behaviour with feature_name_columns=[]
141+
produced a query that only selected join key + timestamp, dropping all columns
142+
the UDF needs. Verify the fixed code does NOT produce that broken query.
143+
144+
Broken upstream SQL looked like:
145+
SELECT user_id, event_timestamp FROM ... WHERE ...
146+
"""
147+
sql = _run_pull_all(repo_config, spark_source, feature_name_columns=[])
148+
149+
projection = sql.split("FROM")[0]
150+
assert "user_id" not in projection, (
151+
"Upstream bug: query projected only 'user_id, event_timestamp', "
152+
"silently dropping all columns needed by the BFV UDF. "
153+
"Fixed query should use SELECT *."
154+
)
155+
156+
157+
@pytest.mark.parametrize(
158+
"feature_cols,expect_star",
159+
[
160+
([], True),
161+
(["f1"], False),
162+
(["f1", "f2", "f3"], False),
163+
],
164+
)
165+
def test_pull_all_select_star_only_when_feature_cols_empty(
166+
repo_config, spark_source, feature_cols, expect_star
167+
):
168+
sql = _run_pull_all(repo_config, spark_source, feature_name_columns=feature_cols)
169+
has_star = sql.strip().upper().startswith("SELECT *")
170+
assert has_star == expect_star, (
171+
f"feature_cols={feature_cols!r}: expected SELECT *={expect_star}, got SQL: {sql[:100]!r}"
172+
)

0 commit comments

Comments
 (0)