From 214e01f8a3f253777b13c4becb35a913e153d952 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Fri, 10 May 2024 11:47:44 -0700 Subject: [PATCH 1/5] allow safety_settings='off' Change-Id: Ica10b399177301073424a98cb3a8b0736dc216b4 --- google/generativeai/types/safety_types.py | 19 +++++- tests/test_generative_models.py | 4 +- tests/test_safety.py | 82 +++++++++++++++++++++++ 3 files changed, 101 insertions(+), 4 deletions(-) create mode 100644 tests/test_safety.py diff --git a/google/generativeai/types/safety_types.py b/google/generativeai/types/safety_types.py index 7d94a5bb0..1dc243008 100644 --- a/google/generativeai/types/safety_types.py +++ b/google/generativeai/types/safety_types.py @@ -219,13 +219,24 @@ class LooseSafetySettingDict(TypedDict): EasySafetySetting = Mapping[HarmCategoryOptions, HarmBlockThresholdOptions] EasySafetySettingDict = dict[HarmCategoryOptions, HarmBlockThresholdOptions] -SafetySettingOptions = Union[EasySafetySetting, Iterable[LooseSafetySettingDict], None] +SafetySettingOptions = Union[HarmBlockThresholdOptions, EasySafetySetting, Iterable[LooseSafetySettingDict], None] +def _expand_block_threshold(block_threshold:HarmBlockThresholdOptions): + block_threshold = to_block_threshold(block_threshold) + set(_NEW_HARM_CATEGORIES.values()) + return { + category: block_threshold + for category in set(_NEW_HARM_CATEGORIES.values()) + } def to_easy_safety_dict(settings: SafetySettingOptions, harm_category_set) -> EasySafetySettingDict: if settings is None: return {} - elif isinstance(settings, Mapping): + + if harm_category_set == "new" and isinstance(settings, (int, str, HarmBlockThreshold)): + settings = _expand_block_threshold(settings) + + if isinstance(settings, Mapping): return { to_harm_category(key, harm_category_set): to_block_threshold(value) for key, value in settings.items() @@ -243,6 +254,10 @@ def normalize_safety_settings( ) -> list[SafetySettingDict] | None: if settings is None: return None + + if harm_category_set == "new" and isinstance(settings, (int, str, HarmBlockThreshold)): + settings = _expand_block_threshold(settings) + if isinstance(settings, Mapping): return [ { diff --git a/tests/test_generative_models.py b/tests/test_generative_models.py index 6fabd59e9..6290b2dfd 100644 --- a/tests/test_generative_models.py +++ b/tests/test_generative_models.py @@ -181,7 +181,7 @@ def test_generation_config_overwrite(self, config1, config2): ) def test_safety_overwrite(self, safe1, safe2): # Safety - model = generative_models.GenerativeModel("gemini-pro", safety_settings={"danger": "low"}) + model = generative_models.GenerativeModel("gemini-pro", safety_settings=safe1) self.responses["generate_content"] = [ simple_response(" world!"), @@ -198,7 +198,7 @@ def test_safety_overwrite(self, safe1, safe2): glm.SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, ) - _ = model.generate_content("hello", safety_settings={"danger": "high"}) + _ = model.generate_content("hello", safety_settings=safe2) self.assertEqual( self.observed_requests[-1].safety_settings[0].category, glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, diff --git a/tests/test_safety.py b/tests/test_safety.py new file mode 100644 index 000000000..d7bae9af5 --- /dev/null +++ b/tests/test_safety.py @@ -0,0 +1,82 @@ +import collections +from collections.abc import Iterable +import copy +import pathlib +from typing import Any +import textwrap +import unittest.mock +from absl.testing import absltest +from absl.testing import parameterized +import google.ai.generativelanguage as glm +from google.generativeai import client as client_lib +from google.generativeai import generative_models +from google.generativeai.types import content_types +from google.generativeai.types import generation_types + +import PIL.Image + + +class SafetyTests(parameterized.TestCase): + """Tests are in order with the design doc.""" + + @parameterized.named_parameters( + ["block_threshold", glm.SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE], + ["block_threshold", "low"], + ["block_threshold", 1], + ["dict", {"danger": "low"}, {"danger": "high"}], + [ + "list-dict", + [ + dict( + category=glm.HarmCategory.HARM_CATEGORY_DANGEROUS, + threshold=glm.SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + ), + ], + ],[ + "list-dict2" + [ + dict(category="danger", threshold="high"), + ], + ], + [ + "object", + [ + glm.SafetySetting( + category=glm.HarmCategory.HARM_CATEGORY_DANGEROUS, + threshold=glm.SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + ), + ], + ], + ) + def test_safety_overwrite(self, safe1, safe2): + # Safety + model = generative_models.GenerativeModel("gemini-pro", safety_settings=safe1) + + self.responses["generate_content"] = [ + simple_response(" world!"), + simple_response(" world!"), + ] + + _ = model.generate_content("hello") + self.assertEqual( + self.observed_requests[-1].safety_settings[0].category, + glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + ) + self.assertEqual( + self.observed_requests[-1].safety_settings[0].threshold, + glm.SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + ) + + _ = model.generate_content("hello", safety_settings=safe2) + self.assertEqual( + self.observed_requests[-1].safety_settings[0].category, + glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + ) + self.assertEqual( + self.observed_requests[-1].safety_settings[0].threshold, + glm.SafetySetting.HarmBlockThreshold.BLOCK_ONLY_HIGH, + ) + + +if __name__ == "__main__": + absltest.main() From c0f03d79c5b3a02bc6b613dd6b663f10e9d1a229 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Fri, 10 May 2024 13:59:47 -0700 Subject: [PATCH 2/5] Fix tests. Change-Id: I06cfd07397e984b9fb757b2831b419eefb8aff98 --- google/generativeai/types/safety_types.py | 13 ++--- tests/test_safety.py | 58 ++++++----------------- 2 files changed, 22 insertions(+), 49 deletions(-) diff --git a/google/generativeai/types/safety_types.py b/google/generativeai/types/safety_types.py index 1dc243008..77f48ebc8 100644 --- a/google/generativeai/types/safety_types.py +++ b/google/generativeai/types/safety_types.py @@ -219,15 +219,16 @@ class LooseSafetySettingDict(TypedDict): EasySafetySetting = Mapping[HarmCategoryOptions, HarmBlockThresholdOptions] EasySafetySettingDict = dict[HarmCategoryOptions, HarmBlockThresholdOptions] -SafetySettingOptions = Union[HarmBlockThresholdOptions, EasySafetySetting, Iterable[LooseSafetySettingDict], None] +SafetySettingOptions = Union[ + HarmBlockThresholdOptions, EasySafetySetting, Iterable[LooseSafetySettingDict], None +] + -def _expand_block_threshold(block_threshold:HarmBlockThresholdOptions): +def _expand_block_threshold(block_threshold: HarmBlockThresholdOptions): block_threshold = to_block_threshold(block_threshold) set(_NEW_HARM_CATEGORIES.values()) - return { - category: block_threshold - for category in set(_NEW_HARM_CATEGORIES.values()) - } + return {category: block_threshold for category in set(_NEW_HARM_CATEGORIES.values())} + def to_easy_safety_dict(settings: SafetySettingOptions, harm_category_set) -> EasySafetySettingDict: if settings is None: diff --git a/tests/test_safety.py b/tests/test_safety.py index d7bae9af5..433cbd445 100644 --- a/tests/test_safety.py +++ b/tests/test_safety.py @@ -10,7 +10,7 @@ import google.ai.generativelanguage as glm from google.generativeai import client as client_lib from google.generativeai import generative_models -from google.generativeai.types import content_types +from google.generativeai.types import safety_types from google.generativeai.types import generation_types import PIL.Image @@ -20,61 +20,33 @@ class SafetyTests(parameterized.TestCase): """Tests are in order with the design doc.""" @parameterized.named_parameters( - ["block_threshold", glm.SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE], - ["block_threshold", "low"], - ["block_threshold", 1], - ["dict", {"danger": "low"}, {"danger": "high"}], + ["block_threshold", glm.SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE], + ["block_threshold2", "medium"], + ["block_threshold3", 2], + ["dict", {"danger": "medium"}], + ["dict2", {"danger": 2}], + ["dict3", {"danger": glm.SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE}], [ "list-dict", [ dict( - category=glm.HarmCategory.HARM_CATEGORY_DANGEROUS, - threshold=glm.SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + category=glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold=glm.SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, ), ], - ],[ - "list-dict2" - [ - dict(category="danger", threshold="high"), - ], ], [ - "object", + "list-dict2", [ - glm.SafetySetting( - category=glm.HarmCategory.HARM_CATEGORY_DANGEROUS, - threshold=glm.SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, - ), + dict(category="danger", threshold="med"), ], ], ) - def test_safety_overwrite(self, safe1, safe2): - # Safety - model = generative_models.GenerativeModel("gemini-pro", safety_settings=safe1) - - self.responses["generate_content"] = [ - simple_response(" world!"), - simple_response(" world!"), - ] - - _ = model.generate_content("hello") - self.assertEqual( - self.observed_requests[-1].safety_settings[0].category, - glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - ) - self.assertEqual( - self.observed_requests[-1].safety_settings[0].threshold, - glm.SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, - ) - - _ = model.generate_content("hello", safety_settings=safe2) - self.assertEqual( - self.observed_requests[-1].safety_settings[0].category, - glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - ) + def test_safety_overwrite(self, setting): + setting = safety_types.to_easy_safety_dict(setting, "new") self.assertEqual( - self.observed_requests[-1].safety_settings[0].threshold, - glm.SafetySetting.HarmBlockThreshold.BLOCK_ONLY_HIGH, + setting[glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT], + glm.SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, ) From d3616d039ea5d8a68338251348f10143ea93d6d1 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Fri, 17 May 2024 13:20:06 -0700 Subject: [PATCH 3/5] license Change-Id: Ifa4843831b9c1479198c2b45c5b5abad8410f448 --- tests/test_safety.py | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/tests/test_safety.py b/tests/test_safety.py index 433cbd445..4d3c32339 100644 --- a/tests/test_safety.py +++ b/tests/test_safety.py @@ -1,19 +1,22 @@ -import collections -from collections.abc import Iterable -import copy -import pathlib -from typing import Any -import textwrap -import unittest.mock +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from absl.testing import absltest from absl.testing import parameterized import google.ai.generativelanguage as glm -from google.generativeai import client as client_lib -from google.generativeai import generative_models from google.generativeai.types import safety_types -from google.generativeai.types import generation_types - -import PIL.Image class SafetyTests(parameterized.TestCase): From 3ee11811388943d96297824abdaad22cef9288c1 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Tue, 21 May 2024 07:09:06 -0700 Subject: [PATCH 4/5] format Change-Id: I534837c309121cda9c8947acdd6c126c9c730d62 --- google/generativeai/types/safety_types.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/google/generativeai/types/safety_types.py b/google/generativeai/types/safety_types.py index 02b1efd46..c8368da7f 100644 --- a/google/generativeai/types/safety_types.py +++ b/google/generativeai/types/safety_types.py @@ -228,9 +228,13 @@ def to_easy_safety_dict(settings: SafetySettingOptions) -> EasySafetySettingDict if isinstance(setting, glm.SafetySetting): result[to_harm_category(setting.category)] = to_block_threshold(setting.threshold) elif isinstance(setting, dict): - result[to_harm_category(setting["category"])] = to_block_threshold(setting["threshold"]) + result[to_harm_category(setting["category"])] = to_block_threshold( + setting["threshold"] + ) else: - raise ValueError(f"Could not understand safety setting:\n {type(setting)=}\n {setting=}") + raise ValueError( + f"Could not understand safety setting:\n {type(setting)=}\n {setting=}" + ) return result From 4a7d8a97f7a78ef113985225946849a824f4ea08 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Tue, 21 May 2024 07:34:40 -0700 Subject: [PATCH 5/5] add test Change-Id: I9bce66322d64b3d6296d4db7cc0a7b7b9a78763b --- tests/test_generative_models.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/tests/test_generative_models.py b/tests/test_generative_models.py index 220f04d59..a4f1e2522 100644 --- a/tests/test_generative_models.py +++ b/tests/test_generative_models.py @@ -155,6 +155,7 @@ def test_generation_config_overwrite(self, config1, config2): @parameterized.named_parameters( ["dict", {"danger": "low"}, {"danger": "high"}], + ["quick", "low", "high"], [ "list-dict", [ @@ -193,22 +194,25 @@ def test_safety_overwrite(self, safe1, safe2): ] _ = model.generate_content("hello") + + danger = [ + s + for s in self.observed_requests[-1].safety_settings + if s.category == glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT + ] self.assertEqual( - self.observed_requests[-1].safety_settings[0].category, - glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - ) - self.assertEqual( - self.observed_requests[-1].safety_settings[0].threshold, + danger[0].threshold, glm.SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, ) _ = model.generate_content("hello", safety_settings=safe2) + danger = [ + s + for s in self.observed_requests[-1].safety_settings + if s.category == glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT + ] self.assertEqual( - self.observed_requests[-1].safety_settings[0].category, - glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - ) - self.assertEqual( - self.observed_requests[-1].safety_settings[0].threshold, + danger[0].threshold, glm.SafetySetting.HarmBlockThreshold.BLOCK_ONLY_HIGH, )