You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am trying to conduct language model inference using a computer that has an Intel processor. I have loading pytorch for Intel (along with the Intel Driver and DeepLearning tools) following torch and intel instructions:
My log then states that is "Using device xpu". This leaves me to assume that the issue is not with pytorch of my processor.
The error arises when I load the model + device to my masked language modeling inference code. Here is my inference code:
import torch
from collections import defaultdict
import logging
import numpy as np
def prediction_function(
text: str,
model,
tokenizer,
device,
window_size: int = 512,
overlap: int = 128,
num_predictions: int = 5,
):
all_predictions = defaultdict(list)
tokens = tokenizer.encode(text, add_special_tokens=False)
num_tokens = len(tokens)
# overlapping window loop to process text beyond 512 tokens
for i in range(0, num_tokens, window_size - overlap):
chunk_ids = tokens[i : min(i + window_size, num_tokens)]
chunk_ids = chunk_ids[:512]
chunk = tokenizer.decode(chunk_ids)
chunk_inputs = tokenizer(
chunk,
return_tensors="pt",
return_attention_mask=True,
add_special_tokens=True,
truncation=True,
max_length=512,
)
chunk_inputs = {k: v.to(device) for k, v in chunk_inputs.items()}
with torch.no_grad():
outputs = model(**chunk_inputs)
predictions = outputs.logits
masked_indices = [
i
for i, token_id in enumerate(chunk_inputs["input_ids"][0])
if token_id == tokenizer.mask_token_id
]
logging.info(masked_indices)
for masked_index in masked_indices:
predicted_probs = predictions[0, masked_index]
sorted_preds, sorted_idx = torch.sort(predicted_probs, descending=True)
masked_predictions = []
for k in range(num_predictions):
predicted_index = int(sorted_idx[k].item())
predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
probability = torch.softmax(predicted_probs, dim=-1)[
predicted_index
].item()
masked_predictions.append((predicted_token, probability))
logging.info(f"Predictions for {masked_index}: {masked_predictions}")
all_predictions[masked_index + i].extend(masked_predictions)
logging.info(f"All predictions: {all_predictions}")
final_predictions = {}
for masked_index, prediction_list in all_predictions.items():
# group subword predictions
subword_groups: dict = {}
for token, prob in prediction_list:
if token.startswith("##"):
base_word = token[2:] # remove "##" prefix
if base_word not in subword_groups:
subword_groups[base_word] = []
subword_groups[base_word].append((token, prob))
else: # whole word token
subword_groups[token] = [(token, prob)]
logging.info(f"Subword groups: {subword_groups}")
whole_word_predictions = []
for base_word, subword_list in subword_groups.items():
max_prob = 0.0
for subtoken, prob in subword_list:
if prob > max_prob:
max_prob = prob
whole_word_predictions.append((base_word, max_prob))
# sort by prob
sorted_predictions = sorted(
whole_word_predictions, key=lambda x: x[1], reverse=True
)
# keep top num_predictions
final_predictions[masked_index] = sorted_predictions[:num_predictions]
logging.info(type(final_predictions))
logging.info(f"Final predictions: {final_predictions}")
return final_predictions
I routinely receive some form of runtime error, such as:
Traceback (most recent call last):
File "C:\Users\jm9095\logion-app\src\backend\main.py", line 181, in prediction_endpoint
results = predict.prediction_function(
text,
...<5 lines>...
num_predictions=5,
)
File "C:\Users\jm9095\logion-app\src\backend\prediction\predict.py", line 59, in prediction_function
outputs = model(**chunk_inputs)
File "C:\Users\jm9095\AppData\Local\Programs\Python\Python313\Lib\site-packages\torch\nn\modules\module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File "C:\Users\jm9095\AppData\Local\Programs\Python\Python313\Lib\site-packages\torch\nn\modules\module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
File "C:\Users\jm9095\AppData\Local\Programs\Python\Python313\Lib\site-packages\transformers\models\bert\modeling_bert.py", line 1461, in forward
outputs = self.bert(
input_ids,
...<9 lines>...
return_dict=return_dict,
)
File "C:\Users\jm9095\AppData\Local\Programs\Python\Python313\Lib\site-packages\torch\nn\modules\module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File "C:\Users\jm9095\AppData\Local\Programs\Python\Python313\Lib\site-packages\torch\nn\modules\module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
File "C:\Users\jm9095\AppData\Local\Programs\Python\Python313\Lib\site-packages\transformers\models\bert\modeling_bert.py", line 1108, in forward
extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
attention_mask, embedding_output.dtype, tgt_len=seq_length
)
File "C:\Users\jm9095\AppData\Local\Programs\Python\Python313\Lib\site-packages\transformers\modeling_attn_mask_utils.py", line 448, in _prepare_4d_attention_mask_for_sdpa
if not is_tracing and torch.all(mask == 1):
~~~~~~~~~^^^^^^^^^^^
RuntimeError: Native API failed. Native API returns: 20 (UR_RESULT_ERROR_DEVICE_LOST)
Or I most recently encountered this error without making any changes to my code or system:
Rraceback (most recent call last):
File "C:\Users\jm9095\logion-app\src\backend\main.py", line 181, in prediction_endpoint
results = predict.prediction_function(
text,
...<5 lines>...
num_predictions=5,
)
File "C:\Users\jm9095\logion-app\src\backend\prediction\predict.py", line 74, in prediction_function
predicted_index = int(sorted_idx[k].item())
~~~~~~~~~~~~~~~~~~^^
RuntimeError: Native API failed. Native API returns: 2147483646 (UR_RESULT_ERROR_UNKNOWN)
I wondered whether the issue is because of the Intel driver, as I do not encounter this problem with other graphics processors (e.g. Nvidia, M-chip) or when using my PC's CPUs. That is, the issue only arises when trying to load the xpu device. But given that pytorch does load the xpu device, I do not know whether this is correct. If not, might anyone have suggestions on the source of the runtime error?
The text was updated successfully, but these errors were encountered:
python -c "import torch; import intel_extension_for_pytorch as ipex; print(torch.__version__); print(ipex.__version__); [print(f'[{i}]: {torch.xpu.get_device_properties(i)}') for i in range(torch.xpu.device_count())];"
[W408 09:36:25.000000000 OperatorEntry.cpp:161] Warning: Warning only once for all operators, other operators may also be overridden.
Overriding a previously registered kernel for the same operator and the same dispatch key
operator: aten::_validate_compressed_sparse_indices(bool is_crow, Tensor compressed_idx, Tensor plain_idx, int cdim, int dim, int nnz) -> ()
registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\build\aten\src\ATen\RegisterSchema.cpp:6
dispatch key: XPU
previous kernel: registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\build\aten\src\ATen\RegisterCPU.cpp:30477
new kernel: registered at E:\frameworks.ai.pytorch.ipex-gpu\build\Release\csrc\gpu\csrc\aten\generated\ATen\RegisterXPU.cpp:468 (function operator ())
2.6.0.post0+xpu
[0]: _XpuDeviceProperties(name='Intel(R) UHD Graphics', platform_name='Intel(R) oneAPI Unified Runtime over Level-Zero', type='gpu', driver_version='1.6.32413', total_memory=7102MB, max_compute_units=80, gpu_eu_count=80, gpu_subslice_count=10, max_work_group_size=512, max_num_sub_groups=64, sub_group_sizes=[8 16 32], has_fp16=1, has_fp64=0, has_atomic64=1)
Describe the issue
I am trying to conduct language model inference using a computer that has an Intel processor. I have loading pytorch for Intel (along with the Intel Driver and DeepLearning tools) following torch and intel instructions:
Getting Started on Intel GPU — PyTorch 2.6 documentation
PyTorch Prerequisites for Intel® GPUs
This is the processor on my computer:
13th Gen Intel(R) Core(TM) i5-1335U, 1300 Mhz, 10 Core(s), 12 Logical Processor(s)
I run this script:
My log then states that is "Using device xpu". This leaves me to assume that the issue is not with pytorch of my processor.
The error arises when I load the model + device to my masked language modeling inference code. Here is my inference code:
I routinely receive some form of runtime error, such as:
Or I most recently encountered this error without making any changes to my code or system:
I wondered whether the issue is because of the Intel driver, as I do not encounter this problem with other graphics processors (e.g. Nvidia, M-chip) or when using my PC's CPUs. That is, the issue only arises when trying to load the xpu device. But given that pytorch does load the xpu device, I do not know whether this is correct. If not, might anyone have suggestions on the source of the runtime error?
The text was updated successfully, but these errors were encountered: