diff --git a/google/generativeai/types/safety_types.py b/google/generativeai/types/safety_types.py index 85e57c8f6..c8368da7f 100644 --- a/google/generativeai/types/safety_types.py +++ b/google/generativeai/types/safety_types.py @@ -201,18 +201,41 @@ class LooseSafetySettingDict(TypedDict): EasySafetySetting = Mapping[HarmCategoryOptions, HarmBlockThresholdOptions] EasySafetySettingDict = dict[HarmCategoryOptions, HarmBlockThresholdOptions] -SafetySettingOptions = Union[EasySafetySetting, Iterable[LooseSafetySettingDict], None] +SafetySettingOptions = Union[ + HarmBlockThresholdOptions, EasySafetySetting, Iterable[LooseSafetySettingDict], None +] + + +def _expand_block_threshold(block_threshold: HarmBlockThresholdOptions): + block_threshold = to_block_threshold(block_threshold) + set(_HARM_CATEGORIES.values()) + return {category: block_threshold for category in set(_HARM_CATEGORIES.values())} def to_easy_safety_dict(settings: SafetySettingOptions) -> EasySafetySettingDict: if settings is None: return {} - elif isinstance(settings, Mapping): + + if isinstance(settings, (int, str, HarmBlockThreshold)): + settings = _expand_block_threshold(settings) + + if isinstance(settings, Mapping): return {to_harm_category(key): to_block_threshold(value) for key, value in settings.items()} + else: # Iterable - return { - to_harm_category(d["category"]): to_block_threshold(d["threshold"]) for d in settings - } + result = {} + for setting in settings: + if isinstance(setting, glm.SafetySetting): + result[to_harm_category(setting.category)] = to_block_threshold(setting.threshold) + elif isinstance(setting, dict): + result[to_harm_category(setting["category"])] = to_block_threshold( + setting["threshold"] + ) + else: + raise ValueError( + f"Could not understand safety setting:\n {type(setting)=}\n {setting=}" + ) + return result def normalize_safety_settings( @@ -220,6 +243,10 @@ def normalize_safety_settings( ) -> list[SafetySettingDict] | None: if settings is None: return None + + if isinstance(settings, (int, str, HarmBlockThreshold)): + settings = _expand_block_threshold(settings) + if isinstance(settings, Mapping): return [ { diff --git a/tests/test_generative_models.py b/tests/test_generative_models.py index 3b0c27814..a4f1e2522 100644 --- a/tests/test_generative_models.py +++ b/tests/test_generative_models.py @@ -155,11 +155,12 @@ def test_generation_config_overwrite(self, config1, config2): @parameterized.named_parameters( ["dict", {"danger": "low"}, {"danger": "high"}], + ["quick", "low", "high"], [ "list-dict", [ dict( - category=glm.HarmCategory.HARM_CATEGORY_DANGEROUS, + category=glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, threshold=glm.SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, ), ], @@ -171,21 +172,21 @@ def test_generation_config_overwrite(self, config1, config2): "object", [ glm.SafetySetting( - category=glm.HarmCategory.HARM_CATEGORY_DANGEROUS, + category=glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, threshold=glm.SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, ), ], [ glm.SafetySetting( - category=glm.HarmCategory.HARM_CATEGORY_DANGEROUS, - threshold=glm.SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + category=glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold=glm.SafetySetting.HarmBlockThreshold.BLOCK_ONLY_HIGH, ), ], ], ) def test_safety_overwrite(self, safe1, safe2): # Safety - model = generative_models.GenerativeModel("gemini-pro", safety_settings={"danger": "low"}) + model = generative_models.GenerativeModel("gemini-pro", safety_settings=safe1) self.responses["generate_content"] = [ simple_response(" world!"), @@ -193,22 +194,25 @@ def test_safety_overwrite(self, safe1, safe2): ] _ = model.generate_content("hello") + + danger = [ + s + for s in self.observed_requests[-1].safety_settings + if s.category == glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT + ] self.assertEqual( - self.observed_requests[-1].safety_settings[0].category, - glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - ) - self.assertEqual( - self.observed_requests[-1].safety_settings[0].threshold, + danger[0].threshold, glm.SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, ) - _ = model.generate_content("hello", safety_settings={"danger": "high"}) - self.assertEqual( - self.observed_requests[-1].safety_settings[0].category, - glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - ) + _ = model.generate_content("hello", safety_settings=safe2) + danger = [ + s + for s in self.observed_requests[-1].safety_settings + if s.category == glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT + ] self.assertEqual( - self.observed_requests[-1].safety_settings[0].threshold, + danger[0].threshold, glm.SafetySetting.HarmBlockThreshold.BLOCK_ONLY_HIGH, ) diff --git a/tests/test_safety.py b/tests/test_safety.py new file mode 100644 index 000000000..f3efc4aca --- /dev/null +++ b/tests/test_safety.py @@ -0,0 +1,57 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest +from absl.testing import parameterized +import google.ai.generativelanguage as glm +from google.generativeai.types import safety_types + + +class SafetyTests(parameterized.TestCase): + """Tests are in order with the design doc.""" + + @parameterized.named_parameters( + ["block_threshold", glm.SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE], + ["block_threshold2", "medium"], + ["block_threshold3", 2], + ["dict", {"danger": "medium"}], + ["dict2", {"danger": 2}], + ["dict3", {"danger": glm.SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE}], + [ + "list-dict", + [ + dict( + category=glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold=glm.SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, + ), + ], + ], + [ + "list-dict2", + [ + dict(category="danger", threshold="med"), + ], + ], + ) + def test_safety_overwrite(self, setting): + setting = safety_types.to_easy_safety_dict(setting) + self.assertEqual( + setting[glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT], + glm.SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, + ) + + +if __name__ == "__main__": + absltest.main()