Skip to content

Commit 9d6d580

Browse files
authored
feat: add image inpainting remove mask mode sample and test for Imagen 2 (GoogleCloudPlatform#11384)
1 parent 8d97af7 commit 9d6d580

File tree

3 files changed

+161
-0
lines changed

3 files changed

+161
-0
lines changed
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Google Cloud Vertex AI sample for editing an image using a mask mode. The
16+
mask mode is used to automatically select the background, foreground (i.e.,
17+
the primary subject of the image), or an object based on segmentation class.
18+
Inpainting can remove the object or background designated by the prompt.
19+
Example usage:
20+
python edit_image_inpainting_remove_mask_mode.py --project_id <project-id> \
21+
--location <location> --input_file <filepath> --mask_mode <mode> \
22+
--output_file <filepath> [--prompt <text>]
23+
"""
24+
25+
# [START generativeaionvertexai_imagen_edit_image_inpainting_remove_mask_mode]
26+
27+
import argparse
28+
29+
import vertexai
30+
from vertexai.preview.vision_models import Image, ImageGenerationModel
31+
32+
33+
def edit_image_inpainting_remove_mask_mode(
34+
project_id: str,
35+
location: str,
36+
input_file: str,
37+
mask_mode: str,
38+
output_file: str,
39+
prompt: str,
40+
) -> vertexai.preview.vision_models.ImageGenerationResponse:
41+
"""Edit a local image by removing an object using a mask.
42+
Args:
43+
project_id: Google Cloud project ID, used to initialize Vertex AI.
44+
location: Google Cloud region, used to initialize Vertex AI.
45+
input_file: Local path to the input image file. Image can be in PNG or JPEG format.
46+
mask_mode: Mask generation mode ('background', 'foreground', or 'semantic').
47+
output_file: Local path to the output image file.
48+
prompt: The optional text prompt describing what you want to see in the edited image.
49+
"""
50+
51+
vertexai.init(project=project_id, location=location)
52+
53+
model = ImageGenerationModel.from_pretrained("imagegeneration@006")
54+
base_img = Image.load_from_file(location=input_file)
55+
56+
images = model.edit_image(
57+
base_image=base_img,
58+
mask_mode=mask_mode,
59+
prompt=prompt,
60+
edit_mode="inpainting-remove",
61+
# Optional parameters
62+
# For semantic mask mode, define the segmentation class IDs:
63+
# segmentation_classes=[7], # a cat
64+
# See https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/image-generation#segment-ids.
65+
)
66+
67+
images[0].save(location=output_file)
68+
69+
# Optional. View the edited image in a notebook.
70+
# images[0].show()
71+
72+
print(f"Created output image using {len(images[0]._image_bytes)} bytes")
73+
74+
return images
75+
76+
77+
# [END generativeaionvertexai_imagen_edit_image_inpainting_remove_mask_mode]
78+
79+
if __name__ == "__main__":
80+
parser = argparse.ArgumentParser()
81+
parser.add_argument("--project_id", help="Your Cloud project ID.", required=True)
82+
parser.add_argument(
83+
"--location",
84+
help="The location in which to initialize Vertex AI.",
85+
default="us-central1",
86+
)
87+
parser.add_argument(
88+
"--input_file",
89+
help="The local path to the input file (e.g., 'my-input.png').",
90+
required=True,
91+
)
92+
parser.add_argument(
93+
"--mask_mode",
94+
help="The mask generation mode ('background', 'foreground', or 'semantic').",
95+
required=True,
96+
)
97+
parser.add_argument(
98+
"--output_file",
99+
help="The local path to the output file (e.g., 'my-output.png').",
100+
required=True,
101+
)
102+
parser.add_argument(
103+
"--prompt",
104+
help="The optional text prompt describing what you want to see in the edited image.",
105+
default="",
106+
)
107+
args = parser.parse_args()
108+
edit_image_inpainting_remove_mask_mode(
109+
args.project_id,
110+
args.location,
111+
args.input_file,
112+
args.mask_mode,
113+
args.output_file,
114+
args.prompt,
115+
)
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
17+
import backoff
18+
19+
import edit_image_inpainting_remove_mask_mode
20+
21+
from google.api_core.exceptions import ResourceExhausted
22+
23+
24+
_RESOURCES = os.path.join(os.path.dirname(__file__), "test_resources")
25+
_PROJECT_ID = os.getenv("GOOGLE_CLOUD_PROJECT")
26+
_LOCATION = "us-central1"
27+
_INPUT_FILE = os.path.join(_RESOURCES, "woman.png")
28+
_MASK_MODE = "foreground"
29+
_OUTPUT_FILE = os.path.join(_RESOURCES, "sports_car.png")
30+
_PROMPT = "sports car"
31+
32+
33+
@backoff.on_exception(backoff.expo, ResourceExhausted, max_time=60)
34+
def test_edit_image_inpainting_remove_mask_mode() -> None:
35+
response = (
36+
edit_image_inpainting_remove_mask_mode.edit_image_inpainting_remove_mask_mode(
37+
_PROJECT_ID,
38+
_LOCATION,
39+
_INPUT_FILE,
40+
_MASK_MODE,
41+
_OUTPUT_FILE,
42+
_PROMPT,
43+
)
44+
)
45+
46+
assert len(response[0]._image_bytes) > 1000
721 KB
Loading

0 commit comments

Comments
 (0)