Skip to content

Commit

Permalink
Merge pull request #40 from matpalm/compile_encoder
Browse files Browse the repository at this point in the history
jit compile encoder as well
  • Loading branch information
divamgupta authored Oct 5, 2022
2 parents bd15f2a + 68908f3 commit c172a40
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions stable_diffusion_tf/stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(self, img_height=1000, img_width=1000, jit_compile=False, download_
self.text_encoder.compile(jit_compile=True)
self.diffusion_model.compile(jit_compile=True)
self.decoder.compile(jit_compile=True)
self.encoder.compile(jit_compile=True)

def generate(
self,
Expand Down Expand Up @@ -58,7 +59,7 @@ def generate(
input_image = Image.open(input_image)
input_image = input_image.resize((self.img_width, self.img_height))
input_image = np.array(input_image)[... , :3]
input_image = (input_image.astype("float") / 255.0)*2 - 1
input_image = (input_image.astype("float") / 255.0)*2 - 1


# Encode unconditional tokens (and their positions into an
Expand Down Expand Up @@ -189,7 +190,7 @@ def get_models(img_height, img_width, download_weights=True):
inp_img = keras.layers.Input((img_height, img_width, 3))
encoder = Encoder()
encoder = keras.models.Model(inp_img, encoder(inp_img))

if download_weights:
text_encoder_weights_fpath = keras.utils.get_file(
origin="https://huggingface.co/fchollet/stable-diffusion/resolve/main/text_encoder.h5",
Expand Down

0 comments on commit c172a40

Please sign in to comment.