From dd1bce1171782849a54b11b8c19297413fbfd299 Mon Sep 17 00:00:00 2001 From: Shenyang Cai Date: Wed, 24 Jun 2026 21:15:27 +0000 Subject: [PATCH] feat(bigframes): add ai.classify, ai.score, ai.if_ to the df bq accessor --- .../bigframes/bigquery/_operations/ai.py | 5 +- .../extensions/core/dataframe_accessor.py | 85 ++++++++ .../tests/system/small/bigquery/test_ai.py | 13 +- .../core/test_dataframe_accessor.py | 188 ++++++++++++++++++ 4 files changed, 289 insertions(+), 2 deletions(-) diff --git a/packages/bigframes/bigframes/bigquery/_operations/ai.py b/packages/bigframes/bigframes/bigquery/_operations/ai.py index 78b5d81b6744..40d5556de400 100644 --- a/packages/bigframes/bigframes/bigquery/_operations/ai.py +++ b/packages/bigframes/bigframes/bigquery/_operations/ai.py @@ -1178,12 +1178,15 @@ def _separate_context_and_series( Input: ("str1", series1, "str2", "str3", series2) Output: ["str1", None, "str2", "str3", None], [series1, series2] """ - if not isinstance(prompt, (str, list, tuple, series.Series)): + if not isinstance(prompt, (str, list, tuple, series.Series, pd.Series)): raise ValueError(f"Unsupported prompt type: {type(prompt)}") if isinstance(prompt, str): return [None], [series.Series([prompt])] + if isinstance(prompt, pd.Series): + return [None], [bpd.read_pandas(prompt)] + if isinstance(prompt, series.Series): if prompt.dtype == dtypes.OBJ_REF_DTYPE: # Multi-model support diff --git a/packages/bigframes/bigframes/extensions/core/dataframe_accessor.py b/packages/bigframes/bigframes/extensions/core/dataframe_accessor.py index c8fa49e41584..e490aa907dc4 100644 --- a/packages/bigframes/bigframes/extensions/core/dataframe_accessor.py +++ b/packages/bigframes/bigframes/extensions/core/dataframe_accessor.py @@ -214,6 +214,91 @@ def generate_double( ) return self._to_series(result) + def classify( + self, + input: PROMPT_TYPE, + categories: tuple[str, ...] | list[str], + *, + examples: list[tuple[str, str]] + | list[tuple[str, list[str] | tuple[str, ...]]] + | None = None, + connection_id: str | None = None, + endpoint: str | None = None, + output_mode: Literal["single", "multi"] | None = None, + optimization_mode: Literal["minimize_cost", "maximize_quality"] | None = None, + max_error_ratio: float | None = None, + ) -> S: + """ + Classifies a given input into one of the specified categories. It will always return one of the provided categories best fit the prompt input. + + This is an accessor for :func:`bigframes.bigquery.ai.classify`. See that + function's documentation for detailed parameter descriptions and examples. + """ + import bigframes.bigquery.ai + + result = bigframes.bigquery.ai.classify( + input, + categories, + examples=examples, + connection_id=connection_id, + endpoint=endpoint, + output_mode=output_mode, + optimization_mode=optimization_mode, + max_error_ratio=max_error_ratio, + ) + return self._to_series(result) + + def if_( + self, + prompt: PROMPT_TYPE, + *, + connection_id: str | None = None, + endpoint: str | None = None, + optimization_mode: Literal["minimize_cost", "maximize_quality"] | None = None, + max_error_ratio: float | None = None, + ) -> S: + """ + Evaluates the prompt to True or False. Compared to ``ai.generate_bool()``, this function + provides optimization such that not all rows are evaluated with the LLM. + + This is an accessor for :func:`bigframes.bigquery.ai.if_`. See that + function's documentation for detailed parameter descriptions and examples. + """ + import bigframes.bigquery.ai + + result = bigframes.bigquery.ai.if_( + prompt, + connection_id=connection_id, + endpoint=endpoint, + optimization_mode=optimization_mode, + max_error_ratio=max_error_ratio, + ) + return self._to_series(result) + + def score( + self, + prompt: PROMPT_TYPE, + *, + connection_id: str | None = None, + endpoint: str | None = None, + max_error_ratio: float | None = None, + ) -> S: + """ + Computes a score based on rubrics described in natural language. It will return a double value. + + This is an accessor for :func:`bigframes.bigquery.ai.score`. See that + function's documentation for detailed parameter descriptions and examples. + """ + import bigframes.bigquery.ai + + result = bigframes.bigquery.ai.score( + prompt, + connection_id=connection_id, + endpoint=endpoint, + max_error_ratio=max_error_ratio, + ) + return self._to_series(result) + class BigQueryDataFrameAccessor(AbstractBigQueryDataFrameAccessor[T, S]): """ diff --git a/packages/bigframes/tests/system/small/bigquery/test_ai.py b/packages/bigframes/tests/system/small/bigquery/test_ai.py index f3c94edd1969..05ebea141440 100644 --- a/packages/bigframes/tests/system/small/bigquery/test_ai.py +++ b/packages/bigframes/tests/system/small/bigquery/test_ai.py @@ -55,7 +55,7 @@ def _create_mock_obj_ref_df(session, uris, name="image", connection=None): return session.read_gbq(table_id) -def test_ai_function_pandas_input(session): +def test_ai_function_pandas_tuple_input(session): s1 = pd.Series(["apple", "bear"]) s2 = bpd.Series(["fruit", "tree"], session=session) prompt = (s1, " is a ", s2) @@ -74,6 +74,17 @@ def test_ai_function_pandas_input(session): ) +def test_ai_function_pandas_series_input(session): + s = pd.Series(["cat", "lavender"]) + + result = bbq.ai.classify( + s, categories=["animal", "plant"], endpoint="gemini-2.5-flash" + ) + + assert len(result) == len(s) + assert result.dtype == dtypes.STRING_DTYPE + + def test_ai_function_string_input(session): with mock.patch( "bigframes.core.global_session.get_global_session" diff --git a/packages/bigframes/tests/unit/extensions/core/test_dataframe_accessor.py b/packages/bigframes/tests/unit/extensions/core/test_dataframe_accessor.py index 2f3352116aff..c207070bb151 100644 --- a/packages/bigframes/tests/unit/extensions/core/test_dataframe_accessor.py +++ b/packages/bigframes/tests/unit/extensions/core/test_dataframe_accessor.py @@ -335,3 +335,191 @@ def test_bigframes_ai_generate_double(scalar_types_df: bpd.DataFrame, monkeypatc } result_series.to_pandas.assert_not_called() assert actual_result is result_series + + +def test_ai_classify(monkeypatch): + mock_classify = mock.MagicMock() + result_series = mock.create_autospec(bpd.Series) + mock_classify.return_value = result_series + expected_result = mock.create_autospec(pd.Series) + result_series.to_pandas.return_value = expected_result + + monkeypatch.setattr(bigframes.bigquery.ai, "classify", mock_classify) + + input_prompt = mock.create_autospec(pd.Series) + df = pd.DataFrame({"text_input": ["Is this a positive review?"]}) + actual_result = df.bigquery.ai.classify( + input_prompt, + categories=["Mammal", "Fish"], + examples=[("Cat", "Mammal")], + connection_id="conn", + endpoint="endpoint", + output_mode="single", + optimization_mode="minimize_cost", + max_error_ratio=0.1, + ) + + mock_classify.assert_called_once_with( + input_prompt, + ["Mammal", "Fish"], + examples=[("Cat", "Mammal")], + connection_id="conn", + endpoint="endpoint", + output_mode="single", + optimization_mode="minimize_cost", + max_error_ratio=0.1, + ) + result_series.to_pandas.assert_called_once() + assert actual_result is expected_result + + +def test_bigframes_ai_classify(scalar_types_df: bpd.DataFrame, monkeypatch): + bf_series = mock.create_autospec(bpd.Series) + result_series = mock.create_autospec(bpd.Series) + + mock_classify = mock.MagicMock() + mock_classify.return_value = result_series + + monkeypatch.setattr(bigframes.bigquery.ai, "classify", mock_classify) + + actual_result = scalar_types_df.bigquery.ai.classify( + bf_series, + categories=["Mammal", "Fish"], + examples=[("Cat", "Mammal")], + connection_id="conn", + endpoint="endpoint", + output_mode="single", + optimization_mode="minimize_cost", + max_error_ratio=0.1, + ) + + mock_classify.assert_called_once() + args, kwargs = mock_classify.call_args + assert args[0] is bf_series + assert args[1] == ["Mammal", "Fish"] + assert kwargs == { + "examples": [("Cat", "Mammal")], + "connection_id": "conn", + "endpoint": "endpoint", + "output_mode": "single", + "optimization_mode": "minimize_cost", + "max_error_ratio": 0.1, + } + result_series.to_pandas.assert_not_called() + assert actual_result is result_series + + +def test_ai_if(monkeypatch): + mock_if = mock.MagicMock() + result_series = mock.create_autospec(bpd.Series) + mock_if.return_value = result_series + expected_result = mock.create_autospec(pd.Series) + result_series.to_pandas.return_value = expected_result + + monkeypatch.setattr(bigframes.bigquery.ai, "if_", mock_if) + + prompt = mock.create_autospec(pd.Series) + df = pd.DataFrame({"text_input": ["Is this a positive review?"]}) + actual_result = df.bigquery.ai.if_( + prompt, + connection_id="conn", + endpoint="endpoint", + optimization_mode="minimize_cost", + max_error_ratio=0.1, + ) + + mock_if.assert_called_once_with( + prompt, + connection_id="conn", + endpoint="endpoint", + optimization_mode="minimize_cost", + max_error_ratio=0.1, + ) + result_series.to_pandas.assert_called_once() + assert actual_result is expected_result + + +def test_bigframes_ai_if(scalar_types_df: bpd.DataFrame, monkeypatch): + bf_series = mock.create_autospec(bpd.Series) + result_series = mock.create_autospec(bpd.Series) + + mock_if = mock.MagicMock() + mock_if.return_value = result_series + + monkeypatch.setattr(bigframes.bigquery.ai, "if_", mock_if) + + actual_result = scalar_types_df.bigquery.ai.if_( + bf_series, + connection_id="conn", + endpoint="endpoint", + optimization_mode="minimize_cost", + max_error_ratio=0.1, + ) + + mock_if.assert_called_once() + args, kwargs = mock_if.call_args + assert args[0] is bf_series + assert kwargs == { + "connection_id": "conn", + "endpoint": "endpoint", + "optimization_mode": "minimize_cost", + "max_error_ratio": 0.1, + } + result_series.to_pandas.assert_not_called() + assert actual_result is result_series + + +def test_ai_score(monkeypatch): + mock_score = mock.MagicMock() + result_series = mock.create_autospec(bpd.Series) + mock_score.return_value = result_series + expected_result = mock.create_autospec(pd.Series) + result_series.to_pandas.return_value = expected_result + + monkeypatch.setattr(bigframes.bigquery.ai, "score", mock_score) + + prompt = mock.create_autospec(pd.Series) + df = pd.DataFrame({"text_input": ["Is this a positive review?"]}) + actual_result = df.bigquery.ai.score( + prompt, + connection_id="conn", + endpoint="endpoint", + max_error_ratio=0.1, + ) + + mock_score.assert_called_once_with( + prompt, + connection_id="conn", + endpoint="endpoint", + max_error_ratio=0.1, + ) + result_series.to_pandas.assert_called_once() + assert actual_result is expected_result + + +def test_bigframes_ai_score(scalar_types_df: bpd.DataFrame, monkeypatch): + bf_series = mock.create_autospec(bpd.Series) + result_series = mock.create_autospec(bpd.Series) + + mock_score = mock.MagicMock() + mock_score.return_value = result_series + + monkeypatch.setattr(bigframes.bigquery.ai, "score", mock_score) + + actual_result = scalar_types_df.bigquery.ai.score( + bf_series, + connection_id="conn", + endpoint="endpoint", + max_error_ratio=0.1, + ) + + mock_score.assert_called_once() + args, kwargs = mock_score.call_args + assert args[0] is bf_series + assert kwargs == { + "connection_id": "conn", + "endpoint": "endpoint", + "max_error_ratio": 0.1, + } + result_series.to_pandas.assert_not_called() + assert actual_result is result_series