Documentation | Torch4keras | Examples
安装稳定版
pip install bert4torch安装最新版
pip install git+https://github.com/Tongjilibo/bert4torch- 注意事项:pip包的发布慢于git上的开发版本,git clone注意引用路径,注意权重是否需要转换
- 测试用例:
git clone https://github.com/Tongjilibo/bert4torch,修改example中的预训练模型文件路径和数据路径即可启动脚本 - 自行训练:针对自己的数据,修改相应的数据处理代码块
- 开发环境:原使用
torch==1.10版本进行开发,现已切换到torch2.0开发,如其他版本遇到不适配,欢迎反馈
-
LLM模型: 加载chatglm、llama、 baichuan、ziya、bloom等开源大模型权重进行推理和微调
-
核心功能:加载bert、roberta、albert、xlnet、nezha、bart、RoFormer、RoFormer_V2、ELECTRA、GPT、GPT2、T5、GAU-alpha、ERNIE等预训练权重继续进行finetune、并支持在bert基础上灵活定义自己模型
-
丰富示例:包含llm、pretrain、sentence_classfication、sentence_embedding、sequence_labeling、relation_extraction、seq2seq、serving等多种解决方案
-
实验验证:已在公开数据集实验验证,使用如下examples数据集
-
易用trick:集成了常见的trick,即插即用
-
其他特性:加载transformers库模型一起使用;调用方式简洁高效;有训练进度条动态展示;配合torchinfo打印参数量;默认Logger和Tensorboard简便记录训练过程;自定义fit过程,满足高阶需求
-
训练过程:
2022-10-28 23:16:10 - Start Training 2022-10-28 23:16:10 - Epoch: 1/2 5000/5000 [==============================] - 13s 3ms/step - loss: 0.1351 - acc: 0.9601 Evaluate: 100%|██████████████████████████████████████████████████| 2500/2500 [00:03<00:00, 798.09it/s] test_acc: 0.98045. best_test_acc: 0.98045 2022-10-28 23:16:27 - Epoch: 2/2 5000/5000 [==============================] - 13s 3ms/step - loss: 0.0465 - acc: 0.9862 Evaluate: 100%|██████████████████████████████████████████████████| 2500/2500 [00:03<00:00, 635.78it/s] test_acc: 0.98280. best_test_acc: 0.98280 2022-10-28 23:16:44 - Finish Training
| 功能 | bert4torch | transformers | 备注 |
|---|---|---|---|
| 训练进度条 | ✅ | ✅ | 进度条打印loss和定义的metrics |
| 分布式训练dp/ddp | ✅ | ✅ | torch自带dp/ddp |
| 各类callbacks | ✅ | ✅ | 日志/tensorboard/earlystop/wandb等 |
| 大模型推理,stream/batch输出 | ✅ | ✅ | 各个模型是通用的,无需单独维护脚本 |
| 大模型微调 | ✅ | ✅ | lora依赖peft库,pv2自带 |
| 丰富tricks | ✅ | ❌ | 对抗训练等tricks即插即用 |
| 代码简洁易懂,自定义空间大 | ✅ | ❌ | 代码复用度高, keras代码训练风格 |
| 仓库的维护能力/影响力/使用量/兼容性 | ❌ | ✅ | 目前仓库个人维护 |
| 更新日期 | bert4torch | torch4keras | 版本说明 |
|---|---|---|---|
| 20231126 | 0.4.0 | 0.1.5 | 修复flash_attn的bug, stream_generate支持仅输出last_token |
| 20231119 | 0.3.9 | 0.1.5 | 修复random_sample采样n>1, 新增Yi-6B, 支持flash_attn |
| 20231112 | 0.3.8 | 0.1.5 | 支持chatglm 32k的rope_ratio,config中可以指定mapping, 增加m3e和bge |
| 20231106 | 0.3.7 | 0.1.5 | 大部分模型文件无需convert,修复multi_query_group_num在int4/int8下bug, 简化build_transformer_model中配置到config中 |
- 20231126:修复flash_attn的bug, stream_generate支持仅输出last_token
- 20231119:修复random_sample采样n>1, 新增Yi-6B, 支持flash_attn
- 20231112:支持chatglm 32k的rope_ratio,config中可以指定mapping, 增加m3e和bge
- 20231106:🔥大部分模型文件无需convert,修复multi_query_group_num在int4/int8下bug, 简化
build_transformer_model中配置到config中
- 若无说明则使用权重自带的
pytorch_model.bin和config.json
| 模型分类 | 模型名称 | 权重来源 | 权重链接 | 备注(若有) |
|---|---|---|---|---|
| bert | bert-base-chinese | 谷歌bert的torch版 | torch | config |
| chinese_L-12_H-768_A-12 | 谷歌 | github, tf | 转换命令, config | |
| chinese-bert-wwm-ext | HFL | tf/torch,torch | ||
| bert-base-multilingual-cased | huggingface | torch | config | |
| macbert | HFL | tf/torch,torch | ||
| wobert | 追一科技 | tf,torch_base,torch_plus_base | ||
| guwenbert | ethanyt | torch | config | |
| roberta | chinese-roberta-wwm-ext | HFL | tf/torch,torch | |
| roberta-small/tiny | 追一科技 & UER | tf,torch | 转换脚本 | |
| roberta-base-english | huggingface | torch | config | |
| albert | albert | brightmart | tf,torch,torch | |
| nezha | NEZHA | 华为 | tf,torch | |
| xlnet | chinese-xlnet | HFL | tf/torch | config |
| deberta | Erlangshen-DeBERTa-v2 | IDEA | torch | |
| electra | Chinese-ELECTRA | HFL | tf,torch | |
| ernie | ernie | 百度文心 | paddle,torch | |
| roformer | roformer | 追一科技 | tf,torch | |
| roformer_v2 | 追一科技 | tf,torch | ||
| simbert | simbert | 追一科技 | tf,torch_base | 转换脚本 |
| simbert_v2/roformer-sim | 追一科技 | tf,torch | ||
| gau | GAU-alpha | 追一科技 | tf | 转换脚本 |
| gpt | CDial-GPT | thu-coai | torch | config |
| gpt2 | cmp_lm(26亿) | 清华 | torch | config |
| gpt2-chinese-cluecorpussmall | UER | torch | config | |
| gpt2-ml | imcaspar | tf,torch | config | |
| bart | bart_base_chinese | 复旦fnlp | torch, v1.0, v2.0 | config |
| t5 | t5 | UER | torch | config_base, config_small |
| mt5 | 谷歌 | torch | config | |
| t5_pegasus | 追一科技 | tf | config_base, config_small | |
| chatyuan v1&v2 | clue-ai | torch | config | |
| PromptCLUE | clue-ai | torch | config | |
| chatglm | chatglm-6b | THUDM | github, v0.1.0, v1.1.0, int8, int4 | config |
| chatglm2-6b | THUDM | github, v2, int4, 32k | config | |
| chatglm3-6b | THUDM | github, v3, 32k | config | |
| llama | llama | github | config | |
| llama-2 | github, 7b, 7b-chat, 13b, 13b-chat | config | ||
| chinese_llama_alpaca | HFL | github | config | |
| Belle_llama | LianjiaTech | github, 7B-2M-enc | 合成说明、config | |
| Ziya | IDEA-CCNL | v1, v1.1, pretrain-v1 | config | |
| Baichuan | baichuan-inc | github, 7B, 13B-Base, 13B-Chat | config | |
| Baichuan2 | baichuan-inc | github, 7B-Base, 7B-Chat, 13B-Base, 13B-Chat | config | |
| vicuna | lmsys | 7b-v1.5 | config | |
| Yi | 01-ai | github, 6B, 6B-200K | config | |
| bloom | bloom | bigscience | bloom-560m, bloomz-560m | config |
| Qwen | Qwen | 阿里云 | github, 7B, 7B-Chat | config |
| InternLM | InternLM | 上海人工智能实验室 | github, 7B-Chat, 7B | config |
| Falcon | Falcon | tiiuae | hf, RW-1B, 7B, 7B-Instruct | config |
| embedding | text2vec-base-chinese | shibing624 | torch | |
| m3e | moka-ai | torch | config | |
| bge | BAAI | torch | config |
- 感谢苏神实现的bert4keras,本实现有不少地方参考了bert4keras的源码,在此衷心感谢大佬的无私奉献;
- 其次感谢项目bert4pytorch,也是在该项目的指引下给了我用pytorch来复现bert4keras的想法和思路。
@misc{bert4torch,
title={bert4torch},
author={Bo Li},
year={2022},
howpublished={\url{https://github.com/Tongjilibo/bert4torch}},
}
- Wechat & Star History Chart
![]() 微信号 |
![]() 微信群 |
Star History Chart |


