Skip to content

Commit

Permalink
Merge remote-tracking branch 'refs/remotes/origin/feat/refactor3' int…
Browse files Browse the repository at this point in the history
…o feat/refactor3
  • Loading branch information
Jintao-Huang committed Nov 3, 2024
2 parents 4668a35 + dd0234d commit cb0285b
Show file tree
Hide file tree
Showing 39 changed files with 998 additions and 1,618 deletions.
12 changes: 10 additions & 2 deletions swift/cli/merge_lora.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from swift.llm import merge_lora_main
from swift.llm import ExportArguments, SwiftPipeline, merge_lora


class SwiftMergeLoRA(SwiftPipeline):
args_class = ExportArguments

def run(self):
merge_lora(self.args)


if __name__ == '__main__':
merge_lora_main(replace_if_exists=True)
SwiftMergeLoRA().main()
14 changes: 6 additions & 8 deletions swift/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
if TYPE_CHECKING:
# Recommend using `xxx_main`
from .infer import (VllmEngine, RequestConfig, InferStats, LmdeployEngine, PtEngine, infer_main, deploy_main,
PtLoRARequest, InferClient)
from .export import export_main, merge_lora
PtLoRARequest, InferClient, SwiftInfer, SwiftDeploy)
from .export import export_main, merge_lora, quantize_model, export_to_ollama
from .eval import eval_main
from .train import sft_main, pt_main, rlhf_main
from .argument import (EvalArguments, InferArguments, SftArguments, ExportArguments, DeployArguments, RLHFArguments,
WebUIArguments, AppUIArguments)
WebUIArguments, BaseArguments)
from .template import (TEMPLATE_MAPPING, Template, Word, get_template, TemplateType, register_template,
TemplateInputs, Messages, TemplateMeta, get_template_meta, InferRequest)
from .model import (MODEL_MAPPING, ModelType, get_model_tokenizer, safe_snapshot_download, HfConfigFactory,
Expand All @@ -22,22 +22,21 @@
standard_keys, load_dataset, DATASET_TYPE, HfDataset, sample_dataset)
from .utils import (deep_getattr, to_device, History, decode_base64, history_to_messages, messages_to_history,
safe_tokenizer_decode)
from .module_mapping import MODEL_KEYS_MAPPING, MultiModelKeys
from .base import SwiftPipeline
else:
_extra_objects = {k: v for k, v in globals().items() if not k.startswith('_')}
_import_structure = {
'rlhf': ['rlhf_main'],
'infer': [
'deploy_main', 'VllmEngine', 'RequestConfig', 'InferStats', 'LmdeployEngine', 'PtEngine', 'infer_main',
'PtLoRARequest', 'InferClient'
'PtLoRARequest', 'InferClient', 'SwiftInfer', 'SwiftDeploy'
],
'export': ['export_main', 'merge_lora'],
'export': ['export_main', 'merge_lora', 'quantize_model', 'export_to_ollama'],
'eval': ['eval_main'],
'train': ['sft_main', 'pt_main', 'rlhf_main'],
'argument': [
'EvalArguments', 'InferArguments', 'SftArguments', 'ExportArguments', 'WebUIArguments', 'DeployArguments',
'RLHFArguments', 'AppUIArguments'
'RLHFArguments', 'BaseArguments'
],
'template': [
'TEMPLATE_MAPPING', 'Template', 'Word', 'get_template', 'TemplateType', 'register_template',
Expand All @@ -58,7 +57,6 @@
'deep_getattr', 'to_device', 'History', 'decode_base64', 'history_to_messages', 'messages_to_history',
'safe_tokenizer_decode'
],
'module_mapping': ['MODEL_KEYS_MAPPING', 'MultiModelKeys'],
'base': ['SwiftPipeline']
}

Expand Down
9 changes: 9 additions & 0 deletions swift/llm/argument/base_args/base_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from swift.hub import default_hub
from swift.utils import check_json_format, get_logger, is_master
from ..tuner_args import TunerArguments, get_supported_tuners
from .data_args import DataArguments
from .generation_args import GenerationArguments
from .model_args import ModelArguments
Expand Down Expand Up @@ -56,6 +57,14 @@ def __post_init__(self):
if default_hub.try_login(self.hub_token):
logger.info('hub login successful!')

@property
def supported_tuners(self):
return get_supported_tuners()

@property
def adapters_can_be_merged(self):
return TunerArguments.adapters_can_be_merged

def _load_args(self) -> None:
"""Load specific attributes from sft_args.json"""
from swift.llm import SftArguments, ExportArguments, InferArguments
Expand Down
2 changes: 1 addition & 1 deletion swift/llm/argument/base_args/quant_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class QuantizeArguments:
quant_method: Literal['bnb', 'hqq', 'eetq', 'awq', 'gptq'] = None
# bnb: 4,8; hqq: 1,2,3,4,8'; eetq: 8
# awq: 4; gptq: 2,3,4,8
quant_bits: Literal[0, 1, 2, 3, 4, 8] = 0
quant_bits: Literal[0, 1, 2, 3, 4, 8] = 4
# hqq
hqq_axis: Optional[int] = None
# bnb
Expand Down
24 changes: 14 additions & 10 deletions swift/llm/argument/export_args.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Literal, Optional

import torch.distributed as dist

from swift.utils import get_logger, is_dist
from .base_args import BaseArguments, to_abspath
from .merge_args import MergeArguments
from .tuner_args import adapters_can_be_merged

logger = get_logger()


@dataclass
class ExportArguments(BaseArguments, MergeArguments):
class ExportArguments(MergeArguments, BaseArguments):
"""
ExportArguments is a dataclass that inherits from BaseArguments and MergeArguments.
Expand All @@ -37,14 +36,18 @@ class ExportArguments(BaseArguments, MergeArguments):
tp (int): Tensor parallelism degree.
pp (int): Pipeline parallelism degree.
"""
ckpt_dir: Optional[str] = field(default=None, metadata={'help': '/path/to/your/vx-xxx/checkpoint-xxx'})
output_dir: Optional[str] = None
device_map: str = 'auto' # e.g. 'cpu', 'auto'
safe_serialization: bool = True
max_shard_size: str = '5GB'

to_peft_format: bool = False
# awq/gptq
quant_n_samples: int = 256
quant_seqlen: int = 2048
quant_device_map: str = 'auto' # e.g. 'cpu', 'auto'
quant_batch_size: int = 1
group_size: int = 128

# ollama
to_ollama: bool = False
Expand Down Expand Up @@ -72,9 +75,10 @@ def _init_quant(self):
raise ValueError(f'self.dataset: {self.dataset}, Please input the quant dataset.')

def _init_output_dir(self):
if self.ckpt_dir is None:
ckpt_dir = self.ckpt_dir
if ckpt_dir is None:
ckpt_dir = self.model_info.model_dir
ckpt_dir, ckpt_name = os.path.split(model_dir)
ckpt_dir, ckpt_name = os.path.split(ckpt_dir)
if self.to_peft_format:
suffix = 'peft'
elif self.merge_lora:
Expand All @@ -87,22 +91,22 @@ def _init_output_dir(self):
suffix = f'tp{self.tp}-pp{self.pp}'
elif self.to_hf:
suffix = 'hf'
else:
raise ValueError(f'args: {self}')
self.output_dir = os.path.join(ckpt_dir, f'{ckpt_name}-{suffix}')

logger.info(f'Setting args.output_dir: {self.output_dir}')

self.output_dir = to_abspath(self.output_dir)
assert not os.path.exists(self.output_dir), (f'args.output_dir: {self.output_dir} already exists.')
assert not os.path.exists(self.output_dir), f'args.output_dir: {self.output_dir} already exists.'

def __post_init__(self):
super().__post_init__()
self._init_output_dir()
if self.quant_bits > 0:
self._init_quant()
elif self.to_ollama:
assert self.train_type in ['full'] + adapters_can_be_merged()
if self.train_type != 'full':
self.merge_lora = True
assert self.train_type in ['full'] + self.adapters_can_be_merged

elif self.to_megatron or self.to_hf:
os.environ['RANK'] = '0'
Expand Down
5 changes: 3 additions & 2 deletions swift/llm/argument/infer_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from swift.utils import get_logger, is_lmdeploy_available, is_vllm_available
from .base_args import BaseArguments, to_abspath
from .merge_args import MergeArguments
from .tuner_args import adapters_can_be_merged

logger = get_logger()

Expand Down Expand Up @@ -87,8 +86,8 @@ class InferArguments(BaseArguments, MergeArguments, VllmArguments, LmdeployArgum
save_result (bool): Flag to indicate if results should be saved. Default is True.
stream (Optional[bool]): Flag to indicate if streaming should be enabled. Default is None.
"""
infer_backend: Literal['vllm', 'pt', 'lmdeploy'] = 'pt'
ckpt_dir: Optional[str] = field(default=None, metadata={'help': '/path/to/your/vx-xxx/checkpoint-xxx'})
infer_backend: Literal['vllm', 'pt', 'lmdeploy'] = 'pt'
max_batch_size: int = 16 # for pt engine

# only for inference
Expand Down Expand Up @@ -136,6 +135,8 @@ def __post_init__(self) -> None:
self._init_result_dir()
self._init_stream()
self._init_eval_human()
if self.ckpt_dir is None:
self.train_type = 'full'

def _init_eval_human(self):
if len(self.dataset) == 0 and len(self.val_dataset) == 0:
Expand Down
2 changes: 0 additions & 2 deletions swift/llm/argument/merge_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,11 @@ class MergeArguments:
Args:
merge_lora (bool): Flag to indicate if LoRA merging is enabled. Default is False.
merge_device_map (str): Device map configuration for merging. Default is 'auto'.
use_merge_kit (bool): Flag to indicate merge with `mergekit`. Default is False.
instruct_model_id_or_path (Optional[str]): Path or ID of the instruct model. Use when `use_merge_kit` is True.
instruct_model_revision (Optional[str]): Revision of the instruct model. Use when `use_merge_kit` is True.
"""
merge_lora: bool = False
merge_device_map: str = 'auto'

use_merge_kit: bool = False
instruct_model_id_or_path: Optional[str] = None
Expand Down
8 changes: 4 additions & 4 deletions swift/llm/argument/tuner_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,6 @@ def get_supported_tuners():
return {'lora', 'full', 'longlora', 'adalora', 'ia3', 'llamapro', 'adapter', 'vera', 'boft', 'fourierft', 'reft'}


def adapters_can_be_merged():
return ['lora', 'longlora', 'llamapro', 'adalora']


@dataclass
class TunerArguments:
"""
Expand Down Expand Up @@ -184,3 +180,7 @@ class TunerArguments:
@property
def is_adapter(self) -> bool:
return self.train_type not in {'full'}

@property
def adapters_can_be_merged(self):
return {'lora', 'longlora', 'llamapro', 'adalora'}
13 changes: 8 additions & 5 deletions swift/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,23 @@
import os
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Callable, List, Optional, Type, TypeVar, Union
from typing import Callable, Generic, List, Optional, Type, TypeVar, Union

from swift.llm import BaseArguments
from swift.utils import get_logger, parse_args, seed_everything
from .argument import BaseArguments

logger = get_logger()

T_Args = TypeVar('T_Args', bound=BaseArguments)


class SwiftPipeline(ABC):
args_class = None
class SwiftPipeline(ABC, Generic[T_Args]):
args_class = BaseArguments

def parse_args(self, args: Union[List[str], T_Args, None] = None) -> T_Args:
def __init__(self, args: Union[List[str], T_Args, None] = None):
self.args = self._parse_args(args)

def _parse_args(self, args: Union[List[str], T_Args, None] = None) -> T_Args:
if isinstance(args, BaseArguments):
return args
assert self.args_class is not None
Expand Down
5 changes: 3 additions & 2 deletions swift/llm/dataset/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
LoadFunction = Callable[..., DATASET_TYPE]
logger = get_logger()

DATASET_MAPPING: Dict[str, Dict[str, Any]] = {}


@dataclass
class SubsetDataset:
Expand Down Expand Up @@ -75,6 +73,9 @@ def __post_init__(self):
self.subsets[i] = SubsetDataset(name=subset)


DATASET_MAPPING: Dict[str, DatasetMeta] = {}


def register_dataset(dataset_meta: DatasetMeta, *, exist_ok: bool = False) -> None:
"""Register dataset to the dataset mapping
Expand Down
5 changes: 5 additions & 0 deletions swift/llm/export/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,6 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from .merge_lora import merge_lora
from .export import export_main, SwiftExport
from .quant import quantize_model
from .ollama import export_to_ollama

Loading

0 comments on commit cb0285b

Please sign in to comment.