回复
910B芯片Swift多模态模型分布式训练实践
一、环境准备
1.前置条件
首先准备好训练机器和数据,笔者采用了32节点910B NPU,300万网页训练数据。
2.环境配置
环境安装,首先安装多模态训练框架ms-swift,安装torch-npu及deepspeed。
# 安装ms-swift(当前推荐从源码安装, 待发版后可直接pip安装)
git clone https://github.com/modelscope/swift.git
cd swift
pip install -e '.[llm]'
# 安装torch-npu
pip install torch-npu decorator
# 安装deepspeed
pip install deepspeed
#完整python依赖版本见文末附录
3.环境验证
1)torch环境验证,在智算平台IDE验证环境是否正确,为了节省算力资源在单卡IDE上验证。
#test_env.py
from transformers.utils import is_torch_npu_available
import torch
print(is_torch_npu_available()) # True
print(torch.npu.device_count()) # 1
print(torch.randn(10, device='npu:0'))
以下为正常响应
2)查看NPU状态,使用npu-smi info指令
以下为正常响应
二、训练任务
本实践的训练任务为基于多模态大模型IternVL2的图生HTML代码能力训练。图生HTML代码能力即,输入一张网页截图,多模态模型能够生成相应的html代码。训练数据规模为300万,来源为huggingface的WebSight数据集、内部HTML训练集。由于多模态IternVL2视觉编码器较小占用显存小、无特殊算子,swift和torch-npu能兼容。
1.单机8卡训练脚本
source /usr/local/Ascend/ascend-toolkit/set_env.sh
ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 NPROC_PER_NODE=8 \
swift sft \
--model_type internvl2-8b \
--model_id_or_path /datas/datasets/InternVL2-8B \
--sft_type lora \
--use_rslora \
--lora_rank 8 \
--lora_alpha 16 \
--num_train_epochs 1 \
--dtype bf16 \
--max_length 4096 \
--eval_steps 1000 \
--save_steps 50 \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 32 \
--logging_steps 5 \
--dataset /datasets/train_swift/swift_html_0924_2.jsonl \
--output_dir /ms-swift-2.3.2/add_internVL2/ckpts \
--deepspeed default-zero2
2.32节点分布式训练脚本
source /usr/local/Ascend/ascend-toolkit/set_env.sh
if [ -z "${MASTER_ADDR}" ]; then
echo "MASTER_ADDR is not set"
MASTER_ADDR=localhost
MASTER_PORT=6000
fi
if [ -z "${RANK}" ]; then
NNODES=1
NODE_RANK=0
else
# 否则使用环境变量
NODE_RANK=${RANK}
NNODES=${PET_NNODES}
fi
echo "start training instructions"
ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
NNODES=32 \
NODE_RANK=$NODE_RANK \
MASTER_ADDR=$MASTER_ADDR \
NPROC_PER_NODE=8 \
swift sft \
--model_type internvl2-8b \
--model_id_or_path /datas/datasets/InternVL2-8B \
--sft_type lora \
--use_rslora \
--lora_rank 8 \
--lora_alpha 16 \
--num_train_epochs 6 \
--dtype bf16 \
--max_length 4096 \
--eval_steps 2000 \
--save_steps 20 \
--save_total_limit 10 \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 32 \
--logging_steps 3 \
--resume_from_checkpoint /ms-swift-2.3.2/add_internVL2/ckpts/cluster0/internvl2-8b/v6-20240925-141815/checkpoint-367 \
--resume_only_model \
--dataset /datasets/train_swift/swift_my_html_0926.jsonl \
--save_on_each_node False \
--output_dir /ms-swift-2.3.2/add_internVL2/ckpts/0926/cluster${NODE_RANK} \
--deepspeed default-zero2
3.训练完成,断点续训了2次,14小时完成了300万数据训练
4.图生代码模型训练效果
三、附录
python依赖版本
Package Version Editable project location
----------------------------- -------------- -----------------------------------------
absl-py 2.1.0
accelerate 0.30.1
addict 2.4.0
aiofiles 23.2.1
aiohttp 3.9.5
aiosignal 1.3.1
aliyun-python-sdk-core 2.15.2
aliyun-python-sdk-kms 2.16.5
annotated-types 0.7.0
antlr4-python3-runtime 4.9.3
anyio 4.4.0
apex 0.1-ascend
argon2-cffi 23.1.0
argon2-cffi-bindings 21.2.0
arrow 1.3.0
ascendebug 0.1.0
asttokens 2.4.1
async-lru 2.0.4
async-timeout 4.0.3
attr 0.3.2
attrdict 2.0.1
attrs 23.2.0
auto-tune 0.1.0
Babel 2.15.0
backcall 0.2.0
beautifulsoup4 4.12.3
binpacking 1.5.2
bleach 6.1.0
boto3 1.35.8
botocore 1.35.8
cachetools 5.5.0
certifi 2024.2.2
cffi 1.12.3
charset-normalizer 3.3.2
click 8.1.7
comm 0.2.2
contourpy 1.1.1
cpm-kernels 1.0.11
crcmod 1.7
cryptography 43.0.1
cycler 0.12.1
dacite 1.8.1
dataflow 0.0.1
datasets 2.19.1
debugpy 1.8.2
decorator 5.1.1
deepspeed 0.15.1
defusedxml 0.7.1
dill 0.3.8
distro 1.9.0
docstring_parser 0.16
einops 0.8.0
eval_type_backport 0.2.0
evaluate 0.4.2
exceptiongroup 1.2.1
executing 2.0.1
fastapi 0.112.4
fastjsonschema 2.20.0
ffmpy 0.4.0
filelock 3.14.0
fonttools 4.52.1
fqdn 1.5.1
frozenlist 1.4.1
fsspec 2024.3.1
future 1.0.0
google-auth 2.34.0
google-auth-oauthlib 1.0.0
gradio 4.43.0
gradio_client 1.3.0
grpcio 1.66.1
h11 0.14.0
hccl 0.1.0
hccl-parser 0.1
hjson 3.1.0
httpcore 1.0.5
httpx 0.27.0
huggingface-hub 0.24.6
idna 3.7
importlib_metadata 7.1.0
importlib_resources 6.4.0
iniconfig 2.0.0
ipykernel 6.29.5
ipython 8.12.3
isoduration 20.11.0
jedi 0.19.1
jieba 0.42.1
Jinja2 3.1.4
jiter 0.5.0
jmespath 0.10.0
joblib 1.4.2
json5 0.9.25
jsonpointer 3.0.0
jsonschema 4.23.0
jsonschema-specifications 2023.12.1
jupyter_client 8.6.2
jupyter_core 5.7.2
jupyter-events 0.10.0
jupyter-lsp 2.2.5
jupyter_server 2.14.2
jupyter_server_terminals 0.5.3
jupyterlab 4.2.4
jupyterlab_pygments 0.3.0
jupyterlab_server 2.27.3
kiwisolver 1.4.5
llm-engine 0.0.1
Markdown 3.7
markdown-it-py 3.0.0
MarkupSafe 2.1.5
matplotlib 3.7.5
matplotlib-inline 0.1.7
mdurl 0.1.2
mistune 3.0.2
mmcv 1.7.0
modelscope 1.18.0
mpmath 1.3.0
ms-swift 2.3.2
msadvisor 1.0.0
multidict 6.0.5
multiprocess 0.70.16
nbclient 0.10.0
nbconvert 7.16.4
nbformat 5.10.4
nest-asyncio 1.6.0
networkx 3.1
ninja 1.11.1.1
nltk 3.9.1
notebook_shim 0.2.4
numpy 1.24.4
oauthlib 3.2.2
omegaconf 2.3.0
op-compile-tool 0.1.0
op-gen 0.1
op-test-frame 0.1
opc-tool 0.1.0
openai 1.44.1
opencv-python 4.9.0.80
orjson 3.10.7
oss2 2.19.0
overrides 7.7.0
packaging 24.0
pandas 2.0.3
pandocfilters 1.5.1
parso 0.8.4
pathlib2 2.3.7.post1
peft 0.12.0
pexpect 4.9.0
pickleshare 0.7.5
pillow 10.3.0
pip 24.0
pkgutil_resolve_name 1.3.10
platformdirs 4.2.2
pluggy 1.5.0
prometheus_client 0.20.0
prompt_toolkit 3.0.47
protobuf 5.27.0
psutil 5.9.8
ptyprocess 0.7.0
pure-eval 0.2.2
py-cpuinfo 9.0.0
pyarrow 16.1.0
pyarrow-hotfix 0.6
pyasn1 0.6.0
pyasn1_modules 0.4.0
pybind11 2.12.0
pycparser 2.22
pycryptodome 3.20.0
pydantic 2.9.1
pydantic_core 2.23.3
pydub 0.25.1
pyext 0.7
Pygments 2.18.0
PyJWT 2.9.0
pyparsing 3.1.2
pytest 8.2.1
python-dateutil 2.9.0.post0
python-json-logger 2.0.7
python-multipart 0.0.9
pytz 2024.1
PyYAML 6.0.1
pyzmq 26.0.3
referencing 0.35.1
regex 2024.5.15
requests 2.32.2
requests-oauthlib 2.0.0
rfc3339-validator 0.1.4
rfc3986-validator 0.1.1
rich 13.8.0
rouge 1.0.1
rpds-py 0.19.0
rsa 4.9
ruff 0.6.4
s3transfer 0.10.2
safetensors 0.4.3
schedule-search 0.0.1
scipy 1.10.1
semantic-version 2.10.0
Send2Trash 1.8.3
sentencepiece 0.2.0
setuptools 69.5.1
shellingham 1.5.4
shtab 1.7.1
simplejson 3.19.3
six 1.16.0
sniffio 1.3.1
sortedcontainers 2.4.0
soupsieve 2.5
stack-data 0.6.3
starlette 0.38.5
sympy 1.12
te 0.4.0
tensorboard 2.14.0
tensorboard-data-server 0.7.2
terminado 0.18.1
tiktoken 0.7.0
timm 1.0.9
tinycss2 1.3.0
tokenizers 0.19.1
tomli 2.0.1
tomlkit 0.12.0
torch 2.2.0
torch-npu 2.2.0.post1
torchvision 0.17.0
tornado 6.4.1
tqdm 4.66.4
traitlets 5.14.3
transformers 4.42.4
transformers-stream-generator 0.0.5
trl 0.10.1
typer 0.12.5
types-python-dateutil 2.9.0.20240316
typing_extensions 4.12.0
tyro 0.8.10
tzdata 2024.1
uri-template 1.3.0
urllib3 2.2.2
uvicorn 0.30.6
wcwidth 0.2.13
webcolors 24.6.0
webencodings 0.5.1
websocket-client 1.8.0
websockets 12.0
Werkzeug 3.0.4
wheel 0.43.0
xxhash 3.4.1
yapf 0.40.1
yarl 1.9.4
zipp 3.19.0
赞
收藏
回复
相关推荐