Skip to content

PaddleOCR-VL 微调后的模型无法使用 FastDeploy 加载和推理 #5525

@megemini

Description

@megemini

在 AI Studio 中微调了一个 PaddleOCR-VL 模型,参考 https://aistudio.baidu.com/projectdetail/9857242

使用如下命令进行部署,报错

aistudio@jupyter-942478-9857242:~$ python -m fastdeploy.entrypoints.openai.api_server \
>     --model /home/aistudio/paddleocr_vl/PaddleOCR-VL-SFT \
>     --port 8185 \
>     --metrics-port 8186 \
>     --engine-worker-queue-port 8187 \
>     --max-model-len 16384 \
>     --max-num-batched-tokens 16384 \
>     --gpu-memory-utilization 0.7 \
>     --max-num-seqs 256
/opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages/paddle/utils/cpp_extension/extension_utils.py:718: UserWarning: No ccache found. Please be aware that recompiling all source files may be required. You can download and install ccache from: https://github.com/ccache/ccache/blob/master/doc/INSTALL.md
  warnings.warn(warning_message)
INFO     2025-12-12 12:54:19,655 43112 api_server.py[line:86] Number of api-server workers: 1.
[2025-12-12 12:54:19,665] [    INFO] - Using download source: huggingface
[2025-12-12 12:54:19,666] [    INFO] - Loading configuration file /home/aistudio/paddleocr_vl/PaddleOCR-VL-SFT/config.json
[2025-12-12 12:54:19,666] [ WARNING] - You are using a model of type paddleocr_vl to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
[2025-12-12 12:54:19,667] [ WARNING] - You are using a model of type paddleocr_vl to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
/home/aistudio/external-libraries/lib/python3.10/site-packages/_distutils_hack/__init__.py:30: UserWarning: Setuptools is replacing distutils. Support for replacing an already imported distutils is deprecated. In the future, this condition will fail. Register concerns at https://github.com/pypa/setuptools/issues/new?template=distutils-deprecation.yml
  warnings.warn(
/opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages/paddleslim/common/load_model.py:20: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
  import pkg_resources as pkg
/opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages/fastdeploy/model_executor/graph_optimization/utils.py:21: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.
[2025-12-12 12:54:22,422] [    INFO] - Using download source: huggingface
[2025-12-12 12:54:22,423] [    INFO] - Loading configuration file /home/aistudio/paddleocr_vl/PaddleOCR-VL-SFT/generation_config.json
Traceback (most recent call last):
  File "/opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages/fastdeploy/input/preprocess.py", line 73, in create_processor
    Processor = load_input_processor_plugins()
  File "/opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages/fastdeploy/plugins/input_processor/__init__.py", line 26, in load_input_processor_plugins
    assert len(plugins) == 1, "Only one plugin is allowed to be loaded."
AssertionError: Only one plugin is allowed to be loaded.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/opt/conda/envs/python35-paddle120-env/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/opt/conda/envs/python35-paddle120-env/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages/fastdeploy/entrypoints/openai/api_server.py", line 715, in <module>
    main()
  File "/opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages/fastdeploy/entrypoints/openai/api_server.py", line 698, in main
    if not load_engine():
  File "/opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages/fastdeploy/entrypoints/openai/api_server.py", line 124, in load_engine
    if not engine.start(api_server_pid=args.port):
  File "/opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages/fastdeploy/engine/engine.py", line 134, in start
    self.engine.create_data_processor()
  File "/opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages/fastdeploy/engine/common_engine.py", line 161, in create_data_processor
    self.data_processor = self.input_processor.create_processor()
  File "/opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages/fastdeploy/input/preprocess.py", line 116, in create_processor
    self.processor = PaddleOCRVLProcessor(
  File "/opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages/fastdeploy/input/paddleocr_vl_processor/paddleocr_vl_processor.py", line 65, in __init__
    super().__init__(model_name_or_path, reasoning_parser_obj, tool_parser_obj)
  File "/opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages/fastdeploy/input/text_processor.py", line 179, in __init__
    self.tokenizer = self._load_tokenizer()
  File "/opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages/fastdeploy/input/text_processor.py", line 628, in _load_tokenizer
    return AutoTokenizer.from_pretrained(self.model_name_or_path, padding_side="left", use_fast=True)
  File "/home/aistudio/external-libraries/lib/python3.10/site-packages/paddleformers/transformers/auto/tokenizer.py", line 329, in from_pretrained
    raise ValueError(
ValueError: Tokenizer class Ernie4_5_Tokenizer does not exist or is not currently imported.
/opt/conda/envs/python35-paddle120-env/lib/python3.10/multiprocessing/resource_tracker.py:224: UserWarning: resource_tracker: There appear to be 1 leaked shared_memory objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '

尝试使用 python 脚本进行推理,同样错误

from fastdeploy.entrypoints.llm import LLM
# 加载模型
llm = LLM(model="/home/aistudio/paddleocr_vl/PaddleOCR-VL-SFT")

outputs = llm.chat(
    messages=[
        {"role": "user", "content": [ {"type": "image_url", "image_url": {"url": "https://ai-studio-static-online.cdn.bcebos.com/dc31c334d4664ca4955aa47d8e202a53a276fd0aab0840b09abe953fe51207d0"}},
                                     {"type": "text", "text": "OCR:{}"}]}
    ],
    chat_template_kwargs={"enable_thinking": False})

# 输出结果
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs.text
    reasoning_text = output.outputs.reasoning_content

输出

[2025-12-12 12:58:45,335] [    INFO] - Using download source: huggingface
[2025-12-12 12:58:45,336] [    INFO] - Loading configuration file /home/aistudio/paddleocr_vl/PaddleOCR-VL-SFT/config.json
[2025-12-12 12:58:45,337] [ WARNING] - You are using a model of type paddleocr_vl to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
[2025-12-12 12:58:45,338] [ WARNING] - You are using a model of type paddleocr_vl to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
[2025-12-12 12:58:45,420] [    INFO] - Using download source: huggingface
[2025-12-12 12:58:45,422] [    INFO] - Loading configuration file /home/aistudio/paddleocr_vl/PaddleOCR-VL-SFT/generation_config.json
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
File /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages/fastdeploy/input/preprocess.py:73, in InputPreprocessor.create_processor(self)
     71 from fastdeploy.plugins.input_processor import load_input_processor_plugins
---> 73 Processor = load_input_processor_plugins()
     74 self.processor = Processor(
     75     model_name_or_path=self.model_name_or_path,
     76     reasoning_parser_obj=reasoning_parser_obj,
     77     tool_parser_obj=tool_parser_obj,
     78 )

File /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages/fastdeploy/plugins/input_processor/__init__.py:26, in load_input_processor_plugins()
     25 plugins = load_plugins_by_group(group=PLUGINS_GROUP)
---> 26 assert len(plugins) == 1, "Only one plugin is allowed to be loaded."
     27 return next(iter(plugins.values()))()

AssertionError: Only one plugin is allowed to be loaded.

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
Cell In[5], line 3
      1 from fastdeploy.entrypoints.llm import LLM
      2 # 加载模型
----> 3 llm = LLM(model="/home/aistudio/paddleocr_vl/PaddleOCR-VL-SFT")
      5 outputs = llm.chat(
      6     messages=[
      7         {"role": "user", "content": [ {"type": "image_url", "image_url": {"url": "https://ai-studio-static-online.cdn.bcebos.com/dc31c334d4664ca4955aa47d8e202a53a276fd0aab0840b09abe953fe51207d0"}},
      8                                      {"type": "text", "text": "OCR:{}"}]}
      9     ],
     10     chat_template_kwargs={"enable_thinking": False})
     12 # 输出结果

File /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages/fastdeploy/entrypoints/llm.py:100, in LLM.__init__(self, model, revision, tokenizer, enable_logprob, chat_template, **kwargs)
     96 self.llm_engine = LLMEngine.from_engine_args(engine_args=engine_args)
     98 self.default_sampling_params = SamplingParams(max_tokens=self.llm_engine.cfg.model_config.max_model_len)
--> 100 self.llm_engine.start()
    102 self.mutex = threading.Lock()
    103 self.req_output = dict()

File /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages/fastdeploy/engine/engine.py:134, in LLMEngine.start(self, api_server_pid)
    131 self.launch_components()
    133 self.engine.start()
--> 134 self.engine.create_data_processor()
    135 self.data_processor = self.engine.data_processor
    137 # If block numer is specified and model is deployed in mixed mode, start cache manager first

File /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages/fastdeploy/engine/common_engine.py:161, in EngineService.create_data_processor(self)
    153 def create_data_processor(self):
    154     self.input_processor = InputPreprocessor(
    155         self.cfg.model_config,
    156         self.cfg.structured_outputs_config.reasoning_parser,
   (...)
    159         self.cfg.tool_parser,
    160     )
--> 161     self.data_processor = self.input_processor.create_processor()

File /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages/fastdeploy/input/preprocess.py:116, in InputPreprocessor.create_processor(self)
    111 elif "PaddleOCRVL" in architecture:
    112     from fastdeploy.input.paddleocr_vl_processor import (
    113         PaddleOCRVLProcessor,
    114     )
--> 116     self.processor = PaddleOCRVLProcessor(
    117         config=self.model_config,
    118         model_name_or_path=self.model_name_or_path,
    119         limit_mm_per_prompt=self.limit_mm_per_prompt,
    120         mm_processor_kwargs=self.mm_processor_kwargs,
    121         reasoning_parser_obj=reasoning_parser_obj,
    122     )
    123 elif "PaddleOCRVL" in architecture:
    124     from fastdeploy.input.paddleocr_vl_processor import (
    125         PaddleOCRVLProcessor,
    126     )

File /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages/fastdeploy/input/paddleocr_vl_processor/paddleocr_vl_processor.py:65, in PaddleOCRVLProcessor.__init__(self, config, model_name_or_path, limit_mm_per_prompt, mm_processor_kwargs, reasoning_parser_obj, tool_parser_obj, enable_processor_cache)
     44 def __init__(
     45     self,
     46     config,
   (...)
     52     enable_processor_cache=False,
     53 ):
     54     """
     55     Initialize PaddleOCRVLProcessor instance.
     56 
   (...)
     63         tool_parser_obj: Tool parser instance
     64     """
---> 65     super().__init__(model_name_or_path, reasoning_parser_obj, tool_parser_obj)
     66     data_processor_logger.info(f"model_name_or_path: {model_name_or_path}")
     67     processor_kwargs = self._parse_processor_kwargs(mm_processor_kwargs)

File /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages/fastdeploy/input/text_processor.py:179, in DataProcessor.__init__(self, model_name_or_path, reasoning_parser_obj, tool_parser_obj)
    177 self.decode_status = dict()
    178 self.tool_parser_dict = dict()
--> 179 self.tokenizer = self._load_tokenizer()
    180 data_processor_logger.info(
    181     f"tokenizer information: bos_token is {self.tokenizer.bos_token}, {self.tokenizer.bos_token_id}, \
    182                         eos_token is {self.tokenizer.eos_token}, {self.tokenizer.eos_token_id} "
    183 )
    185 from paddleformers.trl.llm_utils import get_eos_token_id

File /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages/fastdeploy/input/text_processor.py:628, in DataProcessor._load_tokenizer(self)
    625 else:
    626     from paddleformers.transformers import AutoTokenizer
--> 628     return AutoTokenizer.from_pretrained(self.model_name_or_path, padding_side="left", use_fast=True)

File ~/external-libraries/lib/python3.10/site-packages/paddleformers/transformers/auto/tokenizer.py:329, in AutoTokenizer.from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs)
    327     tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate)
    328 if tokenizer_class is None:
--> 329     raise ValueError(
    330         f"Tokenizer class {tokenizer_class_candidate} does not exist or is not currently imported."
    331     )
    333 # Bind PaddleTokenizerMixin
    334 tokenizer_class = _bind_paddle_mixin_if_available(tokenizer_class)

ValueError: Tokenizer class Ernie4_5_Tokenizer does not exist or is not currently imported.

@jzhang533

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions