Skip to content

[Lunar Lake] UR_RESULT_ERROR_DEVICE_LOST #780

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
rpolyano opened this issue Feb 4, 2025 · 15 comments
Open

[Lunar Lake] UR_RESULT_ERROR_DEVICE_LOST #780

rpolyano opened this issue Feb 4, 2025 · 15 comments
Assignees
Labels
Ecosystem PyTorch ecosystem related Functionality XPU/GPU XPU/GPU specific issues

Comments

@rpolyano
Copy link

rpolyano commented Feb 4, 2025

Describe the bug

Trying to load the openbmb/MiniCPM-o-2_6 model results in

Native API failed. Native API returns: 20 (UR_RESULT_ERROR_DEVICE_LOST)
  File "...py/nightingale/server.py", line 49, in __init__
    self.model = self.model.to(self._device)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "...py/nightingale/server_test.py", line 10, in <module>
    service = MiniCPMService()
              ^^^^^^^^^^^^^^^^
RuntimeError: Native API failed. Native API returns: 20 (UR_RESULT_ERROR_DEVICE_LOST)

If I add .eval() after .model() it fully crashes my entire desktop, and sends me back to login screen.

I have also tried this in the docker.io/intel/intel-extension-for-pytorch:2.5.10-xpu docker container, same result.

Full code snippet:

import enum
from io import BytesIO
from typing import NewType, TypeAlias, TypeVar
from grpc import ServicerContext
import torch
from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer, AutoModel
from PIL import Image


def _select_device() -> torch.device:
    if torch.cuda.is_available():
        return torch.device("cuda:0")
    if torch.xpu.is_available():
        return torch.device("xpu")
    if torch.mps.is_available():
        return torch.device("mps")
    return torch.device("cpu")

class MiniCPMService:

    def __init__(self) -> None:
        super().__init__()

        self._device = _select_device()
        self._torch_dtype = (
            torch.bfloat16 if self._device.type != "cpu" else torch.float16
        )

        print(f'Running on {self._device} with dtype {self._torch_dtype}')

        self.model = AutoModel.from_pretrained(
            'openbmb/MiniCPM-o-2_6',
            trust_remote_code=True,
            attn_implementation='sdpa', # sdpa or flash_attention_2
            torch_dtype=torch.bfloat16,
            init_vision=True,
            init_audio=False,
            init_tts=False
        )


        self.model = self.model.to(self._device)
        # self.model = self.model.eval().to(self._device) # This eval() results in a full system crash
        self.tokenizer = AutoTokenizer.from_pretrained('openbmb/MiniCPM-o-2_6', trust_remote_code=True)

Versions

Traceback (most recent call last):
File .../collect_env.py", line 19, in
import intel_extension_for_pytorch as ipex
File "/home/roman/.local/share/virtualenvs/nightingale-uqI8m8sk/lib/python3.12/site-packages/intel_extension_for_pytorch/init.py", line 147, in
from . import _dynamo
File "/home/roman/.local/share/virtualenvs/nightingale-uqI8m8sk/lib/python3.12/site-packages/intel_extension_for_pytorch/_dynamo/init.py", line 4, in
from torch._inductor.compile_fx import compile_fx
File "/home/roman/.local/share/virtualenvs/nightingale-uqI8m8sk/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 49, in
from torch._inductor.debug import save_args_for_compile_fx_inner
File "/home/roman/.local/share/virtualenvs/nightingale-uqI8m8sk/lib/python3.12/site-packages/torch/_inductor/debug.py", line 26, in
from . import config, ir # noqa: F811, this is needed
^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/roman/.local/share/virtualenvs/nightingale-uqI8m8sk/lib/python3.12/site-packages/torch/_inductor/ir.py", line 77, in
from .runtime.hints import ReductionHint
File "/home/roman/.local/share/virtualenvs/nightingale-uqI8m8sk/lib/python3.12/site-packages/torch/_inductor/runtime/hints.py", line 36, in
attr_desc_fields = {f.name for f in fields(AttrsDescriptor)}
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/roman/.pyenv/versions/3.12.8/lib/python3.12/dataclasses.py", line 1289, in fields
raise TypeError('must be called with a dataclass type or instance') from None
TypeError: must be called with a dataclass type or instance

@louie-tsai louie-tsai self-assigned this Feb 5, 2025
@louie-tsai
Copy link
Contributor

@rpolyano
Sorry for late response.
Since you faced a device lost, could you use xpu-smi to check whether you have right devices inside docker?
here are the instructions
https://intel.github.io/xpumanager/smi_user_guide.html#discover-the-devices-in-this-machine
thanks
Louie

@louie-tsai
Copy link
Contributor

@rpolyano
moreover, your codes need flash-attn which doesn't support XPU or CPU.
https://pypi.org/project/flash-attn/
In that case, we might not be able to run the codes with flash-attn package dependency.

@louie-tsai
Copy link
Contributor

@rpolyano
Do you happen to have CUDA installation on the same machine?

@louie-tsai louie-tsai added XPU/GPU XPU/GPU specific issues Ecosystem PyTorch ecosystem related Functionality labels Mar 5, 2025
@rpolyano
Copy link
Author

rpolyano commented Mar 6, 2025

Hi Louie, sorry for the delayed response.

Just to note: I am on an Asus Zenbook S14 - an Intel Lunar Lake laptop.

xpu-smi does not seem to work:

podman run --rm -it --privileged -v /dev/dri/by-path:/dev/dri/by-path --ipc=host docker.io/intel/intel-extension-for-pytorch:2.5.10-xpu \
xpu-smi discovery

No device discovered

However

podman run --rm -it --privileged -v /dev/dri/by-path:/dev/dri/by-path --ipc=host docker.io/intel/intel-extension-for-pytorch:2.5.10-xpu \
python -c "import torch; import intel_extension_for_pytorch as ipex; print(torch.__version__); print(ipex.__version__); [print(f'[{i}]: {torch.xpu.get_device_properties(i)}') for i in range(torch.xpu.device_count())];"
2.5.1+cxx11.abi
2.5.10+xpu
[0]: _XpuDeviceProperties(name='Intel(R) Graphics [0x64a0]', platform_name='Intel(R) oneAPI Unified Runtime over Level-Zero', type='gpu', driver_version='1.6.31294+21', total_memory=14386MB, max_compute_units=64, gpu_eu_count=64, gpu_subslice_count=8, max_work_group_size=1024, max_num_sub_groups=64, sub_group_sizes=[16 32], has_fp16=1, has_fp64=1, has_atomic64=1)

So it seems python can see it, but xpu-smi cannot.

I do not have CUDA on the same machine.
The requirement of flash-attn is actually fairly recent, it didn't actually require it when I first posted this bug. I believe that flash-attn isn't actually needed, since you can use sdpa without it AFAIK - I have been able to get this code to run on an AMD Navi GPU, which also doesn't support flash-attn.

@simonlui
Copy link

IPEX 2.6.1 did not solve this issue, and I think enough time has passed between an IPEX release and the release of Lunar Lake to warrant trying to get this solved properly. Running the default workflow in ComfyUI with Stable Diffusion 1.5 yields a similar error.

!!! Exception during processing !!! Native API failed. Native API returns: 20 (UR_RESULT_ERROR_DEVICE_LOST)
Traceback (most recent call last):
  File "/home/simonlui/Code_Repositories/ComfyUI/execution.py", line 327, in execute
    output_data, output_ui, has_subgraph = get_output_data(obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/simonlui/Code_Repositories/ComfyUI/execution.py", line 202, in get_output_data
    return_values = _map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/simonlui/Code_Repositories/ComfyUI/execution.py", line 174, in _map_node_over_list
    process_inputs(input_dict, i)
  File "/home/simonlui/Code_Repositories/ComfyUI/execution.py", line 163, in process_inputs
    results.append(getattr(obj, func)(**inputs))
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/simonlui/Code_Repositories/ComfyUI/nodes.py", line 69, in encode
    return (clip.encode_from_tokens_scheduled(tokens), )
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/simonlui/Code_Repositories/ComfyUI/comfy/sd.py", line 152, in encode_from_tokens_scheduled
    pooled_dict = self.encode_from_tokens(tokens, return_pooled=return_pooled, return_dict=True)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/simonlui/Code_Repositories/ComfyUI/comfy/sd.py", line 213, in encode_from_tokens
    self.load_model()
  File "/home/simonlui/Code_Repositories/ComfyUI/comfy/sd.py", line 246, in load_model
    model_management.load_model_gpu(self.patcher)
  File "/home/simonlui/Code_Repositories/ComfyUI/comfy/model_management.py", line 595, in load_model_gpu
    return load_models_gpu([model])
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/simonlui/Code_Repositories/ComfyUI/comfy/model_management.py", line 590, in load_models_gpu
    loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
  File "/home/simonlui/Code_Repositories/ComfyUI/comfy/model_management.py", line 409, in model_load
    self.model_use_more_vram(use_more_vram, force_patch_weights=force_patch_weights)
  File "/home/simonlui/Code_Repositories/ComfyUI/comfy/model_management.py", line 438, in model_use_more_vram
    return self.model.partially_load(self.device, extra_memory, force_patch_weights=force_patch_weights)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/simonlui/Code_Repositories/ComfyUI/comfy/model_patcher.py", line 827, in partially_load
    raise e
  File "/home/simonlui/Code_Repositories/ComfyUI/comfy/model_patcher.py", line 824, in partially_load
    self.load(device_to, lowvram_model_memory=current_used + extra_memory, force_patch_weights=force_patch_weights, full_load=full_load)
  File "/home/simonlui/Code_Repositories/ComfyUI/comfy/model_patcher.py", line 670, in load
    x[2].to(device_to)
  File "/home/simonlui/.conda/envs/comfyui/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1343, in to
    return self._apply(convert)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/simonlui/.conda/envs/comfyui/lib/python3.12/site-packages/torch/nn/modules/module.py", line 930, in _apply
    param_applied = fn(param)
                    ^^^^^^^^^
  File "/home/simonlui/.conda/envs/comfyui/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1329, in convert
    return t.to(
           ^^^^^
RuntimeError: Native API failed. Native API returns: 20 (UR_RESULT_ERROR_DEVICE_LOST)

Prompt executed in 7.75 seconds
Exception in thread Thread-4 (prompt_worker):
Traceback (most recent call last):
  File "/home/simonlui/.conda/envs/comfyui/lib/python3.12/threading.py", line 1075, in _bootstrap_inner
    self.run()
  File "/home/simonlui/.conda/envs/comfyui/lib/python3.12/threading.py", line 1012, in run
    self._target(*self._args, **self._kwargs)
  File "/home/simonlui/Code_Repositories/ComfyUI/main.py", line 208, in prompt_worker
    comfy.model_management.soft_empty_cache()
  File "/home/simonlui/Code_Repositories/ComfyUI/comfy/model_management.py", line 1206, in soft_empty_cache
    torch.xpu.empty_cache()
  File "/home/simonlui/.conda/envs/comfyui/lib/python3.12/site-packages/intel_extension_for_pytorch/xpu/memory.py", line 22, in empty_cache
    intel_extension_for_pytorch._C._emptyCache()
RuntimeError: Native API failed. Native API returns: 20 (UR_RESULT_ERROR_DEVICE_LOST)

Seems odd that this can just happen like that.

So it seems python can see it, but xpu-smi cannot.

xpu-smi most likely can't see it because it wasn't built to see it, it was only built to see enterprise GPUs from Intel's Flex and Max lines as stated in the repository and there is an open issue at intel/xpumanager#92 for this which hasn't gotten anywhere. The A770 seems like the exception here with unofficial support which has worked since there is an equivalent Flex GPU using the same silicon die.

@simonlui
Copy link

If I use the Nightly Pytorch wheel intended for BMG on Linux at #764 (comment) on my LNL laptop, the device lost issue seems to have been fixed, but Stable Diffusion 1.5 default workflow fails with either a black image generated or a garbage image.

Image

So some more work left to be done since digging some more into the issue here, Lunar Lake still isn't officially supported by Pytorch officially in any capacity on Ubuntu 24.04 which I have and which may be the cause but doubtful. Not sure if it is good enough to run LLMs, I don't have the model the issue opener used so can not test.

@mr-cn
Copy link

mr-cn commented Mar 25, 2025

I have this problem, too.
Ultra 7 258V, Windows 11, running latest ipex

(ipex) PS D:\ollama> python -c "import torch; import intel_extension_for_pytorch as ipex; print(torch.__version__); print(ipex.__version__); [print(f'[{i}]: {torch.xpu.get_device_properties(i)}') for i in range(torch.xpu.device_count())];"
[W325 21:23:34.000000000 OperatorEntry.cpp:161] Warning: Warning only once for all operators,  other operators may also be overridden.
  Overriding a previously registered kernel for the same operator and the same dispatch key
  operator: aten::_validate_compressed_sparse_indices(bool is_crow, Tensor compressed_idx, Tensor plain_idx, int cdim, int dim, int nnz) -> ()
    registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\build\aten\src\ATen\RegisterSchema.cpp:6
  dispatch key: XPU
  previous kernel: registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\build\aten\src\ATen\RegisterCPU.cpp:30477
       new kernel: registered at G:\frameworks.ai.pytorch.ipex-gpu\build\Release\csrc\gpu\csrc\aten\generated\ATen\RegisterXPU.cpp:468 (function operator ())
2.6.0+xpu
2.6.10+xpu
[0]: _XpuDeviceProperties(name='Intel(R) Arc(TM) 140V GPU (16GB)', platform_name='Intel(R) oneAPI Unified Runtime over Level-Zero', type='gpu', driver_version='1.6.31441', total_memory=16841MB, max_compute_units=64, gpu_eu_count=64, gpu_subslice_count=8, max_work_group_size=1024, max_num_sub_groups=64, sub_group_sizes=[16 32], has_fp16=1, has_fp64=1, has_atomic64=1)

@mr-cn
Copy link

mr-cn commented Mar 31, 2025

IPEX2.6 &2.5 produces same error. Only flux model has this problem currently. Is it driver related or application related?

To see the GUI go to: http://127.0.0.1:8188
got prompt
Failed to validate prompt for output 53:
* (prompt):
  - Required input is missing: images
* SaveImage 53:
  - Required input is missing: images
Output will be ignored
Using pytorch attention in VAE
Using pytorch attention in VAE
VAE load device: xpu:0, offload device: cpu, dtype: torch.bfloat16
model weight dtype torch.float8_e4m3fn, manual cast: torch.bfloat16
model_type FLUX
Requested to load FluxClipModel_
loaded completely 9.5367431640625e+25 4777.53759765625 True
2025-03-31 21:46:20,619 - _logger.py - IPEX - INFO - Currently split master weight for xpu only support sgd
2025-03-31 21:46:20,643 - _logger.py - IPEX - INFO - Conv BatchNorm folding failed during the optimize process.
2025-03-31 21:46:20,652 - _logger.py - IPEX - INFO - Linear BatchNorm folding failed during the optimize process.
CLIP/text encoder model load device: xpu:0, offload device: cpu, current: xpu:0, dtype: torch.float16
clip missing: ['text_projection.weight']
Token indices sequence length is longer than the specified maximum sequence length for this model (149 > 77). Running this sequence through the model will result in indexing errors
2025-03-31 21:46:31,406 - _logger.py - IPEX - INFO - Currently split master weight for xpu only support sgd
2025-03-31 21:46:31,413 - _logger.py - IPEX - INFO - Conv BatchNorm folding failed during the optimize process.
2025-03-31 21:46:31,419 - _logger.py - IPEX - INFO - Linear BatchNorm folding failed during the optimize process.
Requested to load Flux
loaded completely 13420.0888359375 11350.067443847656 True
2025-03-31 21:47:01,037 - _logger.py - IPEX - INFO - Currently split master weight for xpu only support sgd
2025-03-31 21:47:01,049 - _logger.py - IPEX - INFO - Conv BatchNorm folding failed during the optimize process.
2025-03-31 21:47:01,058 - _logger.py - IPEX - INFO - Linear BatchNorm folding failed during the optimize process.
  0%|                                                                                           | 0/20 [00:01<?, ?it/s]
!!! Exception during processing !!! Native API failed. Native API returns: 2147483646 (UR_RESULT_ERROR_UNKNOWN)
Traceback (most recent call last):
  File "D:\comfy\ComfyUI\execution.py", line 327, in execute
    output_data, output_ui, has_subgraph = get_output_data(obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\comfy\ComfyUI\execution.py", line 202, in get_output_data
    return_values = _map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\comfy\ComfyUI\execution.py", line 174, in _map_node_over_list
    process_inputs(input_dict, i)
  File "D:\comfy\ComfyUI\execution.py", line 163, in process_inputs
    results.append(getattr(obj, func)(**inputs))
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\comfy\ComfyUI\comfy_extras\nodes_custom_sampler.py", line 657, in sample
    samples = guider.sample(noise.generate_noise(latent), latent_image, sampler, sigmas, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise.seed)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\comfy\ComfyUI\comfy\samplers.py", line 1008, in sample
    output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\comfy\ComfyUI\comfy\patcher_extension.py", line 110, in execute
    return self.original(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\comfy\ComfyUI\comfy\samplers.py", line 976, in outer_sample
    output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\comfy\ComfyUI\comfy\samplers.py", line 959, in inner_sample
    samples = executor.execute(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\comfy\ComfyUI\comfy\patcher_extension.py", line 110, in execute
    return self.original(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\comfy\ComfyUI\comfy\samplers.py", line 738, in sample
    samples = self.sampler_function(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **self.extra_options)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\comfy\ipex\Lib\site-packages\torch\utils\_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "D:\comfy\ComfyUI\comfy\k_diffusion\sampling.py", line 161, in sample_euler
    denoised = model(x, sigma_hat * s_in, **extra_args)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\comfy\ComfyUI\comfy\samplers.py", line 390, in __call__
    out = self.inner_model(x, sigma, model_options=model_options, seed=seed)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\comfy\ComfyUI\comfy\samplers.py", line 939, in __call__
    return self.predict_noise(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\comfy\ComfyUI\comfy\samplers.py", line 942, in predict_noise
    return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\comfy\ComfyUI\comfy\samplers.py", line 370, in sampling_function
    out = calc_cond_batch(model, conds, x, timestep, model_options)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\comfy\ComfyUI\comfy\samplers.py", line 206, in calc_cond_batch
    return executor.execute(model, conds, x_in, timestep, model_options)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\comfy\ComfyUI\comfy\patcher_extension.py", line 110, in execute
    return self.original(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\comfy\ComfyUI\comfy\samplers.py", line 319, in _calc_cond_batch
    output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\comfy\ComfyUI\comfy\model_base.py", line 137, in apply_model
    return comfy.patcher_extension.WrapperExecutor.new_class_executor(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\comfy\ComfyUI\comfy\patcher_extension.py", line 110, in execute
    return self.original(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\comfy\ComfyUI\comfy\model_base.py", line 170, in _apply_model
    model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\comfy\ipex\Lib\site-packages\torch\nn\modules\module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\comfy\ipex\Lib\site-packages\torch\nn\modules\module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\comfy\ComfyUI\comfy\ldm\flux\model.py", line 206, in forward
    out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\comfy\ComfyUI\comfy\ldm\flux\model.py", line 145, in forward_orig
    img, txt = block(img=img,
               ^^^^^^^^^^^^^^
  File "D:\comfy\ipex\Lib\site-packages\torch\nn\modules\module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\comfy\ipex\Lib\site-packages\torch\nn\modules\module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\comfy\ComfyUI\comfy\ldm\flux\layers.py", line 199, in forward
    img = img + apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img)
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\comfy\ipex\Lib\site-packages\torch\nn\modules\module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\comfy\ipex\Lib\site-packages\torch\nn\modules\module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\comfy\ipex\Lib\site-packages\torch\nn\modules\container.py", line 250, in forward
    input = module(input)
            ^^^^^^^^^^^^^
  File "D:\comfy\ipex\Lib\site-packages\torch\nn\modules\module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\comfy\ipex\Lib\site-packages\torch\nn\modules\module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\comfy\ComfyUI\comfy\ops.py", line 71, in forward
    return self.forward_comfy_cast_weights(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\comfy\ComfyUI\comfy\ops.py", line 67, in forward_comfy_cast_weights
    return torch.nn.functional.linear(input, weight, bias)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Native API failed. Native API returns: 2147483646 (UR_RESULT_ERROR_UNKNOWN)

Prompt executed in 59.43 seconds
Exception in thread Thread-1 (prompt_worker):
Traceback (most recent call last):
  File "C:\Program Files\WindowsApps\PythonSoftwareFoundation.Python.3.11_3.11.2544.0_x64__qbz5n2kfra8p0\Lib\threading.py", line 1045, in _bootstrap_inner
    self.run()
  File "C:\Program Files\WindowsApps\PythonSoftwareFoundation.Python.3.11_3.11.2544.0_x64__qbz5n2kfra8p0\Lib\threading.py", line 982, in run
    self._target(*self._args, **self._kwargs)
  File "D:\comfy\ComfyUI\main.py", line 208, in prompt_worker
    comfy.model_management.soft_empty_cache()
  File "D:\comfy\ComfyUI\comfy\model_management.py", line 1235, in soft_empty_cache
    torch.xpu.empty_cache()
  File "D:\comfy\ipex\Lib\site-packages\intel_extension_for_pytorch\xpu\memory.py", line 22, in empty_cache
    intel_extension_for_pytorch._C._emptyCache()
RuntimeError: Native API failed. Native API returns: 20 (UR_RESULT_ERROR_DEVICE_LOST)

@Stonepia
Copy link
Contributor

Stonepia commented Apr 2, 2025

This might be a known issue. When the system memory is under high pressure, it might cause this UR Error. You could witness that the system memory is almost full, then you get this UR Error.

We have reported to the driver team for the fix. For now, I would suggest to:

  1. Reduce memory usage by using a smaller model / reduce batch size etc.
  2. Use a bigger memory if possible.

Related: intel/torch-xpu-ops#1324

@simonlui
Copy link

simonlui commented Apr 2, 2025

This might be a known issue. When the system memory is under high pressure, it might cause this UR Error. You could witness that the system memory is almost full, then you get this UR Error.

In the example I gave, I am using Stable Diffusion 1.5 with the default workflow of generating 1 image and it is still failing with this error. It is about as lightweight of a model and example of image diffusion you can run. The LNL laptop I use has 32GB of RAM of which half is allocated as VRAM to the iGPU. It doesn't make sense to me that it would be failing for that reason.

@Stonepia
Copy link
Contributor

Stonepia commented Apr 2, 2025

Well, for SD1.5, then there is no reason it should fail with this error.
SDv3 will have this UR Error due to the memory pressure.

Could you take a look at how much memory it uses when the model runs? Seems that you are using Linux, you may use xpu-smi for example.

watch -n 1 xpu-smi ps

If you are using Windows, you could simply see from the task manager.

@simonlui
Copy link

simonlui commented Apr 2, 2025

Could you take a look at how much memory it uses when the model runs? Seems that you are using Linux, you may use xpu-smi for example.

Lunar Lake does not work with xpu-smi, nor does anything newer than the original Alchemist series of Arc GPUs, per intel/xpumanager#92 and doing xpu-smi discovery gives me No device discovered on Linux. I need to check what happens on Windows but I need to reinstall the OS before I can check since I broke the install for unrelated reasons. I'll be sure to include Task Manager monitoring or something of that nature while running the workflow.

@Stonepia
Copy link
Contributor

Stonepia commented Apr 2, 2025

Could you take a look at how much memory it uses when the model runs? Seems that you are using Linux, you may use xpu-smi for example.

Lunar Lake does not work with xpu-smi, nor does anything newer than the original Alchemist series of Arc GPUs, per intel/xpumanager#92 and doing xpu-smi discovery gives me No device discovered on Linux. I need to check what happens on Windows but I need to reinstall the OS before I can check since I broke the install for unrelated reasons. I'll be sure to include Task Manager monitoring or something of that nature while running the workflow.

Thanks for the patience~! Honestly I don't have a good idea of the clue, but let's start from the reproducer in this thread. Meanwhile, if you have a reproducer, it is welcome to post and we could have a try locally.

@Stonepia Stonepia self-assigned this Apr 2, 2025
@simonlui
Copy link

simonlui commented Apr 2, 2025

@Stonepia I found out I wasn't running the latest torch nightly for another unrelated issue for my LNL laptop and because of that, while doing testing, I found out that using torch nightly 2.8.0.dev20250321+xpu fixes the issue with the lost device while running with normal workloads. The only issue left is the one you linked about UR error when you use too much memory, the system does not handle itself gracefully and sometimes hard locks.

@louie-tsai
Copy link
Contributor

this should be the known issue. It is because when the memory pressure is too high, the context will be broken. In extreme cases, it will result in the UR_ERROR_DEVICE_LOST.
one may find that the host memory is full and never released. A temporal solution:

  • call torch.xpu.empty_cache() to manually release memory.
  • Decrease batch size to reduce the memory usage.

An internal issue has been created, and will keep you posted

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Ecosystem PyTorch ecosystem related Functionality XPU/GPU XPU/GPU specific issues
Projects
None yet
Development

No branches or pull requests

5 participants