Skip to content

Commit 3c5c702

Browse files
committed
add text generation tutorial with transformers
1 parent 248f882 commit 3c5c702

File tree

4 files changed

+299
-0
lines changed

4 files changed

+299
-0
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# [Text Generation with Transformers in Python](https://www.thepythoncode.com/article/text-generation-with-transformers-in-python)
Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
{
2+
"nbformat": 4,
3+
"nbformat_minor": 0,
4+
"metadata": {
5+
"colab": {
6+
"name": "TextGeneration-Transformers-PythonCodeTutorial.ipynb",
7+
"private_outputs": true,
8+
"provenance": [],
9+
"collapsed_sections": [],
10+
"machine_shape": "hm"
11+
},
12+
"kernelspec": {
13+
"name": "python3",
14+
"display_name": "Python 3"
15+
},
16+
"language_info": {
17+
"name": "python"
18+
}
19+
},
20+
"cells": [
21+
{
22+
"cell_type": "code",
23+
"metadata": {
24+
"id": "6bjli5Z7ZEVh"
25+
},
26+
"source": [
27+
"!pip install transformers"
28+
],
29+
"execution_count": null,
30+
"outputs": []
31+
},
32+
{
33+
"cell_type": "code",
34+
"metadata": {
35+
"id": "SPADZcRSY-3Y"
36+
},
37+
"source": [
38+
"from transformers import pipeline"
39+
],
40+
"execution_count": null,
41+
"outputs": []
42+
},
43+
{
44+
"cell_type": "code",
45+
"metadata": {
46+
"id": "k0zHPjIkqcEx"
47+
},
48+
"source": [
49+
"# download & load GPT-2 model\n",
50+
"gpt2_generator = pipeline('text-generation', model='gpt2')"
51+
],
52+
"execution_count": null,
53+
"outputs": []
54+
},
55+
{
56+
"cell_type": "code",
57+
"metadata": {
58+
"id": "me1PAIvlqwKf"
59+
},
60+
"source": [
61+
"# generate 3 different sentences\n",
62+
"# results are sampled from the top 50 candidates\n",
63+
"sentences = gpt2_generator(\"To be honest, neural networks\", do_sample=True, top_k=50, temperature=0.6, max_length=128, num_return_sequences=3)\n",
64+
"for sentence in sentences:\n",
65+
" print(sentence[\"generated_text\"])\n",
66+
" print(\"=\"*50)"
67+
],
68+
"execution_count": null,
69+
"outputs": []
70+
},
71+
{
72+
"cell_type": "code",
73+
"metadata": {
74+
"id": "aXI92oauZCD4"
75+
},
76+
"source": [
77+
"# download & load GPT-J model! It's 22.5GB in size\n",
78+
"gpt_j_generator = pipeline('text-generation', model='EleutherAI/gpt-j-6B')"
79+
],
80+
"execution_count": null,
81+
"outputs": []
82+
},
83+
{
84+
"cell_type": "code",
85+
"metadata": {
86+
"id": "EaOAqXnXtOI0"
87+
},
88+
"source": [
89+
"# generate sentences with TOP-K sampling\n",
90+
"sentences = gpt_j_generator(\"To be honest, robots will\", do_sample=True, top_k=50, temperature=0.6, max_length=128, num_return_sequences=3)\n",
91+
"for sentence in sentences:\n",
92+
" print(sentence[\"generated_text\"])\n",
93+
" print(\"=\"*50)"
94+
],
95+
"execution_count": null,
96+
"outputs": []
97+
},
98+
{
99+
"cell_type": "code",
100+
"metadata": {
101+
"id": "6N5qFdcFZG1v"
102+
},
103+
"source": [
104+
"# generate Python Code!\n",
105+
"print(gpt_j_generator(\n",
106+
"\"\"\"\n",
107+
"import os\n",
108+
"# make a list of all african countries\n",
109+
"\"\"\",\n",
110+
" do_sample=True, top_k=10, temperature=0.05, max_length=256)[0][\"generated_text\"])"
111+
],
112+
"execution_count": null,
113+
"outputs": []
114+
},
115+
{
116+
"cell_type": "code",
117+
"metadata": {
118+
"id": "-TOTvHiwwbK-"
119+
},
120+
"source": [
121+
"print(gpt_j_generator(\n",
122+
"\"\"\"\n",
123+
"import cv2\n",
124+
"\n",
125+
"image = \"image.png\"\n",
126+
"\n",
127+
"# load the image and flip it\n",
128+
"\"\"\",\n",
129+
" do_sample=True, top_k=10, temperature=0.05, max_length=256)[0][\"generated_text\"])"
130+
],
131+
"execution_count": null,
132+
"outputs": []
133+
},
134+
{
135+
"cell_type": "code",
136+
"metadata": {
137+
"id": "_52OftmglAAv"
138+
},
139+
"source": [
140+
"# complete bash script!\n",
141+
"print(gpt_j_generator(\n",
142+
"\"\"\"\n",
143+
"# get .py files in /opt directory\n",
144+
"ls *.py /opt\n",
145+
"# get public ip address\n",
146+
"\"\"\", max_length=256, top_k=50, temperature=0.05, do_sample=True)[0][\"generated_text\"])"
147+
],
148+
"execution_count": null,
149+
"outputs": []
150+
},
151+
{
152+
"cell_type": "code",
153+
"metadata": {
154+
"id": "2x527AykVquF"
155+
},
156+
"source": [
157+
"# generating bash script!\n",
158+
"print(gpt_j_generator(\n",
159+
"\"\"\"\n",
160+
"# update the repository\n",
161+
"sudo apt-get update\n",
162+
"# install and start nginx\n",
163+
"\"\"\", max_length=128, top_k=50, temperature=0.1, do_sample=True)[0][\"generated_text\"])"
164+
],
165+
"execution_count": null,
166+
"outputs": []
167+
},
168+
{
169+
"cell_type": "code",
170+
"metadata": {
171+
"id": "elK4JyyxwCPM"
172+
},
173+
"source": [
174+
"# Java code!\n",
175+
"print(gpt_j_generator(\n",
176+
"\"\"\"\n",
177+
"public class Test {\n",
178+
"\n",
179+
"public static void main(String[] args){\n",
180+
" // printing the first 20 fibonacci numbers\n",
181+
"\"\"\", max_length=128, top_k=50, temperature=0.1, do_sample=True)[0][\"generated_text\"])"
182+
],
183+
"execution_count": null,
184+
"outputs": []
185+
},
186+
{
187+
"cell_type": "code",
188+
"metadata": {
189+
"id": "0US1Tv5xh-F2"
190+
},
191+
"source": [
192+
"# LATEX!\n",
193+
"print(gpt_j_generator(\n",
194+
"r\"\"\"\n",
195+
"% list of Asian countries\n",
196+
"\\begin{enumerate}\n",
197+
"\"\"\", max_length=128, top_k=15, temperature=0.1, do_sample=True)[0][\"generated_text\"])"
198+
],
199+
"execution_count": null,
200+
"outputs": []
201+
},
202+
{
203+
"cell_type": "code",
204+
"metadata": {
205+
"id": "clkMMnsgh_YF"
206+
},
207+
"source": [
208+
""
209+
],
210+
"execution_count": null,
211+
"outputs": []
212+
}
213+
]
214+
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
transformers
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# -*- coding: utf-8 -*-
2+
"""TextGeneration-Transformers-PythonCodeTutorial.ipynb
3+
4+
Automatically generated by Colaboratory.
5+
6+
Original file is located at
7+
https://colab.research.google.com/drive/1OUgJ92vQeFFYatf5gwtGulhA-mFwS0Md
8+
"""
9+
10+
# !pip install transformers
11+
12+
from transformers import pipeline
13+
14+
# download & load GPT-2 model
15+
gpt2_generator = pipeline('text-generation', model='gpt2')
16+
17+
# generate 3 different sentences
18+
# results are sampled from the top 50 candidates
19+
sentences = gpt2_generator("To be honest, neural networks", do_sample=True, top_k=50, temperature=0.6, max_length=128, num_return_sequences=3)
20+
for sentence in sentences:
21+
print(sentence["generated_text"])
22+
print("="*50)
23+
24+
# download & load GPT-J model! It's 22.5GB in size
25+
gpt_j_generator = pipeline('text-generation', model='EleutherAI/gpt-j-6B')
26+
27+
# generate sentences with TOP-K sampling
28+
sentences = gpt_j_generator("To be honest, robots will", do_sample=True, top_k=50, temperature=0.6, max_length=128, num_return_sequences=3)
29+
for sentence in sentences:
30+
print(sentence["generated_text"])
31+
print("="*50)
32+
33+
# generate Python Code!
34+
print(gpt_j_generator(
35+
"""
36+
import os
37+
# make a list of all african countries
38+
""",
39+
do_sample=True, top_k=10, temperature=0.05, max_length=256)[0]["generated_text"])
40+
41+
print(gpt_j_generator(
42+
"""
43+
import cv2
44+
45+
image = "image.png"
46+
47+
# load the image and flip it
48+
""",
49+
do_sample=True, top_k=10, temperature=0.05, max_length=256)[0]["generated_text"])
50+
51+
# complete bash script!
52+
print(gpt_j_generator(
53+
"""
54+
# get .py files in /opt directory
55+
ls *.py /opt
56+
# get public ip address
57+
""", max_length=256, top_k=50, temperature=0.05, do_sample=True)[0]["generated_text"])
58+
59+
# generating bash script!
60+
print(gpt_j_generator(
61+
"""
62+
# update the repository
63+
sudo apt-get update
64+
# install and start nginx
65+
""", max_length=128, top_k=50, temperature=0.1, do_sample=True)[0]["generated_text"])
66+
67+
# Java code!
68+
print(gpt_j_generator(
69+
"""
70+
public class Test {
71+
72+
public static void main(String[] args){
73+
// printing the first 20 fibonacci numbers
74+
""", max_length=128, top_k=50, temperature=0.1, do_sample=True)[0]["generated_text"])
75+
76+
# Commented out IPython magic to ensure Python compatibility.
77+
# LATEX!
78+
print(gpt_j_generator(
79+
r"""
80+
# % list of Asian countries
81+
\begin{enumerate}
82+
""", max_length=128, top_k=15, temperature=0.1, do_sample=True)[0]["generated_text"])
83+

0 commit comments

Comments
 (0)