@@ -18,7 +18,10 @@ def ipex_init(): # pylint: disable=too-many-statements
18
18
if hasattr (torch , "cuda" ) and hasattr (torch .cuda , "is_xpu_hijacked" ) and torch .cuda .is_xpu_hijacked :
19
19
return True , "Skipping IPEX hijack"
20
20
else :
21
- try : # force xpu device on torch compile and triton
21
+ try :
22
+ # force xpu device on torch compile and triton
23
+ # import inductor utils to get around lazy import
24
+ from torch ._inductor import utils as torch_inductor_utils # pylint: disable=import-error, unused-import
22
25
torch ._inductor .utils .GPU_TYPES = ["xpu" ]
23
26
torch ._inductor .utils .get_gpu_type = lambda * args , ** kwargs : "xpu"
24
27
from triton import backends as triton_backends # pylint: disable=import-error
@@ -187,11 +190,13 @@ def ipex_init(): # pylint: disable=too-many-statements
187
190
ipex ._C ._DeviceProperties .multi_processor_count = ipex ._C ._DeviceProperties .gpu_subslice_count
188
191
ipex ._C ._DeviceProperties .major = 12
189
192
ipex ._C ._DeviceProperties .minor = 1
193
+ ipex ._C ._DeviceProperties .L2_cache_size = 16 * 1024 * 1024 # A770 and A750
190
194
else :
191
195
torch ._C ._cuda_getCurrentRawStream = torch ._C ._xpu_getCurrentRawStream
192
196
torch ._C ._XpuDeviceProperties .multi_processor_count = torch ._C ._XpuDeviceProperties .gpu_subslice_count
193
197
torch ._C ._XpuDeviceProperties .major = 12
194
198
torch ._C ._XpuDeviceProperties .minor = 1
199
+ torch ._C ._XpuDeviceProperties .L2_cache_size = 16 * 1024 * 1024 # A770 and A750
195
200
196
201
# Fix functions with ipex:
197
202
# torch.xpu.mem_get_info always returns the total memory as free memory
@@ -200,14 +205,15 @@ def ipex_init(): # pylint: disable=too-many-statements
200
205
torch ._utils ._get_available_device_type = lambda : "xpu"
201
206
torch .has_cuda = True
202
207
torch .cuda .has_half = True
203
- torch .cuda .is_bf16_supported = lambda * args , ** kwargs : True
208
+ torch .cuda .is_bf16_supported = getattr ( torch . xpu , "is_bf16_supported" , lambda * args , ** kwargs : True )
204
209
torch .cuda .is_fp16_supported = lambda * args , ** kwargs : True
205
210
torch .backends .cuda .is_built = lambda * args , ** kwargs : True
206
211
torch .version .cuda = "12.1"
207
- torch .cuda .get_arch_list = lambda : ["ats-m150 " , "pvc" ]
212
+ torch .cuda .get_arch_list = getattr ( torch . xpu , "get_arch_list" , lambda : ["pvc " , "dg2" , "ats-m150" ])
208
213
torch .cuda .get_device_capability = lambda * args , ** kwargs : (12 ,1 )
209
214
torch .cuda .get_device_properties .major = 12
210
215
torch .cuda .get_device_properties .minor = 1
216
+ torch .cuda .get_device_properties .L2_cache_size = 16 * 1024 * 1024 # A770 and A750
211
217
torch .cuda .ipc_collect = lambda * args , ** kwargs : None
212
218
torch .cuda .utilization = lambda * args , ** kwargs : 0
213
219
0 commit comments