Skip to content

Commit

Permalink
#14 updated jaxlib version for jupyter notebooks
Browse files Browse the repository at this point in the history
  • Loading branch information
matthias-wright committed Dec 9, 2021
1 parent a5d907b commit 177f829
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 25 deletions.
14 changes: 7 additions & 7 deletions flaxmodels/gpt2/gpt2_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@
},
"source": [
"!pip install --upgrade pip\n",
"!pip install --upgrade \"jax[cuda111]\" -f https://storage.googleapis.com/jax-releases/jax_releases.html\n",
"!pip install jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_releases.html\n",
"!pip install --upgrade git+https://github.com/matthias-wright/flaxmodels.git"
],
"execution_count": 1,
"execution_count": null,
"outputs": [
{
"output_type": "stream",
Expand Down Expand Up @@ -162,7 +162,7 @@
"print()\n",
"print(sequence)"
],
"execution_count": 2,
"execution_count": null,
"outputs": [
{
"output_type": "stream",
Expand Down Expand Up @@ -266,7 +266,7 @@
"output = model.apply(params, input_ids=input_ids, use_cache=True)\n",
"# output: {'last_hidden_state': ..., 'past_key_values': ..., 'loss': ..., 'logits': ...}"
],
"execution_count": 3,
"execution_count": null,
"outputs": []
},
{
Expand Down Expand Up @@ -300,7 +300,7 @@
"output = model.apply(params, input_embds=input_embds, use_cache=True)\n",
"# output: {'last_hidden_state': ..., 'past_key_values': ..., 'loss': ..., 'logits': ...}"
],
"execution_count": 4,
"execution_count": null,
"outputs": []
},
{
Expand Down Expand Up @@ -339,7 +339,7 @@
"output = model.apply(params, input_ids=input_ids, use_cache=True)\n",
"# output: {'last_hidden_state': ..., 'past_key_values': ...}"
],
"execution_count": 5,
"execution_count": null,
"outputs": []
},
{
Expand Down Expand Up @@ -374,7 +374,7 @@
"output = model.apply(params, input_embds=input_embds, use_cache=True)\n",
"# output: {'last_hidden_state': ..., 'past_key_values': ...}"
],
"execution_count": 6,
"execution_count": null,
"outputs": []
}
]
Expand Down
10 changes: 5 additions & 5 deletions flaxmodels/resnet/resnet_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@
},
"source": [
"!pip install --upgrade pip\n",
"!pip install --upgrade \"jax[cuda111]\" -f https://storage.googleapis.com/jax-releases/jax_releases.html\n",
"!pip install jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_releases.html\n",
"!pip install --upgrade git+https://github.com/matthias-wright/flaxmodels.git"
],
"execution_count": 1,
"execution_count": null,
"outputs": [
{
"output_type": "stream",
Expand Down Expand Up @@ -196,7 +196,7 @@
"\n",
"labels = json.load(open('labels.json'))"
],
"execution_count": 2,
"execution_count": null,
"outputs": []
},
{
Expand Down Expand Up @@ -249,7 +249,7 @@
"for i in range(top5_classes.shape[0]):\n",
" print(f'{i + 1}.', labels[top5_classes[i]])"
],
"execution_count": 3,
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
Expand Down Expand Up @@ -331,7 +331,7 @@
"for key in out.keys():\n",
" print(key, out[key].shape)"
],
"execution_count": 4,
"execution_count": null,
"outputs": [
{
"output_type": "stream",
Expand Down
14 changes: 7 additions & 7 deletions flaxmodels/stylegan2/stylegan2_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@
},
"source": [
"!pip install --upgrade pip\n",
"!pip install --upgrade \"jax[cuda111]\" -f https://storage.googleapis.com/jax-releases/jax_releases.html\n",
"!pip install jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_releases.html\n",
"!pip install --upgrade git+https://github.com/matthias-wright/flaxmodels.git"
],
"execution_count": 1,
"execution_count": null,
"outputs": [
{
"output_type": "stream",
Expand Down Expand Up @@ -214,7 +214,7 @@
" ax[i].imshow(images[i])\n",
" ax[i].axis('off')"
],
"execution_count": 2,
"execution_count": null,
"outputs": [
{
"output_type": "stream",
Expand Down Expand Up @@ -290,7 +290,7 @@
" ax[i].imshow(images[i])\n",
" ax[i].axis('off')"
],
"execution_count": 3,
"execution_count": null,
"outputs": [
{
"output_type": "stream",
Expand Down Expand Up @@ -376,7 +376,7 @@
" ax[i].imshow(images[i])\n",
" ax[i].axis('off')"
],
"execution_count": 4,
"execution_count": null,
"outputs": [
{
"output_type": "stream",
Expand Down Expand Up @@ -501,7 +501,7 @@
"plt.imshow(np.clip(grid, 0, 1))\n",
"plt.axis('off')"
],
"execution_count": 5,
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
Expand Down Expand Up @@ -560,7 +560,7 @@
"params = discriminator.init(key, img)\n",
"out = discriminator.apply(params, img)"
],
"execution_count": 6,
"execution_count": null,
"outputs": [
{
"output_type": "stream",
Expand Down
12 changes: 6 additions & 6 deletions flaxmodels/vgg/vgg_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@
},
"source": [
"!pip install --upgrade pip\n",
"!pip install --upgrade \"jax[cuda111]\" -f https://storage.googleapis.com/jax-releases/jax_releases.html\n",
"!pip install jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_releases.html\n",
"!pip install --upgrade git+https://github.com/matthias-wright/flaxmodels.git"
],
"execution_count": 1,
"execution_count": null,
"outputs": [
{
"output_type": "stream",
Expand Down Expand Up @@ -196,7 +196,7 @@
"\n",
"labels = json.load(open('labels.json'))"
],
"execution_count": 2,
"execution_count": null,
"outputs": []
},
{
Expand Down Expand Up @@ -248,7 +248,7 @@
"for i in range(top5_classes.shape[0]):\n",
" print(f'{i + 1}.', labels[top5_classes[i]])"
],
"execution_count": 3,
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
Expand Down Expand Up @@ -334,7 +334,7 @@
"for key in out.keys():\n",
" print(key, out[key].shape)"
],
"execution_count": 4,
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
Expand Down Expand Up @@ -430,7 +430,7 @@
"for key in out.keys():\n",
" print(key, out[key].shape)"
],
"execution_count": 5,
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
Expand Down

0 comments on commit 177f829

Please sign in to comment.