@@ -8991,8 +8991,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
8991
8991
" }\n"
8992
8992
" func.func @\"__torch_mlir_shape_fn.aten.scaled_dot_product_attention\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.optional<list<int>>, %arg4: !torch.float, %arg5: !torch.bool, %arg6: !torch.optional<float>, %arg7: !torch.bool) -> !torch.list<int> {\n"
8993
8993
" %int-1 = torch.constant.int -1\n"
8994
+ " %int0 = torch.constant.int 0\n"
8994
8995
" %0 = torch.aten.__getitem__.t %arg2, %int-1 : !torch.list<int>, !torch.int -> !torch.int\n"
8995
8996
" %1 = torch.aten._set_item.t %arg0, %int-1, %0 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>\n"
8997
+ " %2 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
8998
+ " %3 = torch.prim.ListConstruct %2 : (!torch.int) -> !torch.list<int>\n"
8999
+ " %4 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
9000
+ " %5 = torch.prim.ListConstruct %4 : (!torch.int) -> !torch.list<int>\n"
9001
+ " %6 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
9002
+ " %7 = torch.prim.ListConstruct %6 : (!torch.int) -> !torch.list<int>\n"
9003
+ " %8 = call @__torch__.torch.jit._shape_functions.broadcast_three(%3, %5, %7) : (!torch.list<int>, !torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
9004
+ " %9 = torch.aten.__getitem__.t %8, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
9005
+ " %10 = torch.aten._set_item.t %arg0, %int0, %9 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>\n"
8996
9006
" return %arg0 : !torch.list<int>\n"
8997
9007
" }\n"
8998
9008
" func.func @\"__torch_mlir_shape_fn.aten.zeros\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>) -> !torch.list<int> {\n"
0 commit comments