diff --git a/stable_diffusion_tf/stable_diffusion.py b/stable_diffusion_tf/stable_diffusion.py index af2f301..9bf2702 100644 --- a/stable_diffusion_tf/stable_diffusion.py +++ b/stable_diffusion_tf/stable_diffusion.py @@ -30,6 +30,7 @@ def __init__(self, img_height=1000, img_width=1000, jit_compile=False, download_ self.text_encoder.compile(jit_compile=True) self.diffusion_model.compile(jit_compile=True) self.decoder.compile(jit_compile=True) + self.encoder.compile(jit_compile=True) def generate( self, @@ -58,7 +59,7 @@ def generate( input_image = Image.open(input_image) input_image = input_image.resize((self.img_width, self.img_height)) input_image = np.array(input_image)[... , :3] - input_image = (input_image.astype("float") / 255.0)*2 - 1 + input_image = (input_image.astype("float") / 255.0)*2 - 1 # Encode unconditional tokens (and their positions into an @@ -189,7 +190,7 @@ def get_models(img_height, img_width, download_weights=True): inp_img = keras.layers.Input((img_height, img_width, 3)) encoder = Encoder() encoder = keras.models.Model(inp_img, encoder(inp_img)) - + if download_weights: text_encoder_weights_fpath = keras.utils.get_file( origin="https://huggingface.co/fchollet/stable-diffusion/resolve/main/text_encoder.h5",