Skip to content

Commit cfdbb24

Browse files
authored
feat: add imagen samples: generate image from prompt, get captions, g… (GoogleCloudPlatform#11247)
* feat: add imagen samples: generate image from prompt, get captions, get responses (vqa) * change test question for image responses * Trigger Build
1 parent da907f0 commit cfdbb24

6 files changed

+363
-0
lines changed
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Google Cloud Vertex AI sample for generating an image using only
16+
descriptive text as an input.
17+
Example usage:
18+
python generate_image.py --project_id <project-id> --location <location> \
19+
--output_file <filepath> --prompt <text>
20+
"""
21+
22+
# [START aiplatform_imagen_generate_image]
23+
24+
import argparse
25+
26+
import vertexai
27+
from vertexai.preview.vision_models import ImageGenerationModel
28+
29+
30+
def generate_image(
31+
project_id: str, location: str, output_file: str, prompt: str
32+
) -> vertexai.preview.vision_models.ImageGenerationResponse:
33+
"""Generate an image using a text prompt.
34+
Args:
35+
project_id: Google Cloud project ID, used to initialize Vertex AI.
36+
location: Google Cloud region, used to initialize Vertex AI.
37+
output_file: Local path to the output image file.
38+
prompt: The text prompt describing what you want to see."""
39+
40+
vertexai.init(project=project_id, location=location)
41+
42+
model = ImageGenerationModel.from_pretrained("imagegeneration@005")
43+
44+
images = model.generate_images(
45+
prompt=prompt,
46+
# Optional parameters
47+
seed=1,
48+
number_of_images=1,
49+
)
50+
51+
images[0].save(location=output_file, include_generation_parameters=True)
52+
53+
# Optional. View the generated image in a notebook.
54+
# images[0].show()
55+
56+
print(f"Created output image using {len(images[0]._image_bytes)} bytes")
57+
58+
return images
59+
60+
61+
# [END aiplatform_imagen_generate_image]
62+
63+
if __name__ == "__main__":
64+
parser = argparse.ArgumentParser()
65+
parser.add_argument("--project_id", help="Your Cloud project ID.", required=True)
66+
parser.add_argument(
67+
"--location",
68+
help="The location in which to initialize Vertex AI.",
69+
default="us-central1",
70+
)
71+
parser.add_argument(
72+
"--output_file",
73+
help="The local path to the output file (e.g., 'my-output.png').",
74+
required=True,
75+
)
76+
parser.add_argument(
77+
"--prompt",
78+
help="The text prompt describing what you want to see (e.g., 'a dog reading a newspaper').",
79+
required=True,
80+
)
81+
args = parser.parse_args()
82+
generate_image(
83+
args.project_id,
84+
args.location,
85+
args.output_file,
86+
args.prompt,
87+
)
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
17+
import backoff
18+
19+
import generate_image
20+
21+
from google.api_core.exceptions import ResourceExhausted
22+
23+
24+
_RESOURCES = os.path.join(os.path.dirname(__file__), "test_resources")
25+
_PROJECT_ID = os.getenv("GOOGLE_CLOUD_PROJECT")
26+
_LOCATION = "us-central1"
27+
_OUTPUT_FILE = os.path.join(_RESOURCES, "dog_newspaper.png")
28+
_PROMPT = "a dog reading a newspaper"
29+
30+
31+
@backoff.on_exception(backoff.expo, ResourceExhausted, max_time=60)
32+
def test_generate_image() -> None:
33+
response = generate_image.generate_image(
34+
_PROJECT_ID,
35+
_LOCATION,
36+
_OUTPUT_FILE,
37+
_PROMPT,
38+
)
39+
40+
assert len(response[0]._image_bytes) > 1000
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Google Cloud Vertex AI sample for getting short-form image captions.
16+
Example usage:
17+
python get_short_form_image_captions.py --project_id <project-id> --location <location> \
18+
--input_file <filepath>
19+
"""
20+
21+
# [START aiplatform_imagen_get_short_form_image_captions]
22+
23+
import argparse
24+
25+
import vertexai
26+
from vertexai.preview.vision_models import Image, ImageTextModel
27+
28+
29+
def get_short_form_image_captions(
30+
project_id: str, location: str, input_file: str
31+
) -> list:
32+
"""Get short-form captions for a local image.
33+
Args:
34+
project_id: Google Cloud project ID, used to initialize Vertex AI.
35+
location: Google Cloud region, used to initialize Vertex AI.
36+
input_file: Local path to the input image file."""
37+
38+
vertexai.init(project=project_id, location=location)
39+
40+
model = ImageTextModel.from_pretrained("imagetext@001")
41+
source_img = Image.load_from_file(location=input_file)
42+
43+
captions = model.get_captions(
44+
image=source_img,
45+
# Optional parameters
46+
language="en",
47+
number_of_results=1,
48+
)
49+
50+
print(captions)
51+
52+
return captions
53+
54+
55+
# [END aiplatform_imagen_get_short_form_image_captions]
56+
57+
if __name__ == "__main__":
58+
parser = argparse.ArgumentParser()
59+
parser.add_argument("--project_id", help="Your Cloud project ID.", required=True)
60+
parser.add_argument(
61+
"--location",
62+
help="The location in which to initialize Vertex AI.",
63+
default="us-central1",
64+
)
65+
parser.add_argument(
66+
"--input_file",
67+
help="The local path to the input file (e.g., 'my-input.png').",
68+
required=True,
69+
)
70+
args = parser.parse_args()
71+
get_short_form_image_captions(
72+
args.project_id,
73+
args.location,
74+
args.input_file,
75+
)
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
17+
import backoff
18+
19+
import get_short_form_image_captions
20+
21+
from google.api_core.exceptions import ResourceExhausted
22+
23+
24+
_RESOURCES = os.path.join(os.path.dirname(__file__), "test_resources")
25+
_PROJECT_ID = os.getenv("GOOGLE_CLOUD_PROJECT")
26+
_LOCATION = "us-central1"
27+
_INPUT_FILE = os.path.join(_RESOURCES, "cat.png")
28+
29+
30+
@backoff.on_exception(backoff.expo, ResourceExhausted, max_time=60)
31+
def test_get_short_form_image_captions() -> None:
32+
response = get_short_form_image_captions.get_short_form_image_captions(
33+
_PROJECT_ID,
34+
_LOCATION,
35+
_INPUT_FILE,
36+
)
37+
38+
assert len(response) > 0 and "cat" in response[0]
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Google Cloud Vertex AI sample for getting short-form responses to a
16+
question about an image.
17+
Example usage:
18+
python get_short_form_image_responses.py --project_id <project-id> --location <location> \
19+
--input_file <filepath> --question <text>
20+
"""
21+
22+
# [START aiplatform_imagen_get_short_form_image_responses]
23+
24+
import argparse
25+
26+
import vertexai
27+
from vertexai.preview.vision_models import Image, ImageTextModel
28+
29+
30+
def get_short_form_image_responses(
31+
project_id: str, location: str, input_file: str, question: str
32+
) -> list:
33+
"""Get short-form responses to a question about a local image.
34+
Args:
35+
project_id: Google Cloud project ID, used to initialize Vertex AI.
36+
location: Google Cloud region, used to initialize Vertex AI.
37+
input_file: Local path to the input image file.
38+
question: The question about the contents of the image."""
39+
40+
vertexai.init(project=project_id, location=location)
41+
42+
model = ImageTextModel.from_pretrained("imagetext@001")
43+
source_img = Image.load_from_file(location=input_file)
44+
45+
answers = model.ask_question(
46+
image=source_img,
47+
question=question,
48+
# Optional parameters
49+
number_of_results=1,
50+
)
51+
52+
print(answers)
53+
54+
return answers
55+
56+
57+
# [END aiplatform_imagen_get_short_form_image_responses]
58+
59+
if __name__ == "__main__":
60+
parser = argparse.ArgumentParser()
61+
parser.add_argument("--project_id", help="Your Cloud project ID.", required=True)
62+
parser.add_argument(
63+
"--location",
64+
help="The location in which to initialize Vertex AI.",
65+
default="us-central1",
66+
)
67+
parser.add_argument(
68+
"--input_file",
69+
help="The local path to the input file (e.g., 'my-input.png').",
70+
required=True,
71+
)
72+
parser.add_argument(
73+
"--question",
74+
help="The question about the image (e.g., 'What breed of dog is this a picture of?').",
75+
required=True,
76+
)
77+
args = parser.parse_args()
78+
get_short_form_image_responses(
79+
args.project_id,
80+
args.location,
81+
args.input_file,
82+
args.question,
83+
)
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
17+
import backoff
18+
19+
import get_short_form_image_responses
20+
21+
from google.api_core.exceptions import ResourceExhausted
22+
23+
24+
_RESOURCES = os.path.join(os.path.dirname(__file__), "test_resources")
25+
_PROJECT_ID = os.getenv("GOOGLE_CLOUD_PROJECT")
26+
_LOCATION = "us-central1"
27+
_INPUT_FILE = os.path.join(_RESOURCES, "cat.png")
28+
_QUESTION = "What breed of cat is this a picture of?"
29+
30+
31+
@backoff.on_exception(backoff.expo, ResourceExhausted, max_time=60)
32+
def test_get_short_form_image_responses() -> None:
33+
response = get_short_form_image_responses.get_short_form_image_responses(
34+
_PROJECT_ID,
35+
_LOCATION,
36+
_INPUT_FILE,
37+
_QUESTION,
38+
)
39+
40+
assert len(response) > 0 and "tabby" in response[0]

0 commit comments

Comments
 (0)