Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LLVMGPU] Scatter codegen fails on unaligned shapes #19760

Open
Groverkss opened this issue Jan 22, 2025 · 1 comment
Open

[LLVMGPU] Scatter codegen fails on unaligned shapes #19760

Groverkss opened this issue Jan 22, 2025 · 1 comment
Labels
bug 🐞 Something isn't working

Comments

@Groverkss
Copy link
Contributor

Groverkss commented Jan 22, 2025

Input IR:

hal.executable @decode_bs1$async_dispatch_124 {
  hal.executable.variant public @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb", {abi = "hip", iree.gpu.target = #iree_gpu.target<arch = "gfx1100", features = "", wgp = <compute =  fp64|fp32|fp16|int64|int32|int16|int8, storage =  b64|b32|b16|b8, subgroup =  shuffle|arithmetic, dot =  dp4xi8toi32, mma = [<WMMA_F32_16x16x16_F16>, <WMMA_F16_16x16x16_F16>, <WMMA_I32_16x16x16_I8>, <WMMA_I32_16x16x16_I8>, <WMMA_I32_16x16x16_I8>], subgroup_size_choices = [32, 64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, max_workgroup_counts = [2147483647, 2147483647, 2147483647], max_load_instruction_bits = 128, simds_per_wgp = 4, vgpr_space_bits = 8192>>, ukernels = "none"}>) {
    hal.executable.export public @decode_bs1$async_dispatch_124_scatter_Dx32x16x100xf16_dispatch_tensor_store ordinal(0) layout(#hal.pipeline.layout<constants = 4, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) {
    ^bb0(%arg0: !hal.device, %arg1: index):
      %x, %y, %z = flow.dispatch.workgroup_count_from_slice %arg1
      hal.return %x, %y, %z : index, index, index
    }
    builtin.module {
      func.func @decode_bs1$async_dispatch_124_scatter_Dx32x16x100xf16_dispatch_tensor_store() {
        %c7872 = arith.constant 7872 : index
        %c46656 = arith.constant 46656 : index
        %c0 = arith.constant 0 : index
        %c32_i64 = arith.constant 32 : i64
        %0 = hal.interface.constant.load layout(<constants = 4, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(0) : i32
        %1 = hal.interface.constant.load layout(<constants = 4, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(1) : i32
        %2 = hal.interface.constant.load layout(<constants = 4, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(2) : i32
        %3 = hal.interface.constant.load layout(<constants = 4, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(3) : i32
        %4 = arith.extui %0 : i32 to i64
        %5 = arith.extui %1 : i32 to i64
        %6 = arith.shli %5, %c32_i64 : i64
        %7 = arith.ori %4, %6 : i64
        %8 = arith.index_castui %7 : i64 to index
        %9 = arith.extui %2 : i32 to i64
        %10 = arith.extui %3 : i32 to i64
        %11 = arith.shli %10, %c32_i64 : i64
        %12 = arith.ori %9, %11 : i64
        %13 = arith.index_castui %12 : i64 to index
        %14 = util.assume.int %13<umin = 1, umax = 9007199254740991> : index
        %15 = hal.interface.binding.subspan layout(<constants = 4, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(1) alignment(64) offset(%c46656) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<32x100xf16>>
        %16 = hal.interface.binding.subspan layout(<constants = 4, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(1) alignment(64) offset(%c7872) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<32x3xi32>>
        %17 = flow.dispatch.workload.ordinal %14, 0 : index
        %18 = hal.interface.binding.subspan layout(<constants = 4, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<?x26x2x32x16x100xf16>>{%17}
        %19 = hal.interface.binding.subspan layout(<constants = 4, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(2) alignment(64) offset(%8) flags(Indirect) : !flow.dispatch.tensor<writeonly:tensor<?x32x16x100xf16>>{%17}
        %20 = flow.dispatch.tensor.load %15, offsets = [0, 0], sizes = [32, 100], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<32x100xf16>> -> tensor<32x100xf16>
        %21 = flow.dispatch.tensor.load %16, offsets = [0, 0], sizes = [32, 3], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<32x3xi32>> -> tensor<32x3xi32>
        %22 = flow.dispatch.tensor.load %18, offsets = [0, 4, 0, 0, 0, 0], sizes = [%17, 1, 1, 32, 16, 100], strides = [1, 1, 1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<?x26x2x32x16x100xf16>>{%17} -> tensor<?x32x16x100xf16>
        %23 = iree_linalg_ext.scatter dimension_map = [0, 1, 2] unique_indices(true) ins(%20, %21 : tensor<32x100xf16>, tensor<32x3xi32>) outs(%22 : tensor<?x32x16x100xf16>) {
        ^bb0(%arg0: f16, %arg1: f16):
          iree_linalg_ext.yield %arg0 : f16
        } -> tensor<?x32x16x100xf16>
        flow.dispatch.tensor.store %23, %19, offsets = [0, 0, 0, 0], sizes = [%17, 32, 16, 100], strides = [1, 1, 1, 1] : tensor<?x32x16x100xf16> -> !flow.dispatch.tensor<writeonly:tensor<?x32x16x100xf16>>{%17}
        return
      }
    }
  }
}

Reproduce instructions:

iree-compile <input-file>

IREE Commit: 525389c

@Groverkss Groverkss added the bug 🐞 Something isn't working label Jan 22, 2025
@Groverkss
Copy link
Contributor Author

Looks like bufferization is creating a copy from workgroup memory to global memory outside the scf.forall region. Maybe a bufferization issue that needs to be fixed?

Error:

t.mlir:41:9: error: 'memref.copy' op write affecting operations on global resources are restricted to workgroup distributed contexts.
        flow.dispatch.tensor.store %23, %19, offsets = [0, 0, 0, 0], sizes = [%17, 32, 16, 100], strides = [1, 1, 1, 1] : tensor<?x32x16x100xf16> -> !flow.dispatch.tensor<writeonly:tensor<?x32x16x100xf16>>{%17}
        ^
t.mlir:41:9: note: see current operation: "memref.copy"(%25, %23) {__internal_linalg_transform__ = "copy_to_workgroup_memory"} : (memref<?x32x16x100xf16, #gpu.address_space<workgroup>>, memref<?x32x16x100xf16, strided<[51200, 1600, 100, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) -> ()
t.mlir:9:7: error: 'func.func' op failed on workgroup distribution verification
      func.func @decode_bs1$async_dispatch_124_scatter_Dx32x16x100xf16_dispatch_tensor_store() {
      ^
t.mlir:9:7: note: see current operation: 
"func.func"() <{function_type = () -> (), sym_name = "decode_bs1$async_dispatch_124_scatter_Dx32x16x100xf16_dispatch_tensor_store"}> ({
  %0 = "arith.constant"() <{value = 7872 : index}> : () -> index
  %1 = "arith.constant"() <{value = 46656 : index}> : () -> index
  %2 = "arith.constant"() <{value = 0 : index}> : () -> index
  %3 = "arith.constant"() <{value = 32 : i64}> : () -> i64
  %4 = "hal.interface.constant.load"() {layout = #hal.pipeline.layout<constants = 4, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, ordinal = 0 : index} : () -> i32
  %5 = "hal.interface.constant.load"() {layout = #hal.pipeline.layout<constants = 4, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, ordinal = 1 : index} : () -> i32
  %6 = "hal.interface.constant.load"() {layout = #hal.pipeline.layout<constants = 4, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, ordinal = 2 : index} : () -> i32
  %7 = "hal.interface.constant.load"() {layout = #hal.pipeline.layout<constants = 4, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, ordinal = 3 : index} : () -> i32
  %8 = "arith.extui"(%4) : (i32) -> i64
  %9 = "arith.extui"(%5) : (i32) -> i64
  %10 = "arith.shli"(%9, %3) <{overflowFlags = #arith.overflow<none>}> : (i64, i64) -> i64
  %11 = "arith.ori"(%8, %10) : (i64, i64) -> i64
  %12 = "arith.index_castui"(%11) : (i64) -> index
  %13 = "arith.extui"(%6) : (i32) -> i64
  %14 = "arith.extui"(%7) : (i32) -> i64
  %15 = "arith.shli"(%14, %3) <{overflowFlags = #arith.overflow<none>}> : (i64, i64) -> i64
  %16 = "arith.ori"(%13, %15) : (i64, i64) -> i64
  %17 = "arith.index_castui"(%16) : (i64) -> index
  %18 = "util.assume.int"(%17) <{assumptions = [[#util.int.assumption<umin = 1, umax = 9007199254740991>]]}> : (index) -> index
  %19 = "hal.interface.binding.subspan"(%1) {alignment = 64 : index, binding = 1 : index, descriptor_flags = 3 : i32, layout = #hal.pipeline.layout<constants = 4, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, operandSegmentSizes = array<i32: 1, 0>} : (index) -> memref<32x100xf16, strided<[100, 1], offset: 23328>, #hal.descriptor_type<storage_buffer>>
  "memref.assume_alignment"(%19) <{alignment = 64 : i32}> : (memref<32x100xf16, strided<[100, 1], offset: 23328>, #hal.descriptor_type<storage_buffer>>) -> ()
  %20 = "hal.interface.binding.subspan"(%0) {alignment = 64 : index, binding = 1 : index, descriptor_flags = 3 : i32, layout = #hal.pipeline.layout<constants = 4, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, operandSegmentSizes = array<i32: 1, 0>} : (index) -> memref<32x3xi32, strided<[3, 1], offset: 1968>, #hal.descriptor_type<storage_buffer>>
  "memref.assume_alignment"(%20) <{alignment = 64 : i32}> : (memref<32x3xi32, strided<[3, 1], offset: 1968>, #hal.descriptor_type<storage_buffer>>) -> ()
  %21 = "flow.dispatch.workload.ordinal"(%18) <{ordinal = 0 : index}> : (index) -> index
  %22 = "hal.interface.binding.subspan"(%2, %21) {alignment = 64 : index, binding = 0 : index, descriptor_flags = 3 : i32, layout = #hal.pipeline.layout<constants = 4, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, operandSegmentSizes = array<i32: 1, 1>} : (index, index) -> memref<?x26x2x32x16x100xf16, #hal.descriptor_type<storage_buffer>>
  "memref.assume_alignment"(%22) <{alignment = 64 : i32}> : (memref<?x26x2x32x16x100xf16, #hal.descriptor_type<storage_buffer>>) -> ()
  %23 = "hal.interface.binding.subspan"(%12, %21) {alignment = 64 : index, binding = 2 : index, descriptor_flags = 2 : i32, layout = #hal.pipeline.layout<constants = 4, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, operandSegmentSizes = array<i32: 1, 1>} : (index, index) -> memref<?x32x16x100xf16, strided<[51200, 1600, 100, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
  "memref.assume_alignment"(%23) <{alignment = 1 : i32}> : (memref<?x32x16x100xf16, strided<[51200, 1600, 100, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) -> ()
  %24 = "memref.subview"(%22, %21) <{operandSegmentSizes = array<i32: 1, 0, 1, 0>, static_offsets = array<i64: 0, 4, 0, 0, 0, 0>, static_sizes = array<i64: -9223372036854775808, 1, 1, 32, 16, 100>, static_strides = array<i64: 1, 1, 1, 1, 1, 1>}> : (memref<?x26x2x32x16x100xf16, #hal.descriptor_type<storage_buffer>>, index) -> memref<?x32x16x100xf16, strided<[2662400, 1600, 100, 1], offset: 409600>, #hal.descriptor_type<storage_buffer>>
  %25 = "memref.alloc"(%21) <{operandSegmentSizes = array<i32: 1, 0>}> : (index) -> memref<?x32x16x100xf16, #gpu.address_space<workgroup>>
  "gpu.barrier"() : () -> ()
  "memref.copy"(%24, %25) {__internal_linalg_transform__ = "copy_to_workgroup_memory"} : (memref<?x32x16x100xf16, strided<[2662400, 1600, 100, 1], offset: 409600>, #hal.descriptor_type<storage_buffer>>, memref<?x32x16x100xf16, #gpu.address_space<workgroup>>) -> ()
  "gpu.barrier"() : () -> ()
  "scf.forall"() <{mapping = [#iree_codegen.workgroup_mapping<y>, #iree_codegen.workgroup_mapping<x>], operandSegmentSizes = array<i32: 0, 0, 0, 0>, staticLowerBound = array<i64: 0, 0>, staticStep = array<i64: 1, 64>, staticUpperBound = array<i64: 32, 100>}> ({
  ^bb0(%arg0: index, %arg1: index):
    %26 = "affine.min"(%arg1) <{map = affine_map<(d0) -> (-d0 + 100, 64)>}> : (index) -> index
    %27 = "memref.subview"(%19, %arg0, %arg1, %26) <{operandSegmentSizes = array<i32: 1, 2, 1, 0>, static_offsets = array<i64: -9223372036854775808, -9223372036854775808>, static_sizes = array<i64: 1, -9223372036854775808>, static_strides = array<i64: 1, 1>}> : (memref<32x100xf16, strided<[100, 1], offset: 23328>, #hal.descriptor_type<storage_buffer>>, index, index, index) -> memref<1x?xf16, strided<[100, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
    %28 = "memref.subview"(%20, %arg0) <{operandSegmentSizes = array<i32: 1, 1, 0, 0>, static_offsets = array<i64: -9223372036854775808, 0>, static_sizes = array<i64: 1, 3>, static_strides = array<i64: 1, 1>}> : (memref<32x3xi32, strided<[3, 1], offset: 1968>, #hal.descriptor_type<storage_buffer>>, index) -> memref<1x3xi32, strided<[3, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
    %29 = "memref.subview"(%25, %arg1, %21, %26) <{operandSegmentSizes = array<i32: 1, 1, 2, 0>, static_offsets = array<i64: 0, 0, 0, -9223372036854775808>, static_sizes = array<i64: -9223372036854775808, 32, 16, -9223372036854775808>, static_strides = array<i64: 1, 1, 1, 1>}> : (memref<?x32x16x100xf16, #gpu.address_space<workgroup>>, index, index, index) -> memref<?x32x16x?xf16, strided<[51200, 1600, 100, 1], offset: ?>, #gpu.address_space<workgroup>>
    "iree_linalg_ext.scatter"(%27, %28, %29) <{dimension_map = array<i64: 0, 1, 2>, operandSegmentSizes = array<i32: 2, 1>, unique_indices = true}> ({
    ^bb0(%arg2: f16, %arg3: f16):
      "iree_linalg_ext.yield"(%arg2) : (f16) -> ()
    }) {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 64]]>} : (memref<1x?xf16, strided<[100, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, memref<1x3xi32, strided<[3, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, memref<?x32x16x?xf16, strided<[51200, 1600, 100, 1], offset: ?>, #gpu.address_space<workgroup>>) -> ()
    "gpu.barrier"() : () -> ()
    "scf.forall.in_parallel"() ({
    ^bb0:
    }) : () -> ()
  }) : () -> ()
  "gpu.barrier"() : () -> ()
  "memref.copy"(%25, %23) {__internal_linalg_transform__ = "copy_to_workgroup_memory"} : (memref<?x32x16x100xf16, #gpu.address_space<workgroup>>, memref<?x32x16x100xf16, strided<[51200, 1600, 100, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) -> ()
  "gpu.barrier"() : () -> ()
  "func.return"() : () -> ()
}) {translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUDistribute workgroup_size = [64, 1, 1] subgroup_size = 32>} : () -> ()
t.mlir:2:3: error: failed to run translation of source executable to target executable for backend #hal.executable.target<"rocm", "rocm-hsaco-fb", {abi = "hip", iree.gpu.target = #iree_gpu.target<arch = "gfx1100", features = "", wgp = <compute =  fp64|fp32|fp16|int64|int32|int16|int8, storage =  b64|b32|b16|b8, subgroup =  shuffle|arithmetic, dot =  dp4xi8toi32, mma = [<WMMA_F32_16x16x16_F16>, <WMMA_F16_16x16x16_F16>, <WMMA_I32_16x16x16_I8>, <WMMA_I32_16x16x16_I8>, <WMMA_I32_16x16x16_I8>], subgroup_size_choices = [32, 64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, max_workgroup_counts = [2147483647, 2147483647, 2147483647], max_load_instruction_bits = 128, simds_per_wgp = 4, vgpr_space_bits = 8192>>, ukernels = "none"}>
  hal.executable.variant public @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb", {abi = "hip", iree.gpu.target = #iree_gpu.target<arch = "gfx1100", features = "", wgp = <compute =  fp64|fp32|fp16|int64|int32|int16|int8, storage =  b64|b32|b16|b8, subgroup =  shuffle|arithmetic, dot =  dp4xi8toi32, mma = [<WMMA_F32_16x16x16_F16>, <WMMA_F16_16x16x16_F16>, <WMMA_I32_16x16x16_I8>, <WMMA_I32_16x16x16_I8>, <WMMA_I32_16x16x16_I8>], subgroup_size_choices = [32, 64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, max_workgroup_counts = [2147483647, 2147483647, 2147483647], max_load_instruction_bits = 128, simds_per_wgp = 4, vgpr_space_bits = 8192>>, ukernels = "none"}>) {
  ^
t.mlir:2:3: note: see current operation: 
"hal.executable.variant"() ({
  "hal.executable.export"() ({
  ^bb0(%arg4: !hal.device, %arg5: index):
    %30:3 = "flow.dispatch.workgroup_count_from_slice"(%arg5) : (index) -> (index, index, index)
    "hal.return"(%30#0, %30#1, %30#2) : (index, index, index) -> ()
  }) {layout = #hal.pipeline.layout<constants = 4, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, ordinal = 0 : index, sym_name = "decode_bs1$async_dispatch_124_scatter_Dx32x16x100xf16_dispatch_tensor_store"} : () -> ()
  "builtin.module"() ({
    "func.func"() <{function_type = () -> (), sym_name = "decode_bs1$async_dispatch_124_scatter_Dx32x16x100xf16_dispatch_tensor_store"}> ({
      %0 = "arith.constant"() <{value = 7872 : index}> : () -> index
      %1 = "arith.constant"() <{value = 46656 : index}> : () -> index
      %2 = "arith.constant"() <{value = 0 : index}> : () -> index
      %3 = "arith.constant"() <{value = 32 : i64}> : () -> i64
      %4 = "hal.interface.constant.load"() {layout = #hal.pipeline.layout<constants = 4, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, ordinal = 0 : index} : () -> i32
      %5 = "hal.interface.constant.load"() {layout = #hal.pipeline.layout<constants = 4, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, ordinal = 1 : index} : () -> i32
      %6 = "hal.interface.constant.load"() {layout = #hal.pipeline.layout<constants = 4, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, ordinal = 2 : index} : () -> i32
      %7 = "hal.interface.constant.load"() {layout = #hal.pipeline.layout<constants = 4, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, ordinal = 3 : index} : () -> i32
      %8 = "arith.extui"(%4) : (i32) -> i64
      %9 = "arith.extui"(%5) : (i32) -> i64
      %10 = "arith.shli"(%9, %3) <{overflowFlags = #arith.overflow<none>}> : (i64, i64) -> i64
      %11 = "arith.ori"(%8, %10) : (i64, i64) -> i64
      %12 = "arith.index_castui"(%11) : (i64) -> index
      %13 = "arith.extui"(%6) : (i32) -> i64
      %14 = "arith.extui"(%7) : (i32) -> i64
      %15 = "arith.shli"(%14, %3) <{overflowFlags = #arith.overflow<none>}> : (i64, i64) -> i64
      %16 = "arith.ori"(%13, %15) : (i64, i64) -> i64
      %17 = "arith.index_castui"(%16) : (i64) -> index
      %18 = "util.assume.int"(%17) <{assumptions = [[#util.int.assumption<umin = 1, umax = 9007199254740991>]]}> : (index) -> index
      %19 = "hal.interface.binding.subspan"(%1) {alignment = 64 : index, binding = 1 : index, descriptor_flags = 3 : i32, layout = #hal.pipeline.layout<constants = 4, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, operandSegmentSizes = array<i32: 1, 0>} : (index) -> memref<32x100xf16, strided<[100, 1], offset: 23328>, #hal.descriptor_type<storage_buffer>>
      "memref.assume_alignment"(%19) <{alignment = 64 : i32}> : (memref<32x100xf16, strided<[100, 1], offset: 23328>, #hal.descriptor_type<storage_buffer>>) -> ()
      %20 = "hal.interface.binding.subspan"(%0) {alignment = 64 : index, binding = 1 : index, descriptor_flags = 3 : i32, layout = #hal.pipeline.layout<constants = 4, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, operandSegmentSizes = array<i32: 1, 0>} : (index) -> memref<32x3xi32, strided<[3, 1], offset: 1968>, #hal.descriptor_type<storage_buffer>>
      "memref.assume_alignment"(%20) <{alignment = 64 : i32}> : (memref<32x3xi32, strided<[3, 1], offset: 1968>, #hal.descriptor_type<storage_buffer>>) -> ()
      %21 = "flow.dispatch.workload.ordinal"(%18) <{ordinal = 0 : index}> : (index) -> index
      %22 = "hal.interface.binding.subspan"(%2, %21) {alignment = 64 : index, binding = 0 : index, descriptor_flags = 3 : i32, layout = #hal.pipeline.layout<constants = 4, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, operandSegmentSizes = array<i32: 1, 1>} : (index, index) -> memref<?x26x2x32x16x100xf16, #hal.descriptor_type<storage_buffer>>
      "memref.assume_alignment"(%22) <{alignment = 64 : i32}> : (memref<?x26x2x32x16x100xf16, #hal.descriptor_type<storage_buffer>>) -> ()
      %23 = "hal.interface.binding.subspan"(%12, %21) {alignment = 64 : index, binding = 2 : index, descriptor_flags = 2 : i32, layout = #hal.pipeline.layout<constants = 4, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, operandSegmentSizes = array<i32: 1, 1>} : (index, index) -> memref<?x32x16x100xf16, strided<[51200, 1600, 100, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
      "memref.assume_alignment"(%23) <{alignment = 1 : i32}> : (memref<?x32x16x100xf16, strided<[51200, 1600, 100, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) -> ()
      %24 = "memref.subview"(%22, %21) <{operandSegmentSizes = array<i32: 1, 0, 1, 0>, static_offsets = array<i64: 0, 4, 0, 0, 0, 0>, static_sizes = array<i64: -9223372036854775808, 1, 1, 32, 16, 100>, static_strides = array<i64: 1, 1, 1, 1, 1, 1>}> : (memref<?x26x2x32x16x100xf16, #hal.descriptor_type<storage_buffer>>, index) -> memref<?x32x16x100xf16, strided<[2662400, 1600, 100, 1], offset: 409600>, #hal.descriptor_type<storage_buffer>>
      %25 = "memref.alloc"(%21) <{operandSegmentSizes = array<i32: 1, 0>}> : (index) -> memref<?x32x16x100xf16, #gpu.address_space<workgroup>>
      "gpu.barrier"() : () -> ()
      "memref.copy"(%24, %25) {__internal_linalg_transform__ = "copy_to_workgroup_memory"} : (memref<?x32x16x100xf16, strided<[2662400, 1600, 100, 1], offset: 409600>, #hal.descriptor_type<storage_buffer>>, memref<?x32x16x100xf16, #gpu.address_space<workgroup>>) -> ()
      "gpu.barrier"() : () -> ()
      "scf.forall"() <{mapping = [#iree_codegen.workgroup_mapping<y>, #iree_codegen.workgroup_mapping<x>], operandSegmentSizes = array<i32: 0, 0, 0, 0>, staticLowerBound = array<i64: 0, 0>, staticStep = array<i64: 1, 64>, staticUpperBound = array<i64: 32, 100>}> ({
      ^bb0(%arg0: index, %arg1: index):
        %26 = "affine.min"(%arg1) <{map = affine_map<(d0) -> (-d0 + 100, 64)>}> : (index) -> index
        %27 = "memref.subview"(%19, %arg0, %arg1, %26) <{operandSegmentSizes = array<i32: 1, 2, 1, 0>, static_offsets = array<i64: -9223372036854775808, -9223372036854775808>, static_sizes = array<i64: 1, -9223372036854775808>, static_strides = array<i64: 1, 1>}> : (memref<32x100xf16, strided<[100, 1], offset: 23328>, #hal.descriptor_type<storage_buffer>>, index, index, index) -> memref<1x?xf16, strided<[100, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
        %28 = "memref.subview"(%20, %arg0) <{operandSegmentSizes = array<i32: 1, 1, 0, 0>, static_offsets = array<i64: -9223372036854775808, 0>, static_sizes = array<i64: 1, 3>, static_strides = array<i64: 1, 1>}> : (memref<32x3xi32, strided<[3, 1], offset: 1968>, #hal.descriptor_type<storage_buffer>>, index) -> memref<1x3xi32, strided<[3, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
        %29 = "memref.subview"(%25, %arg1, %21, %26) <{operandSegmentSizes = array<i32: 1, 1, 2, 0>, static_offsets = array<i64: 0, 0, 0, -9223372036854775808>, static_sizes = array<i64: -9223372036854775808, 32, 16, -9223372036854775808>, static_strides = array<i64: 1, 1, 1, 1>}> : (memref<?x32x16x100xf16, #gpu.address_space<workgroup>>, index, index, index) -> memref<?x32x16x?xf16, strided<[51200, 1600, 100, 1], offset: ?>, #gpu.address_space<workgroup>>
        "iree_linalg_ext.scatter"(%27, %28, %29) <{dimension_map = array<i64: 0, 1, 2>, operandSegmentSizes = array<i32: 2, 1>, unique_indices = true}> ({
        ^bb0(%arg2: f16, %arg3: f16):
          "iree_linalg_ext.yield"(%arg2) : (f16) -> ()
        }) {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 64]]>} : (memref<1x?xf16, strided<[100, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, memref<1x3xi32, strided<[3, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, memref<?x32x16x?xf16, strided<[51200, 1600, 100, 1], offset: ?>, #gpu.address_space<workgroup>>) -> ()
        "gpu.barrier"() : () -> ()
        "scf.forall.in_parallel"() ({
        ^bb0:
        }) : () -> ()
      }) : () -> ()
      "gpu.barrier"() : () -> ()
      "memref.copy"(%25, %23) {__internal_linalg_transform__ = "copy_to_workgroup_memory"} : (memref<?x32x16x100xf16, #gpu.address_space<workgroup>>, memref<?x32x16x100xf16, strided<[51200, 1600, 100, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) -> ()
      "gpu.barrier"() : () -> ()
      "func.return"() : () -> ()
    }) {translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUDistribute workgroup_size = [64, 1, 1] subgroup_size = 32>} : () -> ()
  }) : () -> ()
  "hal.executable.variant_end"() : () -> ()
}) {sym_name = "rocm_hsaco_fb", sym_visibility = "public", target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {abi = "hip", iree.gpu.target = #iree_gpu.target<arch = "gfx1100", features = "", wgp = <compute =  fp64|fp32|fp16|int64|int32|int16|int8, storage =  b64|b32|b16|b8, subgroup =  shuffle|arithmetic, dot =  dp4xi8toi32, mma = [<WMMA_F32_16x16x16_F16>, <WMMA_F16_16x16x16_F16>, <WMMA_I32_16x16x16_I8>, <WMMA_I32_16x16x16_I8>, <WMMA_I32_16x16x16_I8>], subgroup_size_choices = [32, 64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, max_workgroup_counts = [2147483647, 2147483647, 2147483647], max_load_instruction_bits = 128, simds_per_wgp = 4, vgpr_space_bits = 8192>>, ukernels = "none"}>} : () -> ()
failed to translate executables

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug 🐞 Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant