diff --git a/google/generativeai/answer.py b/google/generativeai/answer.py index f17a82a17..637002052 100644 --- a/google/generativeai/answer.py +++ b/google/generativeai/answer.py @@ -206,9 +206,7 @@ def _make_generate_answer_request( contents = content_types.to_contents(contents) if safety_settings: - safety_settings = safety_types.normalize_safety_settings( - safety_settings, harm_category_set="new" - ) + safety_settings = safety_types.normalize_safety_settings(safety_settings) if inline_passages is not None and semantic_retriever is not None: raise ValueError( diff --git a/google/generativeai/discuss.py b/google/generativeai/discuss.py index 0cc342096..81e087aa0 100644 --- a/google/generativeai/discuss.py +++ b/google/generativeai/discuss.py @@ -27,7 +27,7 @@ from google.generativeai import string_utils from google.generativeai.types import discuss_types from google.generativeai.types import model_types -from google.generativeai.types import safety_types +from google.generativeai.types import palm_safety_types def _make_message(content: discuss_types.MessageOptions) -> glm.Message: @@ -521,7 +521,7 @@ def _build_chat_response( response = type(response).to_dict(response) response.pop("messages") - response["filters"] = safety_types.convert_filters_to_enums(response["filters"]) + response["filters"] = palm_safety_types.convert_filters_to_enums(response["filters"]) if response["candidates"]: last = response["candidates"][0] diff --git a/google/generativeai/generative_models.py b/google/generativeai/generative_models.py index a0e7df1e2..4d71baf48 100644 --- a/google/generativeai/generative_models.py +++ b/google/generativeai/generative_models.py @@ -79,9 +79,7 @@ def __init__( if "/" not in model_name: model_name = "models/" + model_name self._model_name = model_name - self._safety_settings = safety_types.to_easy_safety_dict( - safety_settings, harm_category_set="new" - ) + self._safety_settings = safety_types.to_easy_safety_dict(safety_settings) self._generation_config = generation_types.to_generation_config_dict(generation_config) self._tools = content_types.to_function_library(tools) @@ -149,10 +147,10 @@ def _prepare_request( merged_gc = self._generation_config.copy() merged_gc.update(generation_config) - safety_settings = safety_types.to_easy_safety_dict(safety_settings, harm_category_set="new") + safety_settings = safety_types.to_easy_safety_dict(safety_settings) merged_ss = self._safety_settings.copy() merged_ss.update(safety_settings) - merged_ss = safety_types.normalize_safety_settings(merged_ss, harm_category_set="new") + merged_ss = safety_types.normalize_safety_settings(merged_ss) return glm.GenerateContentRequest( model=self._model_name, diff --git a/google/generativeai/text.py b/google/generativeai/text.py index 3a147f945..e51090e1f 100644 --- a/google/generativeai/text.py +++ b/google/generativeai/text.py @@ -26,7 +26,7 @@ from google.generativeai.types import text_types from google.generativeai.types import model_types from google.generativeai import models -from google.generativeai.types import safety_types +from google.generativeai.types import palm_safety_types DEFAULT_TEXT_MODEL = "models/text-bison-001" EMBEDDING_MAX_BATCH_SIZE = 100 @@ -81,7 +81,7 @@ def _make_generate_text_request( max_output_tokens: int | None = None, top_p: int | None = None, top_k: int | None = None, - safety_settings: safety_types.SafetySettingOptions | None = None, + safety_settings: palm_safety_types.SafetySettingOptions | None = None, stop_sequences: str | Iterable[str] | None = None, ) -> glm.GenerateTextRequest: """ @@ -108,9 +108,7 @@ def _make_generate_text_request( """ model = model_types.make_model_name(model) prompt = _make_text_prompt(prompt=prompt) - safety_settings = safety_types.normalize_safety_settings( - safety_settings, harm_category_set="old" - ) + safety_settings = palm_safety_types.normalize_safety_settings(safety_settings) if isinstance(stop_sequences, str): stop_sequences = [stop_sequences] if stop_sequences: @@ -138,7 +136,7 @@ def generate_text( max_output_tokens: int | None = None, top_p: float | None = None, top_k: float | None = None, - safety_settings: safety_types.SafetySettingOptions | None = None, + safety_settings: palm_safety_types.SafetySettingOptions | None = None, stop_sequences: str | Iterable[str] | None = None, client: glm.TextServiceClient | None = None, request_options: dict[str, Any] | None = None, @@ -240,11 +238,11 @@ def _generate_response( response = client.generate_text(request, **request_options) response = type(response).to_dict(response) - response["filters"] = safety_types.convert_filters_to_enums(response["filters"]) - response["safety_feedback"] = safety_types.convert_safety_feedback_to_enums( + response["filters"] = palm_safety_types.convert_filters_to_enums(response["filters"]) + response["safety_feedback"] = palm_safety_types.convert_safety_feedback_to_enums( response["safety_feedback"] ) - response["candidates"] = safety_types.convert_candidate_enums(response["candidates"]) + response["candidates"] = palm_safety_types.convert_candidate_enums(response["candidates"]) return Completion(_client=client, **response) diff --git a/google/generativeai/types/discuss_types.py b/google/generativeai/types/discuss_types.py index 0cb393e5c..fa777d1d1 100644 --- a/google/generativeai/types/discuss_types.py +++ b/google/generativeai/types/discuss_types.py @@ -22,7 +22,7 @@ import google.ai.generativelanguage as glm from google.generativeai import string_utils -from google.generativeai.types import safety_types +from google.generativeai.types import palm_safety_types from google.generativeai.types import citation_types @@ -169,7 +169,7 @@ class ChatResponse(abc.ABC): temperature: Optional[float] candidate_count: Optional[int] candidates: List[MessageDict] - filters: List[safety_types.ContentFilterDict] + filters: List[palm_safety_types.ContentFilterDict] top_p: Optional[float] = None top_k: Optional[float] = None diff --git a/google/generativeai/types/palm_safety_types.py b/google/generativeai/types/palm_safety_types.py new file mode 100644 index 000000000..9fb88cd67 --- /dev/null +++ b/google/generativeai/types/palm_safety_types.py @@ -0,0 +1,286 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from collections.abc import Mapping + +import enum +import typing +from typing import Dict, Iterable, List, Union + +from typing_extensions import TypedDict + + +from google.ai import generativelanguage as glm +from google.generativeai import string_utils + + +__all__ = [ + "HarmCategory", + "HarmProbability", + "HarmBlockThreshold", + "BlockedReason", + "ContentFilterDict", + "SafetyRatingDict", + "SafetySettingDict", + "SafetyFeedbackDict", +] + +# These are basic python enums, it's okay to expose them +HarmProbability = glm.SafetyRating.HarmProbability +HarmBlockThreshold = glm.SafetySetting.HarmBlockThreshold +BlockedReason = glm.ContentFilter.BlockedReason + + +class HarmCategory: + """ + Harm Categories supported by the palm-family models + """ + + HARM_CATEGORY_UNSPECIFIED = glm.HarmCategory.HARM_CATEGORY_UNSPECIFIED.value + HARM_CATEGORY_DEROGATORY = glm.HarmCategory.HARM_CATEGORY_DEROGATORY.value + HARM_CATEGORY_TOXICITY = glm.HarmCategory.HARM_CATEGORY_TOXICITY.value + HARM_CATEGORY_VIOLENCE = glm.HarmCategory.HARM_CATEGORY_VIOLENCE.value + HARM_CATEGORY_SEXUAL = glm.HarmCategory.HARM_CATEGORY_SEXUAL.value + HARM_CATEGORY_MEDICAL = glm.HarmCategory.HARM_CATEGORY_MEDICAL.value + HARM_CATEGORY_DANGEROUS = glm.HarmCategory.HARM_CATEGORY_DANGEROUS.value + + +HarmCategoryOptions = Union[str, int, HarmCategory] + +# fmt: off +_HARM_CATEGORIES: Dict[HarmCategoryOptions, glm.HarmCategory] = { + glm.HarmCategory.HARM_CATEGORY_UNSPECIFIED: glm.HarmCategory.HARM_CATEGORY_UNSPECIFIED, + HarmCategory.HARM_CATEGORY_UNSPECIFIED: glm.HarmCategory.HARM_CATEGORY_UNSPECIFIED, + 0: glm.HarmCategory.HARM_CATEGORY_UNSPECIFIED, + "harm_category_unspecified": glm.HarmCategory.HARM_CATEGORY_UNSPECIFIED, + "unspecified": glm.HarmCategory.HARM_CATEGORY_UNSPECIFIED, + + glm.HarmCategory.HARM_CATEGORY_DEROGATORY: glm.HarmCategory.HARM_CATEGORY_DEROGATORY, + HarmCategory.HARM_CATEGORY_DEROGATORY: glm.HarmCategory.HARM_CATEGORY_DEROGATORY, + 1: glm.HarmCategory.HARM_CATEGORY_DEROGATORY, + "harm_category_derogatory": glm.HarmCategory.HARM_CATEGORY_DEROGATORY, + "derogatory": glm.HarmCategory.HARM_CATEGORY_DEROGATORY, + + glm.HarmCategory.HARM_CATEGORY_TOXICITY: glm.HarmCategory.HARM_CATEGORY_TOXICITY, + HarmCategory.HARM_CATEGORY_TOXICITY: glm.HarmCategory.HARM_CATEGORY_TOXICITY, + 2: glm.HarmCategory.HARM_CATEGORY_TOXICITY, + "harm_category_toxicity": glm.HarmCategory.HARM_CATEGORY_TOXICITY, + "toxicity": glm.HarmCategory.HARM_CATEGORY_TOXICITY, + "toxic": glm.HarmCategory.HARM_CATEGORY_TOXICITY, + + glm.HarmCategory.HARM_CATEGORY_VIOLENCE: glm.HarmCategory.HARM_CATEGORY_VIOLENCE, + HarmCategory.HARM_CATEGORY_VIOLENCE: glm.HarmCategory.HARM_CATEGORY_VIOLENCE, + 3: glm.HarmCategory.HARM_CATEGORY_VIOLENCE, + "harm_category_violence": glm.HarmCategory.HARM_CATEGORY_VIOLENCE, + "violence": glm.HarmCategory.HARM_CATEGORY_VIOLENCE, + "violent": glm.HarmCategory.HARM_CATEGORY_VIOLENCE, + + glm.HarmCategory.HARM_CATEGORY_SEXUAL: glm.HarmCategory.HARM_CATEGORY_SEXUAL, + HarmCategory.HARM_CATEGORY_SEXUAL: glm.HarmCategory.HARM_CATEGORY_SEXUAL, + 4: glm.HarmCategory.HARM_CATEGORY_SEXUAL, + "harm_category_sexual": glm.HarmCategory.HARM_CATEGORY_SEXUAL, + "sexual": glm.HarmCategory.HARM_CATEGORY_SEXUAL, + "sex": glm.HarmCategory.HARM_CATEGORY_SEXUAL, + + glm.HarmCategory.HARM_CATEGORY_MEDICAL: glm.HarmCategory.HARM_CATEGORY_MEDICAL, + HarmCategory.HARM_CATEGORY_MEDICAL: glm.HarmCategory.HARM_CATEGORY_MEDICAL, + 5: glm.HarmCategory.HARM_CATEGORY_MEDICAL, + "harm_category_medical": glm.HarmCategory.HARM_CATEGORY_MEDICAL, + "medical": glm.HarmCategory.HARM_CATEGORY_MEDICAL, + "med": glm.HarmCategory.HARM_CATEGORY_MEDICAL, + + glm.HarmCategory.HARM_CATEGORY_DANGEROUS: glm.HarmCategory.HARM_CATEGORY_DANGEROUS, + HarmCategory.HARM_CATEGORY_DANGEROUS: glm.HarmCategory.HARM_CATEGORY_DANGEROUS, + 6: glm.HarmCategory.HARM_CATEGORY_DANGEROUS, + "harm_category_dangerous": glm.HarmCategory.HARM_CATEGORY_DANGEROUS, + "dangerous": glm.HarmCategory.HARM_CATEGORY_DANGEROUS, + "danger": glm.HarmCategory.HARM_CATEGORY_DANGEROUS, +} +# fmt: on + + +def to_harm_category(x: HarmCategoryOptions) -> glm.HarmCategory: + if isinstance(x, str): + x = x.lower() + return _HARM_CATEGORIES[x] + + +HarmBlockThresholdOptions = Union[str, int, HarmBlockThreshold] + +# fmt: off +_BLOCK_THRESHOLDS: Dict[HarmBlockThresholdOptions, HarmBlockThreshold] = { + HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED: HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED, + 0: HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED, + "harm_block_threshold_unspecified": HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED, + "block_threshold_unspecified": HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED, + "unspecified": HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED, + + HarmBlockThreshold.BLOCK_LOW_AND_ABOVE: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + 1: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + "block_low_and_above": HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + "low": HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + + HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, + 2: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, + "block_medium_and_above": HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, + "medium": HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, + "med": HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, + + HarmBlockThreshold.BLOCK_ONLY_HIGH: HarmBlockThreshold.BLOCK_ONLY_HIGH, + 3: HarmBlockThreshold.BLOCK_ONLY_HIGH, + "block_only_high": HarmBlockThreshold.BLOCK_ONLY_HIGH, + "high": HarmBlockThreshold.BLOCK_ONLY_HIGH, + + HarmBlockThreshold.BLOCK_NONE: HarmBlockThreshold.BLOCK_NONE, + 4: HarmBlockThreshold.BLOCK_NONE, + "block_none": HarmBlockThreshold.BLOCK_NONE, +} +# fmt: on + + +def to_block_threshold(x: HarmBlockThresholdOptions) -> HarmBlockThreshold: + if isinstance(x, str): + x = x.lower() + return _BLOCK_THRESHOLDS[x] + + +class ContentFilterDict(TypedDict): + reason: BlockedReason + message: str + + __doc__ = string_utils.strip_oneof(glm.ContentFilter.__doc__) + + +def convert_filters_to_enums( + filters: Iterable[dict], +) -> List[ContentFilterDict]: + result = [] + for f in filters: + f = f.copy() + f["reason"] = BlockedReason(f["reason"]) + f = typing.cast(ContentFilterDict, f) + result.append(f) + return result + + +class SafetyRatingDict(TypedDict): + category: glm.HarmCategory + probability: HarmProbability + + __doc__ = string_utils.strip_oneof(glm.SafetyRating.__doc__) + + +def convert_rating_to_enum(rating: dict) -> SafetyRatingDict: + return { + "category": glm.HarmCategory(rating["category"]), + "probability": HarmProbability(rating["probability"]), + } + + +def convert_ratings_to_enum(ratings: Iterable[dict]) -> List[SafetyRatingDict]: + result = [] + for r in ratings: + result.append(convert_rating_to_enum(r)) + return result + + +class SafetySettingDict(TypedDict): + category: glm.HarmCategory + threshold: HarmBlockThreshold + + __doc__ = string_utils.strip_oneof(glm.SafetySetting.__doc__) + + +class LooseSafetySettingDict(TypedDict): + category: HarmCategoryOptions + threshold: HarmBlockThresholdOptions + + +EasySafetySetting = Mapping[HarmCategoryOptions, HarmBlockThresholdOptions] +EasySafetySettingDict = dict[HarmCategoryOptions, HarmBlockThresholdOptions] + +SafetySettingOptions = Union[EasySafetySetting, Iterable[LooseSafetySettingDict], None] + + +def to_easy_safety_dict(settings: SafetySettingOptions) -> EasySafetySettingDict: + if settings is None: + return {} + elif isinstance(settings, Mapping): + return {to_harm_category(key): to_block_threshold(value) for key, value in settings.items()} + else: # Iterable + return { + to_harm_category(d["category"]): to_block_threshold(d["threshold"]) for d in settings + } + + +def normalize_safety_settings( + settings: SafetySettingOptions, +) -> list[SafetySettingDict] | None: + if settings is None: + return None + if isinstance(settings, Mapping): + return [ + { + "category": to_harm_category(key), + "threshold": to_block_threshold(value), + } + for key, value in settings.items() + ] + else: + return [ + { + "category": to_harm_category(d["category"]), + "threshold": to_block_threshold(d["threshold"]), + } + for d in settings + ] + + +def convert_setting_to_enum(setting: dict) -> SafetySettingDict: + return { + "category": glm.HarmCategory(setting["category"]), + "threshold": HarmBlockThreshold(setting["threshold"]), + } + + +class SafetyFeedbackDict(TypedDict): + rating: SafetyRatingDict + setting: SafetySettingDict + + __doc__ = string_utils.strip_oneof(glm.SafetyFeedback.__doc__) + + +def convert_safety_feedback_to_enums( + safety_feedback: Iterable[dict], +) -> List[SafetyFeedbackDict]: + result = [] + for sf in safety_feedback: + result.append( + { + "rating": convert_rating_to_enum(sf["rating"]), + "setting": convert_setting_to_enum(sf["setting"]), + } + ) + return result + + +def convert_candidate_enums(candidates): + result = [] + for candidate in candidates: + candidate = candidate.copy() + candidate["safety_ratings"] = convert_ratings_to_enum(candidate["safety_ratings"]) + result.append(candidate) + return result diff --git a/google/generativeai/types/safety_types.py b/google/generativeai/types/safety_types.py index 7d94a5bb0..85e57c8f6 100644 --- a/google/generativeai/types/safety_types.py +++ b/google/generativeai/types/safety_types.py @@ -1,7 +1,22 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from __future__ import annotations from collections.abc import Mapping +import enum import typing from typing import Dict, Iterable, List, Union @@ -24,105 +39,72 @@ ] # These are basic python enums, it's okay to expose them -HarmCategory = glm.HarmCategory HarmProbability = glm.SafetyRating.HarmProbability HarmBlockThreshold = glm.SafetySetting.HarmBlockThreshold BlockedReason = glm.ContentFilter.BlockedReason +import proto + + +class HarmCategory(proto.Enum): + """ + Harm Categories supported by the gemini-family model + """ + + HARM_CATEGORY_UNSPECIFIED = glm.HarmCategory.HARM_CATEGORY_UNSPECIFIED.value + HARM_CATEGORY_HARASSMENT = glm.HarmCategory.HARM_CATEGORY_HARASSMENT.value + HARM_CATEGORY_HATE_SPEECH = glm.HarmCategory.HARM_CATEGORY_HATE_SPEECH.value + HARM_CATEGORY_SEXUALLY_EXPLICIT = glm.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT.value + HARM_CATEGORY_DANGEROUS_CONTENT = glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT.value + + HarmCategoryOptions = Union[str, int, HarmCategory] # fmt: off -_OLD_HARM_CATEGORIES: Dict[HarmCategoryOptions, HarmCategory] = { - HarmCategory.HARM_CATEGORY_UNSPECIFIED: HarmCategory.HARM_CATEGORY_UNSPECIFIED, - 0: HarmCategory.HARM_CATEGORY_UNSPECIFIED, - "harm_category_unspecified": HarmCategory.HARM_CATEGORY_UNSPECIFIED, - "unspecified": HarmCategory.HARM_CATEGORY_UNSPECIFIED, - - HarmCategory.HARM_CATEGORY_DEROGATORY: HarmCategory.HARM_CATEGORY_DEROGATORY, - 1: HarmCategory.HARM_CATEGORY_DEROGATORY, - "harm_category_derogatory": HarmCategory.HARM_CATEGORY_DEROGATORY, - "derogatory": HarmCategory.HARM_CATEGORY_DEROGATORY, - - HarmCategory.HARM_CATEGORY_TOXICITY: HarmCategory.HARM_CATEGORY_TOXICITY, - 2: HarmCategory.HARM_CATEGORY_TOXICITY, - "harm_category_toxicity": HarmCategory.HARM_CATEGORY_TOXICITY, - "toxicity": HarmCategory.HARM_CATEGORY_TOXICITY, - "toxic": HarmCategory.HARM_CATEGORY_TOXICITY, - - HarmCategory.HARM_CATEGORY_VIOLENCE: HarmCategory.HARM_CATEGORY_VIOLENCE, - 3: HarmCategory.HARM_CATEGORY_VIOLENCE, - "harm_category_violence": HarmCategory.HARM_CATEGORY_VIOLENCE, - "violence": HarmCategory.HARM_CATEGORY_VIOLENCE, - "violent": HarmCategory.HARM_CATEGORY_VIOLENCE, - - HarmCategory.HARM_CATEGORY_SEXUAL: HarmCategory.HARM_CATEGORY_SEXUAL, - 4: HarmCategory.HARM_CATEGORY_SEXUAL, - "harm_category_sexual": HarmCategory.HARM_CATEGORY_SEXUAL, - "sexual": HarmCategory.HARM_CATEGORY_SEXUAL, - "sex": HarmCategory.HARM_CATEGORY_SEXUAL, - - HarmCategory.HARM_CATEGORY_MEDICAL: HarmCategory.HARM_CATEGORY_MEDICAL, - 5: HarmCategory.HARM_CATEGORY_MEDICAL, - "harm_category_medical": HarmCategory.HARM_CATEGORY_MEDICAL, - "medical": HarmCategory.HARM_CATEGORY_MEDICAL, - "med": HarmCategory.HARM_CATEGORY_MEDICAL, - - HarmCategory.HARM_CATEGORY_DANGEROUS: HarmCategory.HARM_CATEGORY_DANGEROUS, - 6: HarmCategory.HARM_CATEGORY_DANGEROUS, - "harm_category_dangerous": HarmCategory.HARM_CATEGORY_DANGEROUS, - "dangerous": HarmCategory.HARM_CATEGORY_DANGEROUS, - "danger": HarmCategory.HARM_CATEGORY_DANGEROUS, -} - -_NEW_HARM_CATEGORIES = { - 7: HarmCategory.HARM_CATEGORY_HARASSMENT, - HarmCategory.HARM_CATEGORY_HARASSMENT: HarmCategory.HARM_CATEGORY_HARASSMENT, - "harm_category_harassment": HarmCategory.HARM_CATEGORY_HARASSMENT, - "harassment": HarmCategory.HARM_CATEGORY_HARASSMENT, - - 8: HarmCategory.HARM_CATEGORY_HATE_SPEECH, - HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmCategory.HARM_CATEGORY_HATE_SPEECH, - 'harm_category_hate_speech': HarmCategory.HARM_CATEGORY_HATE_SPEECH, - 'hate_speech': HarmCategory.HARM_CATEGORY_HATE_SPEECH, - 'hate': HarmCategory.HARM_CATEGORY_HATE_SPEECH, - - 9: HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, - HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, - "harm_category_sexually_explicit": HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, - "harm_category_sexual": HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, - "sexually_explicit": HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, - "sexual": HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, - "sex": HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, - - 10: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - "harm_category_dangerous_content": HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - "harm_category_dangerous": HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - "dangerous": HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - "danger": HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, +_HARM_CATEGORIES: Dict[HarmCategoryOptions, glm.HarmCategory] = { + glm.HarmCategory.HARM_CATEGORY_UNSPECIFIED: glm.HarmCategory.HARM_CATEGORY_UNSPECIFIED, + HarmCategory.HARM_CATEGORY_UNSPECIFIED: glm.HarmCategory.HARM_CATEGORY_UNSPECIFIED, + 0: glm.HarmCategory.HARM_CATEGORY_UNSPECIFIED, + "harm_category_unspecified": glm.HarmCategory.HARM_CATEGORY_UNSPECIFIED, + "unspecified": glm.HarmCategory.HARM_CATEGORY_UNSPECIFIED, + + 7: glm.HarmCategory.HARM_CATEGORY_HARASSMENT, + glm.HarmCategory.HARM_CATEGORY_HARASSMENT: glm.HarmCategory.HARM_CATEGORY_HARASSMENT, + HarmCategory.HARM_CATEGORY_HARASSMENT: glm.HarmCategory.HARM_CATEGORY_HARASSMENT, + "harm_category_harassment": glm.HarmCategory.HARM_CATEGORY_HARASSMENT, + "harassment": glm.HarmCategory.HARM_CATEGORY_HARASSMENT, + + 8: glm.HarmCategory.HARM_CATEGORY_HATE_SPEECH, + glm.HarmCategory.HARM_CATEGORY_HATE_SPEECH: glm.HarmCategory.HARM_CATEGORY_HATE_SPEECH, + HarmCategory.HARM_CATEGORY_HATE_SPEECH: glm.HarmCategory.HARM_CATEGORY_HATE_SPEECH, + 'harm_category_hate_speech': glm.HarmCategory.HARM_CATEGORY_HATE_SPEECH, + 'hate_speech': glm.HarmCategory.HARM_CATEGORY_HATE_SPEECH, + 'hate': glm.HarmCategory.HARM_CATEGORY_HATE_SPEECH, + + 9: glm.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + glm.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: glm.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: glm.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + "harm_category_sexually_explicit": glm.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + "harm_category_sexual": glm.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + "sexually_explicit": glm.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + "sexual": glm.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + "sex": glm.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + + 10: glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + "harm_category_dangerous_content": glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + "harm_category_dangerous": glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + "dangerous": glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + "danger": glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, } # fmt: on -def to_old_harm_category(x: HarmCategoryOptions) -> HarmCategory: - if isinstance(x, str): - x = x.lower() - return _OLD_HARM_CATEGORIES[x] - - -def to_new_harm_category(x: HarmCategoryOptions) -> HarmCategory: +def to_harm_category(x: HarmCategoryOptions) -> glm.HarmCategory: if isinstance(x, str): x = x.lower() - return _NEW_HARM_CATEGORIES[x] - - -def to_harm_category(x, harm_category_set): - if harm_category_set == "old": - return to_old_harm_category(x) - elif harm_category_set == "new": - return to_new_harm_category(x) - else: - raise ValueError("harm_category_set must be 'new' or 'old'") + return _HARM_CATEGORIES[x] HarmBlockThresholdOptions = Union[str, int, HarmBlockThreshold] @@ -158,7 +140,7 @@ def to_harm_category(x, harm_category_set): # fmt: on -def to_block_threshold(x: HarmBlockThresholdOptions) -> HarmCategory: +def to_block_threshold(x: HarmBlockThresholdOptions) -> HarmBlockThreshold: if isinstance(x, str): x = x.lower() return _BLOCK_THRESHOLDS[x] @@ -184,7 +166,7 @@ def convert_filters_to_enums( class SafetyRatingDict(TypedDict): - category: HarmCategory + category: glm.HarmCategory probability: HarmProbability __doc__ = string_utils.strip_oneof(glm.SafetyRating.__doc__) @@ -192,7 +174,7 @@ class SafetyRatingDict(TypedDict): def convert_rating_to_enum(rating: dict) -> SafetyRatingDict: return { - "category": HarmCategory(rating["category"]), + "category": glm.HarmCategory(rating["category"]), "probability": HarmProbability(rating["probability"]), } @@ -205,7 +187,7 @@ def convert_ratings_to_enum(ratings: Iterable[dict]) -> List[SafetyRatingDict]: class SafetySettingDict(TypedDict): - category: HarmCategory + category: glm.HarmCategory threshold: HarmBlockThreshold __doc__ = string_utils.strip_oneof(glm.SafetySetting.__doc__) @@ -222,31 +204,26 @@ class LooseSafetySettingDict(TypedDict): SafetySettingOptions = Union[EasySafetySetting, Iterable[LooseSafetySettingDict], None] -def to_easy_safety_dict(settings: SafetySettingOptions, harm_category_set) -> EasySafetySettingDict: +def to_easy_safety_dict(settings: SafetySettingOptions) -> EasySafetySettingDict: if settings is None: return {} elif isinstance(settings, Mapping): - return { - to_harm_category(key, harm_category_set): to_block_threshold(value) - for key, value in settings.items() - } + return {to_harm_category(key): to_block_threshold(value) for key, value in settings.items()} else: # Iterable return { - to_harm_category(d["category"], harm_category_set): to_block_threshold(d["threshold"]) - for d in settings + to_harm_category(d["category"]): to_block_threshold(d["threshold"]) for d in settings } def normalize_safety_settings( settings: SafetySettingOptions, - harm_category_set, ) -> list[SafetySettingDict] | None: if settings is None: return None if isinstance(settings, Mapping): return [ { - "category": to_harm_category(key, harm_category_set), + "category": to_harm_category(key), "threshold": to_block_threshold(value), } for key, value in settings.items() @@ -254,7 +231,7 @@ def normalize_safety_settings( else: return [ { - "category": to_harm_category(d["category"], harm_category_set), + "category": to_harm_category(d["category"]), "threshold": to_block_threshold(d["threshold"]), } for d in settings @@ -263,7 +240,7 @@ def normalize_safety_settings( def convert_setting_to_enum(setting: dict) -> SafetySettingDict: return { - "category": HarmCategory(setting["category"]), + "category": glm.HarmCategory(setting["category"]), "threshold": HarmBlockThreshold(setting["threshold"]), } diff --git a/google/generativeai/types/text_types.py b/google/generativeai/types/text_types.py index f66c0fb32..61804fcaa 100644 --- a/google/generativeai/types/text_types.py +++ b/google/generativeai/types/text_types.py @@ -21,7 +21,7 @@ from typing_extensions import TypedDict from google.generativeai import string_utils -from google.generativeai.types import safety_types +from google.generativeai.types import palm_safety_types from google.generativeai.types import citation_types @@ -42,7 +42,7 @@ class BatchEmbeddingDict(TypedDict): class TextCompletion(TypedDict, total=False): output: str - safety_ratings: List[safety_types.SafetyRatingDict | None] + safety_ratings: List[palm_safety_types.SafetyRatingDict | None] citation_metadata: citation_types.CitationMetadataDict | None @@ -63,8 +63,8 @@ class Completion(abc.ABC): candidates: List[TextCompletion] result: str | None - filters: List[safety_types.ContentFilterDict | None] - safety_feedback: List[safety_types.SafetyFeedbackDict | None] + filters: List[palm_safety_types.ContentFilterDict | None] + safety_feedback: List[palm_safety_types.SafetyFeedbackDict | None] def to_dict(self) -> Dict[str, Any]: result = { diff --git a/tests/test_discuss.py b/tests/test_discuss.py index 9d628a42c..183ccd0c3 100644 --- a/tests/test_discuss.py +++ b/tests/test_discuss.py @@ -22,7 +22,7 @@ from google.generativeai import discuss from google.generativeai import client import google.generativeai as genai -from google.generativeai.types import safety_types +from google.generativeai.types import palm_safety_types from absl.testing import absltest from absl.testing import parameterized @@ -289,32 +289,32 @@ def test_receive_and_reply_with_filters(self): self.mock_response = mock_response = glm.GenerateMessageResponse( candidates=[glm.Message(content="a", author="1")], filters=[ - glm.ContentFilter(reason=safety_types.BlockedReason.SAFETY, message="unsafe"), - glm.ContentFilter(reason=safety_types.BlockedReason.OTHER), + glm.ContentFilter(reason=palm_safety_types.BlockedReason.SAFETY, message="unsafe"), + glm.ContentFilter(reason=palm_safety_types.BlockedReason.OTHER), ], ) response = discuss.chat(messages="do filters work?") filters = response.filters self.assertLen(filters, 2) - self.assertIsInstance(filters[0]["reason"], safety_types.BlockedReason) - self.assertEqual(filters[0]["reason"], safety_types.BlockedReason.SAFETY) + self.assertIsInstance(filters[0]["reason"], palm_safety_types.BlockedReason) + self.assertEqual(filters[0]["reason"], palm_safety_types.BlockedReason.SAFETY) self.assertEqual(filters[0]["message"], "unsafe") self.mock_response = glm.GenerateMessageResponse( candidates=[glm.Message(content="a", author="1")], filters=[ - glm.ContentFilter(reason=safety_types.BlockedReason.BLOCKED_REASON_UNSPECIFIED) + glm.ContentFilter(reason=palm_safety_types.BlockedReason.BLOCKED_REASON_UNSPECIFIED) ], ) response = response.reply("Does reply work?") filters = response.filters self.assertLen(filters, 1) - self.assertIsInstance(filters[0]["reason"], safety_types.BlockedReason) + self.assertIsInstance(filters[0]["reason"], palm_safety_types.BlockedReason) self.assertEqual( filters[0]["reason"], - safety_types.BlockedReason.BLOCKED_REASON_UNSPECIFIED, + palm_safety_types.BlockedReason.BLOCKED_REASON_UNSPECIFIED, ) def test_chat_citations(self): diff --git a/tests/test_text.py b/tests/test_text.py index 0bc1d4e59..5dcda93b9 100644 --- a/tests/test_text.py +++ b/tests/test_text.py @@ -22,7 +22,7 @@ from google.generativeai import text as text_service from google.generativeai import client -from google.generativeai.types import safety_types +from google.generativeai.types import palm_safety_types from google.generativeai.types import model_types from absl.testing import absltest from absl.testing import parameterized @@ -246,12 +246,12 @@ def test_stop_string(self): testcase_name="basic", safety_settings=[ { - "category": safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, - "threshold": safety_types.HarmBlockThreshold.BLOCK_NONE, + "category": palm_safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, + "threshold": palm_safety_types.HarmBlockThreshold.BLOCK_NONE, }, { - "category": safety_types.HarmCategory.HARM_CATEGORY_VIOLENCE, - "threshold": safety_types.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + "category": palm_safety_types.HarmCategory.HARM_CATEGORY_VIOLENCE, + "threshold": palm_safety_types.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, }, ], ), @@ -275,8 +275,8 @@ def test_stop_string(self): dict( testcase_name="mixed", safety_settings={ - "medical": safety_types.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, - safety_types.HarmCategory.HARM_CATEGORY_VIOLENCE: 1, + "medical": palm_safety_types.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + palm_safety_types.HarmCategory.HARM_CATEGORY_VIOLENCE: 1, }, ), ] @@ -294,7 +294,7 @@ def test_safety_settings(self, safety_settings): self.assertEqual( self.observed_requests[-1].safety_settings[0].category, - safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, + palm_safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, ) def test_filters(self): @@ -302,15 +302,15 @@ def test_filters(self): candidates=[{"output": "hello"}], filters=[ { - "reason": safety_types.BlockedReason.SAFETY, + "reason": palm_safety_types.BlockedReason.SAFETY, "message": "not safe", } ], ) response = text_service.generate_text(prompt="do filters work?") - self.assertIsInstance(response.filters[0]["reason"], safety_types.BlockedReason) - self.assertEqual(response.filters[0]["reason"], safety_types.BlockedReason.SAFETY) + self.assertIsInstance(response.filters[0]["reason"], palm_safety_types.BlockedReason) + self.assertEqual(response.filters[0]["reason"], palm_safety_types.BlockedReason.SAFETY) def test_safety_feedback(self): self.responses["generate_text"] = glm.GenerateTextResponse( @@ -318,12 +318,12 @@ def test_safety_feedback(self): safety_feedback=[ { "rating": { - "category": safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, - "probability": safety_types.HarmProbability.HIGH, + "category": palm_safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, + "probability": palm_safety_types.HarmProbability.HIGH, }, "setting": { - "category": safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, - "threshold": safety_types.HarmBlockThreshold.BLOCK_NONE, + "category": palm_safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, + "threshold": palm_safety_types.HarmBlockThreshold.BLOCK_NONE, }, } ], @@ -332,20 +332,20 @@ def test_safety_feedback(self): response = text_service.generate_text(prompt="does safety feedback work?") self.assertIsInstance( response.safety_feedback[0]["rating"]["probability"], - safety_types.HarmProbability, + palm_safety_types.HarmProbability, ) self.assertEqual( response.safety_feedback[0]["rating"]["probability"], - safety_types.HarmProbability.HIGH, + palm_safety_types.HarmProbability.HIGH, ) self.assertIsInstance( response.safety_feedback[0]["setting"]["category"], - safety_types.HarmCategory, + glm.HarmCategory, ) self.assertEqual( response.safety_feedback[0]["setting"]["category"], - safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, + palm_safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, ) def test_candidate_safety_feedback(self): @@ -355,12 +355,12 @@ def test_candidate_safety_feedback(self): "output": "hello", "safety_ratings": [ { - "category": safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, - "probability": safety_types.HarmProbability.HIGH, + "category": palm_safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, + "probability": palm_safety_types.HarmProbability.HIGH, }, { - "category": safety_types.HarmCategory.HARM_CATEGORY_VIOLENCE, - "probability": safety_types.HarmProbability.LOW, + "category": palm_safety_types.HarmCategory.HARM_CATEGORY_VIOLENCE, + "probability": palm_safety_types.HarmProbability.LOW, }, ], } @@ -370,20 +370,20 @@ def test_candidate_safety_feedback(self): result = text_service.generate_text(prompt="Write a story from the ER.") self.assertIsInstance( result.candidates[0]["safety_ratings"][0]["category"], - safety_types.HarmCategory, + glm.HarmCategory, ) self.assertEqual( result.candidates[0]["safety_ratings"][0]["category"], - safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, + palm_safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, ) self.assertIsInstance( result.candidates[0]["safety_ratings"][0]["probability"], - safety_types.HarmProbability, + palm_safety_types.HarmProbability, ) self.assertEqual( result.candidates[0]["safety_ratings"][0]["probability"], - safety_types.HarmProbability.HIGH, + palm_safety_types.HarmProbability.HIGH, ) def test_candidate_citations(self):