Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
JingyaHuang committed Nov 29, 2023
1 parent e8d72c2 commit dd4b1c7
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
8 changes: 5 additions & 3 deletions optimum/exporters/neuron/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,11 @@ def validate_model_outputs(
with torch.no_grad():
reference_model.eval()
ref_inputs = config.generate_dummy_inputs(return_tuple=False, **input_shapes)
if reference_model.config.is_encoder_decoder:
if getattr(reference_model.config, "is_encoder_decoder", False):
reference_model = config.patch_model_for_export(reference_model, device="cpu", **input_shapes)
if "AutoencoderKL" in getattr(config._config, "_class_name", "") or reference_model.config.is_encoder_decoder:
if "AutoencoderKL" in getattr(config._config, "_class_name", "") or getattr(
reference_model.config, "is_encoder_decoder", False
):
# VAE components for stable diffusion or Encoder-Decoder models
ref_inputs = tuple(ref_inputs.values())
ref_outputs = reference_model(*ref_inputs)
Expand Down Expand Up @@ -428,7 +430,7 @@ def export_neuronx(
dummy_inputs_tuple = tuple(dummy_inputs.values())

aliases = {}
if model.config.is_encoder_decoder:
if getattr(model.config, "is_encoder_decoder", False):
checked_model = config.patch_model_for_export(model, **input_shapes)
if getattr(config, "is_decoder", False):
aliases = config.generate_io_aliases(checked_model)
Expand Down
4 changes: 3 additions & 1 deletion optimum/neuron/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,9 @@ def _neuron_config_init(cls, config: "PretrainedConfig") -> "NeuronConfig":
# Neuron config constructuor
task = getattr(config, "task") or TasksManager.infer_task_from_model(cls.auto_model_class)
task = TasksManager.map_from_synonym(task)
model_type = neuron_configs.get("model_type", None) or config.model_type
model_type = neuron_configs.get("model_type", None)
if not (model_type and model_type != "None"):
model_type = config.model_type
neuron_config_constructor = TasksManager.get_exporter_config_constructor(
model_type=model_type, exporter="neuron", task=task
)
Expand Down

0 comments on commit dd4b1c7

Please sign in to comment.