Skip to content

Commit 60379d7

Browse files
authored
[Torch] fix sdpa's shape function when different batch size (#4137)
1 parent 80a3dfd commit 60379d7

File tree

2 files changed

+12
-0
lines changed

2 files changed

+12
-0
lines changed

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8991,8 +8991,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
89918991
" }\n"
89928992
" func.func @\"__torch_mlir_shape_fn.aten.scaled_dot_product_attention\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.optional<list<int>>, %arg4: !torch.float, %arg5: !torch.bool, %arg6: !torch.optional<float>, %arg7: !torch.bool) -> !torch.list<int> {\n"
89938993
" %int-1 = torch.constant.int -1\n"
8994+
" %int0 = torch.constant.int 0\n"
89948995
" %0 = torch.aten.__getitem__.t %arg2, %int-1 : !torch.list<int>, !torch.int -> !torch.int\n"
89958996
" %1 = torch.aten._set_item.t %arg0, %int-1, %0 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>\n"
8997+
" %2 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
8998+
" %3 = torch.prim.ListConstruct %2 : (!torch.int) -> !torch.list<int>\n"
8999+
" %4 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
9000+
" %5 = torch.prim.ListConstruct %4 : (!torch.int) -> !torch.list<int>\n"
9001+
" %6 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
9002+
" %7 = torch.prim.ListConstruct %6 : (!torch.int) -> !torch.list<int>\n"
9003+
" %8 = call @__torch__.torch.jit._shape_functions.broadcast_three(%3, %5, %7) : (!torch.list<int>, !torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
9004+
" %9 = torch.aten.__getitem__.t %8, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
9005+
" %10 = torch.aten._set_item.t %arg0, %int0, %9 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>\n"
89969006
" return %arg0 : !torch.list<int>\n"
89979007
" }\n"
89989008
" func.func @\"__torch_mlir_shape_fn.aten.zeros\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>) -> !torch.list<int> {\n"

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1348,10 +1348,12 @@ def aten〇_trilinear〡shape(i1: List[int], i2: List[int], i3: List[int], expan
13481348
@check_shape_function([
13491349
Invocation(TensorOfShape(3, 2, 8, 4), TensorOfShape(3, 2, 8, 4), TensorOfShape(3, 2, 8, 4)), # Same shape
13501350
Invocation(TensorOfShape(3, 2, 16, 8), TensorOfShape(3, 2, 8, 8), TensorOfShape(3, 2, 8, 4)), # Different shape
1351+
Invocation(TensorOfShape(1, 2, 8, 4), TensorOfShape(3, 2, 8, 4), TensorOfShape(3, 2, 8, 4)), # Different batch size
13511352
])
13521353
def aten〇scaled_dot_product_attention〡shape(query: List[int], key: List[int], value: List[int], attn_mask: Optional[List[int]] = None, dropout_p: float = 0., is_causal: bool = False, scale: Optional[float] = None, enable_gqa: bool = False) -> List[int]:
13531354
outshape = query
13541355
outshape[-1] = value[-1]
1356+
outshape[0] = upstream_shape_functions.broadcast_three([query[0]], [key[0]], [value[0]])[0]
13551357
return outshape
13561358

13571359
@check_shape_function([

0 commit comments

Comments
 (0)