Skip to content

Commit

Permalink
Fix sdxl q_unet config (#2081)
Browse files Browse the repository at this point in the history
Signed-off-by: Kaihui-intel <[email protected]>
  • Loading branch information
Kaihui-intel authored Dec 6, 2024
1 parent 306bd63 commit 697c5be
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,19 +50,13 @@ function run_benchmark {

if [[ ${mode} == "performance" ]]; then
extra_cmd=$extra_cmd" --performance"
if [[ ${int8} == "true" ]]; then
extra_cmd=$extra_cmd" --int8"
fi
echo $extra_cmd

python -u sdxl_smooth_quant.py \
--model_name_or_path ${model_name_or_path} \
--latent ${latent} \
${extra_cmd}
else
if [[ ${int8} == "true" ]]; then
extra_cmd=$extra_cmd" --int8"
fi
echo $extra_cmd

python -u sdxl_smooth_quant.py \
Expand All @@ -82,7 +76,7 @@ function run_benchmark {
cd mlperf_sd_inference
cp ../main.py ./
if [ -d "../saved_results/" ]; then
mv ../saved_results/ ./
cp -r ../saved_results/ ./
fi

python -u main.py \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -406,10 +406,10 @@ def forward_loop(model):
if args.int8:
from neural_compressor.torch.quantization import load
q_unet = load(os.path.abspath(os.path.expanduser(args.output_dir)))
setattr(q_unet, "config", pipeline.unet.config)
else:
q_unet = pipeline.unet

if not hasattr(q_unet, "config"):
setattr(q_unet, "config", pipeline.unet.config)
pipeline.unet = q_unet
quant_images = prompts2images(pipeline, prompts, n_steps=args.n_steps, latent=init_latent)
save_images(prompts, quant_images, args.output_dir, prefix='quant')
Expand Down

0 comments on commit 697c5be

Please sign in to comment.