Skip to content

Commit

Permalink
pydantic fixes and cli sample fix (#374)
Browse files Browse the repository at this point in the history
  • Loading branch information
TheFusion21 authored Jan 8, 2024
1 parent 4451ab7 commit ed5885b
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 15 deletions.
2 changes: 1 addition & 1 deletion imagen_pytorch/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def sample(

# generate image

pil_image = imagen.sample(text, cond_scale = cond_scale, return_pil_images = True)
pil_image = imagen.sample([text], cond_scale = cond_scale, return_pil_images = True)

image_path = f'./{simple_slugify(text)}.png'
pil_image[0].save(image_path)
Expand Down
26 changes: 12 additions & 14 deletions imagen_pytorch/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class UnetConfig(AllowExtraBaseModel):
dim: int
dim_mults: ListOrTuple(int)
text_embed_dim: int = get_encoded_dim(DEFAULT_T5_NAME)
cond_dim: int = None
cond_dim: Optional[int] = None
channels: int = 3
attn_dim_head: int = 32
attn_heads: int = 16
Expand All @@ -56,7 +56,7 @@ class Unet3DConfig(AllowExtraBaseModel):
dim: int
dim_mults: ListOrTuple(int)
text_embed_dim: int = get_encoded_dim(DEFAULT_T5_NAME)
cond_dim: int = None
cond_dim: Optional[int] = None
channels: int = 3
attn_dim_head: int = 32
attn_heads: int = 16
Expand All @@ -75,12 +75,11 @@ class ImagenConfig(AllowExtraBaseModel):
loss_type: str = 'l2'
cond_drop_prob: float = 0.5

@validator('image_sizes')
def check_image_sizes(cls, image_sizes, values):
unets = values.get('unets')
if len(image_sizes) != len(unets):
raise ValueError(f'image sizes length {len(image_sizes)} must be equivalent to the number of unets {len(unets)}')
return image_sizes
@model_validator(mode="after")
def check_image_sizes(self):
if len(self.image_sizes) != len(self.unets):
raise ValueError(f'image sizes length {len(self.image_sizes)} must be equivalent to the number of unets {len(self.unets)}')
return self

def create(self):
decoder_kwargs = self.dict()
Expand Down Expand Up @@ -123,12 +122,11 @@ class ElucidatedImagenConfig(AllowExtraBaseModel):
S_tmax: SingleOrList(int) = 50
S_noise: SingleOrList(float) = 1.003

@validator('image_sizes')
def check_image_sizes(cls, image_sizes, values):
unets = values.get('unets')
if len(image_sizes) != len(unets):
raise ValueError(f'image sizes length {len(image_sizes)} must be equivalent to the number of unets {len(unets)}')
return image_sizes
@model_validator(mode="after")
def check_image_sizes(self):
if len(self.image_sizes) != len(self.unets):
raise ValueError(f'image sizes length {len(self.image_sizes)} must be equivalent to the number of unets {len(self.unets)}')
return self

def create(self):
decoder_kwargs = self.dict()
Expand Down

0 comments on commit ed5885b

Please sign in to comment.