Skip to content

Commit

Permalink
added image to image
Browse files Browse the repository at this point in the history
  • Loading branch information
divamgupta committed Sep 30, 2022
1 parent 2873079 commit c31ff5e
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 11 deletions.
16 changes: 14 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,10 @@ pip install -r requirements.txt
If you installed the package, you can use it as follows:

```python
from stable_diffusion_tf.stable_diffusion import Text2Image
from stable_diffusion_tf.stable_diffusion import StableDiffusion
from PIL import Image
generator = Text2Image(
generator = StableDiffusion(
img_height=512,
img_width=512,
jit_compile=False,
Expand All @@ -90,6 +90,18 @@ img = generator.generate(
temperature=1,
batch_size=1,
)
# for image to image :
img = generator.generate(
"A Halloween bedroom",
num_steps=50,
unconditional_guidance_scale=7.5,
temperature=1,
batch_size=1,
input_image="/path/to/img.png"
)
Image.fromarray(img[0]).save("output.png")
```

Expand Down
34 changes: 34 additions & 0 deletions stable_diffusion_tf/autoencoder_kl.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,37 @@ def __init__(self):
PaddedConv2D(3, 3, padding=1),
]
)


class Encoder(keras.Sequential):
def __init__(self):
super().__init__(
[
PaddedConv2D(128, 3, padding=1 ),
ResnetBlock(128,128),
ResnetBlock(128, 128),
PaddedConv2D(128 , 3 , padding=1, stride=2),

ResnetBlock(128,256),
ResnetBlock(256, 256),
PaddedConv2D(256 , 3 , padding=1, stride=2),

ResnetBlock(256,512),
ResnetBlock(512, 512),
PaddedConv2D(512 , 3 , padding=1, stride=2),

ResnetBlock(512,512),
ResnetBlock(512, 512),

ResnetBlock(512, 512),
AttentionBlock(512),
ResnetBlock(512, 512),

tfa.layers.GroupNormalization(epsilon=1e-5) ,
keras.layers.Activation("swish"),
PaddedConv2D(8, 3, padding=1 ),
PaddedConv2D(8, 1 ),
keras.layers.Lambda(lambda x : x[... , :4] * 0.18215)
]
)

52 changes: 45 additions & 7 deletions stable_diffusion_tf/stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,27 @@
import tensorflow as tf
from tensorflow import keras

from .autoencoder_kl import Decoder
from .autoencoder_kl import Decoder, Encoder
from .diffusion_model import UNetModel
from .clip_encoder import CLIPTextTransformer
from .clip_tokenizer import SimpleTokenizer
from .constants import _UNCONDITIONAL_TOKENS, _ALPHAS_CUMPROD
from PIL import Image

MAX_TEXT_LEN = 77


class Text2Image:
class StableDiffusion:
def __init__(self, img_height=1000, img_width=1000, jit_compile=False, download_weights=True):
self.img_height = img_height
self.img_width = img_width
self.tokenizer = SimpleTokenizer()

text_encoder, diffusion_model, decoder = get_models(img_height, img_width, download_weights=download_weights)
text_encoder, diffusion_model, decoder, encoder = get_models(img_height, img_width, download_weights=download_weights)
self.text_encoder = text_encoder
self.diffusion_model = diffusion_model
self.decoder = decoder
self.encoder = encoder
if jit_compile:
self.text_encoder.compile(jit_compile=True)
self.diffusion_model.compile(jit_compile=True)
Expand All @@ -37,6 +39,8 @@ def generate(
unconditional_guidance_scale=7.5,
temperature=1,
seed=None,
input_image=None,
input_image_strength=0.5,
):
# Tokenize prompt (i.e. starting context)
inputs = self.tokenizer.encode(prompt)
Expand All @@ -50,6 +54,13 @@ def generate(
pos_ids = np.repeat(pos_ids, batch_size, axis=0)
context = self.text_encoder.predict_on_batch([phrase, pos_ids])

if type(input_image) is str:
input_image = Image.open(input_image)
input_image = input_image.resize((self.img_width, self.img_height))
input_image = np.array(input_image)[... , :3]
input_image = (input_image.astype("float") / 255.0)*2 - 1


# Encode unconditional tokens (and their positions into an
# "unconditional context vector"
unconditional_tokens = np.array(_UNCONDITIONAL_TOKENS)[None].astype("int32")
Expand All @@ -59,10 +70,14 @@ def generate(
[self.unconditional_tokens, pos_ids]
)
timesteps = np.arange(1, 1000, 1000 // num_steps)
input_img_noise_t = timesteps[ int(len(timesteps)*input_image_strength) ]
latent, alphas, alphas_prev = self.get_starting_parameters(
timesteps, batch_size, seed
timesteps, batch_size, seed , input_image=input_image, input_img_noise_t=input_img_noise_t
)

if input_image is not None:
timesteps = timesteps[: int(len(timesteps)*input_image_strength)]

# Diffusion stage
progbar = tqdm(list(enumerate(timesteps))[::-1])
for index, timestep in progbar:
Expand All @@ -85,6 +100,14 @@ def generate(
decoded = ((decoded + 1) / 2) * 255
return np.clip(decoded, 0, 255).astype("uint8")

def add_noise(self, x , t ):
batch_size,w,h = x.shape[0] , x.shape[1] , x.shape[2]
noise = tf.random.normal((batch_size,w,h,4))
sqrt_alpha_prod = _ALPHAS_CUMPROD[t] ** 0.5
sqrt_one_minus_alpha_prod = (1 - _ALPHAS_CUMPROD[t]) ** 0.5

return sqrt_alpha_prod * x + sqrt_one_minus_alpha_prod * noise

def timestep_embedding(self, timesteps, dim=320, max_period=10000):
half = dim // 2
freqs = np.exp(
Expand Down Expand Up @@ -125,12 +148,17 @@ def get_x_prev_and_pred_x0(self, x, e_t, index, a_t, a_prev, temperature, seed):
x_prev = math.sqrt(a_prev) * pred_x0 + dir_xt
return x_prev, pred_x0

def get_starting_parameters(self, timesteps, batch_size, seed):
def get_starting_parameters(self, timesteps, batch_size, seed, input_image=None, input_img_noise_t=None):
n_h = self.img_height // 8
n_w = self.img_width // 8
alphas = [_ALPHAS_CUMPROD[t] for t in timesteps]
alphas_prev = [1.0] + alphas[:-1]
latent = tf.random.normal((batch_size, n_h, n_w, 4), seed=seed)
if input_image is None:
latent = tf.random.normal((batch_size, n_h, n_w, 4), seed=seed)
else:
latent = self.encoder(input_image[None])
latent = self.add_noise(latent, input_img_noise_t)
latent = tf.repeat(latent , batch_size , axis=0)
return latent, alphas, alphas_prev


Expand All @@ -157,6 +185,10 @@ def get_models(img_height, img_width, download_weights=True):
latent = keras.layers.Input((n_h, n_w, 4))
decoder = Decoder()
decoder = keras.models.Model(latent, decoder(latent))

inp_img = keras.layers.Input((img_height, img_width, 3))
encoder = Encoder()
encoder = keras.models.Model(inp_img, encoder(inp_img))

if download_weights:
text_encoder_weights_fpath = keras.utils.get_file(
Expand All @@ -172,7 +204,13 @@ def get_models(img_height, img_width, download_weights=True):
file_hash="6d3c5ba91d5cc2b134da881aaa157b2d2adc648e5625560e3ed199561d0e39d5",
)

encoder_weights_fpath = keras.utils.get_file(
origin="https://huggingface.co/divamgupta/stable-diffusion-tensorflow/resolve/main/encoder_newW.h5",
file_hash="56a2578423c640746c5e90c0a789b9b11481f47497f817e65b44a1a5538af754",
)

text_encoder.load_weights(text_encoder_weights_fpath)
diffusion_model.load_weights(diffusion_model_weights_fpath)
decoder.load_weights(decoder_weights_fpath)
return text_encoder, diffusion_model, decoder
encoder.load_weights(encoder_weights_fpath)
return text_encoder, diffusion_model, decoder , encoder
4 changes: 2 additions & 2 deletions text2image.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from tensorflow import keras
from stable_diffusion_tf.stable_diffusion import Text2Image
from stable_diffusion_tf.stable_diffusion import StableDiffusion
import argparse
from PIL import Image

Expand Down Expand Up @@ -65,7 +65,7 @@
print("Using mixed precision.")
keras.mixed_precision.set_global_policy("mixed_float16")

generator = Text2Image(img_height=args.H, img_width=args.W, jit_compile=False)
generator = StableDiffusion(img_height=args.H, img_width=args.W, jit_compile=False)
img = generator.generate(
args.prompt,
num_steps=args.steps,
Expand Down

0 comments on commit c31ff5e

Please sign in to comment.