910B芯片Swift多模态模型分布式训练实践

发布于 2024-12-11 10:57
浏览
0收藏

一、环境准备

1.前置条件

首先准备好训练机器和数据,笔者采用了32节点910B NPU,300万网页训练数据。

2.环境配置

环境安装,首先安装多模态训练框架ms-swift,安装torch-npu及deepspeed。

# 安装ms-swift(当前推荐从源码安装, 待发版后可直接pip安装)
git clone https://github.com/modelscope/swift.git
cd swift
pip install -e '.[llm]'
# 安装torch-npu
pip install torch-npu decorator
# 安装deepspeed
pip install deepspeed
#完整python依赖版本见文末附录

3.环境验证

1)torch环境验证,在智算平台IDE验证环境是否正确,为了节省算力资源在单卡IDE上验证。

#test_env.py
from transformers.utils import is_torch_npu_available
import torch
 
print(is_torch_npu_available())  # True
print(torch.npu.device_count())  # 1
print(torch.randn(10, device='npu:0'))

以下为正常响应

910B芯片Swift多模态模型分布式训练实践-AI.x社区

2)查看NPU状态,使用npu-smi info指令

以下为正常响应

910B芯片Swift多模态模型分布式训练实践-AI.x社区

二、训练任务

本实践的训练任务为基于多模态大模型IternVL2的图生HTML代码能力训练。图生HTML代码能力即,输入一张网页截图,多模态模型能够生成相应的html代码。训练数据规模为300万,来源为huggingface的WebSight数据集、内部HTML训练集。由于多模态IternVL2视觉编码器较小占用显存小、无特殊算子,swift和torch-npu能兼容。

1.单机8卡训练脚本

source /usr/local/Ascend/ascend-toolkit/set_env.sh
ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 NPROC_PER_NODE=8 \
swift sft \
    --model_type internvl2-8b \
    --model_id_or_path /datas/datasets/InternVL2-8B \
    --sft_type lora \
    --use_rslora \
    --lora_rank 8 \
    --lora_alpha 16 \
    --num_train_epochs 1 \
    --dtype bf16 \
    --max_length 4096 \
    --eval_steps 1000 \
    --save_steps 50 \
    --per_device_train_batch_size 1 \
    --gradient_accumulation_steps 32 \
    --logging_steps 5 \
    --dataset /datasets/train_swift/swift_html_0924_2.jsonl \
    --output_dir /ms-swift-2.3.2/add_internVL2/ckpts \
    --deepspeed default-zero2

2.32节点分布式训练脚本

source /usr/local/Ascend/ascend-toolkit/set_env.sh
if [ -z "${MASTER_ADDR}" ]; then
    echo "MASTER_ADDR is not set"
    MASTER_ADDR=localhost
    MASTER_PORT=6000
fi
if [ -z "${RANK}" ]; then
    NNODES=1
    NODE_RANK=0
else
    # 否则使用环境变量
    NODE_RANK=${RANK}
    NNODES=${PET_NNODES}
fi
echo "start training instructions"
ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
NNODES=32 \
NODE_RANK=$NODE_RANK \
MASTER_ADDR=$MASTER_ADDR \
NPROC_PER_NODE=8 \
swift sft \
    --model_type internvl2-8b \
    --model_id_or_path /datas/datasets/InternVL2-8B \
    --sft_type lora \
    --use_rslora \
    --lora_rank 8 \
    --lora_alpha 16 \
    --num_train_epochs 6 \
    --dtype bf16 \
    --max_length 4096 \
    --eval_steps 2000 \
    --save_steps 20 \
    --save_total_limit 10 \
    --per_device_train_batch_size 1 \
    --gradient_accumulation_steps 32 \
    --logging_steps 3 \
    --resume_from_checkpoint /ms-swift-2.3.2/add_internVL2/ckpts/cluster0/internvl2-8b/v6-20240925-141815/checkpoint-367 \
    --resume_only_model \
    --dataset /datasets/train_swift/swift_my_html_0926.jsonl \
    --save_on_each_node False \
    --output_dir /ms-swift-2.3.2/add_internVL2/ckpts/0926/cluster${NODE_RANK} \
    --deepspeed default-zero2

3.训练完成,断点续训了2次,14小时完成了300万数据训练

910B芯片Swift多模态模型分布式训练实践-AI.x社区

4.图生代码模型训练效果

910B芯片Swift多模态模型分布式训练实践-AI.x社区

910B芯片Swift多模态模型分布式训练实践-AI.x社区

三、附录

python依赖版本

Package                       Version        Editable project location
----------------------------- -------------- -----------------------------------------
absl-py                       2.1.0
accelerate                    0.30.1
addict                        2.4.0
aiofiles                      23.2.1
aiohttp                       3.9.5
aiosignal                     1.3.1
aliyun-python-sdk-core        2.15.2
aliyun-python-sdk-kms         2.16.5
annotated-types               0.7.0
antlr4-python3-runtime        4.9.3
anyio                         4.4.0
apex                          0.1-ascend
argon2-cffi                   23.1.0
argon2-cffi-bindings          21.2.0
arrow                         1.3.0
ascendebug                    0.1.0
asttokens                     2.4.1
async-lru                     2.0.4
async-timeout                 4.0.3
attr                          0.3.2
attrdict                      2.0.1
attrs                         23.2.0
auto-tune                     0.1.0
Babel                         2.15.0
backcall                      0.2.0
beautifulsoup4                4.12.3
binpacking                    1.5.2
bleach                        6.1.0
boto3                         1.35.8
botocore                      1.35.8
cachetools                    5.5.0
certifi                       2024.2.2
cffi                          1.12.3
charset-normalizer            3.3.2
click                         8.1.7
comm                          0.2.2
contourpy                     1.1.1
cpm-kernels                   1.0.11
crcmod                        1.7
cryptography                  43.0.1
cycler                        0.12.1
dacite                        1.8.1
dataflow                      0.0.1
datasets                      2.19.1
debugpy                       1.8.2
decorator                     5.1.1
deepspeed                     0.15.1
defusedxml                    0.7.1
dill                          0.3.8
distro                        1.9.0
docstring_parser              0.16
einops                        0.8.0
eval_type_backport            0.2.0
evaluate                      0.4.2
exceptiongroup                1.2.1
executing                     2.0.1
fastapi                       0.112.4
fastjsonschema                2.20.0
ffmpy                         0.4.0
filelock                      3.14.0
fonttools                     4.52.1
fqdn                          1.5.1
frozenlist                    1.4.1
fsspec                        2024.3.1
future                        1.0.0
google-auth                   2.34.0
google-auth-oauthlib          1.0.0
gradio                        4.43.0
gradio_client                 1.3.0
grpcio                        1.66.1
h11                           0.14.0
hccl                          0.1.0
hccl-parser                   0.1
hjson                         3.1.0
httpcore                      1.0.5
httpx                         0.27.0
huggingface-hub               0.24.6
idna                          3.7
importlib_metadata            7.1.0
importlib_resources           6.4.0
iniconfig                     2.0.0
ipykernel                     6.29.5
ipython                       8.12.3
isoduration                   20.11.0
jedi                          0.19.1
jieba                         0.42.1
Jinja2                        3.1.4
jiter                         0.5.0
jmespath                      0.10.0
joblib                        1.4.2
json5                         0.9.25
jsonpointer                   3.0.0
jsonschema                    4.23.0
jsonschema-specifications     2023.12.1
jupyter_client                8.6.2
jupyter_core                  5.7.2
jupyter-events                0.10.0
jupyter-lsp                   2.2.5
jupyter_server                2.14.2
jupyter_server_terminals      0.5.3
jupyterlab                    4.2.4
jupyterlab_pygments           0.3.0
jupyterlab_server             2.27.3
kiwisolver                    1.4.5
llm-engine                    0.0.1
Markdown                      3.7
markdown-it-py                3.0.0
MarkupSafe                    2.1.5
matplotlib                    3.7.5
matplotlib-inline             0.1.7
mdurl                         0.1.2
mistune                       3.0.2
mmcv                          1.7.0
modelscope                    1.18.0
mpmath                        1.3.0
ms-swift                      2.3.2          
msadvisor                     1.0.0
multidict                     6.0.5
multiprocess                  0.70.16
nbclient                      0.10.0
nbconvert                     7.16.4
nbformat                      5.10.4
nest-asyncio                  1.6.0
networkx                      3.1
ninja                         1.11.1.1
nltk                          3.9.1
notebook_shim                 0.2.4
numpy                         1.24.4
oauthlib                      3.2.2
omegaconf                     2.3.0
op-compile-tool               0.1.0
op-gen                        0.1
op-test-frame                 0.1
opc-tool                      0.1.0
openai                        1.44.1
opencv-python                 4.9.0.80
orjson                        3.10.7
oss2                          2.19.0
overrides                     7.7.0
packaging                     24.0
pandas                        2.0.3
pandocfilters                 1.5.1
parso                         0.8.4
pathlib2                      2.3.7.post1
peft                          0.12.0
pexpect                       4.9.0
pickleshare                   0.7.5
pillow                        10.3.0
pip                           24.0
pkgutil_resolve_name          1.3.10
platformdirs                  4.2.2
pluggy                        1.5.0
prometheus_client             0.20.0
prompt_toolkit                3.0.47
protobuf                      5.27.0
psutil                        5.9.8
ptyprocess                    0.7.0
pure-eval                     0.2.2
py-cpuinfo                    9.0.0
pyarrow                       16.1.0
pyarrow-hotfix                0.6
pyasn1                        0.6.0
pyasn1_modules                0.4.0
pybind11                      2.12.0
pycparser                     2.22
pycryptodome                  3.20.0
pydantic                      2.9.1
pydantic_core                 2.23.3
pydub                         0.25.1
pyext                         0.7
Pygments                      2.18.0
PyJWT                         2.9.0
pyparsing                     3.1.2
pytest                        8.2.1
python-dateutil               2.9.0.post0
python-json-logger            2.0.7
python-multipart              0.0.9
pytz                          2024.1
PyYAML                        6.0.1
pyzmq                         26.0.3
referencing                   0.35.1
regex                         2024.5.15
requests                      2.32.2
requests-oauthlib             2.0.0
rfc3339-validator             0.1.4
rfc3986-validator             0.1.1
rich                          13.8.0
rouge                         1.0.1
rpds-py                       0.19.0
rsa                           4.9
ruff                          0.6.4
s3transfer                    0.10.2
safetensors                   0.4.3
schedule-search               0.0.1
scipy                         1.10.1
semantic-version              2.10.0
Send2Trash                    1.8.3
sentencepiece                 0.2.0
setuptools                    69.5.1
shellingham                   1.5.4
shtab                         1.7.1
simplejson                    3.19.3
six                           1.16.0
sniffio                       1.3.1
sortedcontainers              2.4.0
soupsieve                     2.5
stack-data                    0.6.3
starlette                     0.38.5
sympy                         1.12
te                            0.4.0
tensorboard                   2.14.0
tensorboard-data-server       0.7.2
terminado                     0.18.1
tiktoken                      0.7.0
timm                          1.0.9
tinycss2                      1.3.0
tokenizers                    0.19.1
tomli                         2.0.1
tomlkit                       0.12.0
torch                         2.2.0
torch-npu                     2.2.0.post1
torchvision                   0.17.0
tornado                       6.4.1
tqdm                          4.66.4
traitlets                     5.14.3
transformers                  4.42.4
transformers-stream-generator 0.0.5
trl                           0.10.1
typer                         0.12.5
types-python-dateutil         2.9.0.20240316
typing_extensions             4.12.0
tyro                          0.8.10
tzdata                        2024.1
uri-template                  1.3.0
urllib3                       2.2.2
uvicorn                       0.30.6
wcwidth                       0.2.13
webcolors                     24.6.0
webencodings                  0.5.1
websocket-client              1.8.0
websockets                    12.0
Werkzeug                      3.0.4
wheel                         0.43.0
xxhash                        3.4.1
yapf                          0.40.1
yarl                          1.9.4
zipp                          3.19.0

本文转载自​AI遇见云​,作者: 周华健 ​​

收藏
回复
举报
回复
相关推荐