Skip to content

replace all-in-one pass with real pipeline #174

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 18, 2024
Merged

Conversation

xurui1995
Copy link
Contributor

@xurui1995 xurui1995 commented Jul 18, 2024

The all-in-one pass is not debug-friendly, and the real pipeline can provide more information about each pass.
Track: #173

Example:
for the following cmd:
./bin/gc-opt /home/xurui/gc_v2/test.mlir --gc-cpu-pipeline --mlir-print-ir-after-all

all-in-one pass

// -----// IR Dump After GCCPUPipeline (gc-cpu-pipeline) //----- //
module {
  llvm.func @main_entry(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64, %arg7: !llvm.ptr, %arg8: !llvm.ptr, %arg9: i64, %arg10: i64, %arg11: i64, %arg12: i64, %arg13: i64, %arg14: !llvm.ptr, %arg15: !llvm.ptr, %arg16: i64, %arg17: i64, %arg18: i64, %arg19: i64, %arg20: i64) attributes {llvm.emit_c_interface} {
    %0 = llvm.mlir.constant(256 : index) : i64
    %1 = llvm.mlir.constant(1 : index) : i64
    %2 = llvm.mlir.constant(128 : index) : i64
    %3 = llvm.mlir.constant(0 : index) : i64
    %4 = llvm.mlir.constant(0.000000e+00 : bf16) : bf16
    %5 = llvm.mlir.constant(512 : index) : i64
    omp.parallel {
      omp.wsloop {
        omp.loop_nest (%arg21, %arg22) : i64 = (%3, %3) to (%2, %0) step (%1, %1) {
          %6 = llvm.intr.stacksave : !llvm.ptr
          llvm.br ^bb1
        ^bb1:  // pred: ^bb0
          %7 = llvm.mul %arg21, %0 : i64
          %8 = llvm.add %7, %arg22 : i64
          %9 = llvm.getelementptr %arg15[%8] : (!llvm.ptr, i64) -> !llvm.ptr, bf16
          llvm.store %4, %9 : bf16, !llvm.ptr
          llvm.intr.stackrestore %6 : !llvm.ptr
          llvm.br ^bb2
        ^bb2:  // pred: ^bb1
          omp.yield
        }
        omp.terminator
      }
      omp.terminator
    }
    omp.parallel {
      omp.wsloop {
        omp.loop_nest (%arg21, %arg22) : i64 = (%3, %3) to (%2, %0) step (%1, %1) {
          %6 = llvm.intr.stacksave : !llvm.ptr
          llvm.br ^bb1
        ^bb1:  // pred: ^bb0
          llvm.br ^bb2(%3 : i64)
        ^bb2(%7: i64):  // 2 preds: ^bb1, ^bb3
          %8 = llvm.icmp "slt" %7, %5 : i64
          llvm.cond_br %8, ^bb3, ^bb4
        ^bb3:  // pred: ^bb2
          %9 = llvm.mul %arg21, %5 : i64
          %10 = llvm.add %9, %7 : i64
          %11 = llvm.getelementptr %arg1[%10] : (!llvm.ptr, i64) -> !llvm.ptr, bf16
          %12 = llvm.load %11 : !llvm.ptr -> bf16
          %13 = llvm.mul %7, %0 : i64
          %14 = llvm.add %13, %arg22 : i64
          %15 = llvm.getelementptr %arg8[%14] : (!llvm.ptr, i64) -> !llvm.ptr, bf16
          %16 = llvm.load %15 : !llvm.ptr -> bf16
          %17 = llvm.mul %arg21, %0 : i64
          %18 = llvm.add %17, %arg22 : i64
          %19 = llvm.getelementptr %arg15[%18] : (!llvm.ptr, i64) -> !llvm.ptr, bf16
          %20 = llvm.load %19 : !llvm.ptr -> bf16
          %21 = llvm.fpext %20 {fastmath = #arith.fastmath<contract>} : bf16 to f32
          %22 = llvm.fpext %16 {fastmath = #arith.fastmath<contract>} : bf16 to f32
          %23 = llvm.fpext %12 {fastmath = #arith.fastmath<contract>} : bf16 to f32
          %24 = llvm.fmul %23, %22  : f32
          %25 = llvm.fadd %21, %24  : f32
          %26 = llvm.fptrunc %25 {fastmath = #arith.fastmath<contract>} : f32 to bf16
          llvm.store %26, %19 : bf16, !llvm.ptr
          %27 = llvm.add %7, %1 : i64
          llvm.br ^bb2(%27 : i64)
        ^bb4:  // pred: ^bb2
          llvm.intr.stackrestore %6 : !llvm.ptr
          llvm.br ^bb5
        ^bb5:  // pred: ^bb4
          omp.yield
        }
        omp.terminator
      }
      omp.terminator
    }
    llvm.return
  }
  llvm.func @_mlir_ciface_main_entry(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr) attributes {llvm.emit_c_interface} {
    %0 = llvm.load %arg0 : !llvm.ptr -> !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
    %1 = llvm.extractvalue %0[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %2 = llvm.extractvalue %0[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %3 = llvm.extractvalue %0[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %4 = llvm.extractvalue %0[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %5 = llvm.extractvalue %0[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %6 = llvm.extractvalue %0[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %7 = llvm.extractvalue %0[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %8 = llvm.load %arg1 : !llvm.ptr -> !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
    %9 = llvm.extractvalue %8[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %10 = llvm.extractvalue %8[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %11 = llvm.extractvalue %8[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %12 = llvm.extractvalue %8[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %13 = llvm.extractvalue %8[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %14 = llvm.extractvalue %8[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %15 = llvm.extractvalue %8[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %16 = llvm.load %arg2 : !llvm.ptr -> !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
    %17 = llvm.extractvalue %16[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %18 = llvm.extractvalue %16[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %19 = llvm.extractvalue %16[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %20 = llvm.extractvalue %16[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %21 = llvm.extractvalue %16[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %22 = llvm.extractvalue %16[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %23 = llvm.extractvalue %16[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    llvm.call @main_entry(%1, %2, %3, %4, %5, %6, %7, %9, %10, %11, %12, %13, %14, %15, %17, %18, %19, %20, %21, %22, %23) : (!llvm.ptr, !llvm.ptr, i64, i64, i64, i64, i64, !llvm.ptr, !llvm.ptr, i64, i64, i64, i64, i64, !llvm.ptr, !llvm.ptr, i64, i64, i64, i64, i64) -> ()
    llvm.return
  }
}


module {
  llvm.func @main_entry(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64, %arg7: !llvm.ptr, %arg8: !llvm.ptr, %arg9: i64, %arg10: i64, %arg11: i64, %arg12: i64, %arg13: i64, %arg14: !llvm.ptr, %arg15: !llvm.ptr, %arg16: i64, %arg17: i64, %arg18: i64, %arg19: i64, %arg20: i64) attributes {llvm.emit_c_interface} {
    %0 = llvm.mlir.constant(256 : index) : i64
    %1 = llvm.mlir.constant(1 : index) : i64
    %2 = llvm.mlir.constant(128 : index) : i64
    %3 = llvm.mlir.constant(0 : index) : i64
    %4 = llvm.mlir.constant(0.000000e+00 : bf16) : bf16
    %5 = llvm.mlir.constant(512 : index) : i64
    omp.parallel {
      omp.wsloop {
        omp.loop_nest (%arg21, %arg22) : i64 = (%3, %3) to (%2, %0) step (%1, %1) {
          %6 = llvm.intr.stacksave : !llvm.ptr
          llvm.br ^bb1
        ^bb1:  // pred: ^bb0
          %7 = llvm.mul %arg21, %0 : i64
          %8 = llvm.add %7, %arg22 : i64
          %9 = llvm.getelementptr %arg15[%8] : (!llvm.ptr, i64) -> !llvm.ptr, bf16
          llvm.store %4, %9 : bf16, !llvm.ptr
          llvm.intr.stackrestore %6 : !llvm.ptr
          llvm.br ^bb2
        ^bb2:  // pred: ^bb1
          omp.yield
        }
        omp.terminator
      }
      omp.terminator
    }
    omp.parallel {
      omp.wsloop {
        omp.loop_nest (%arg21, %arg22) : i64 = (%3, %3) to (%2, %0) step (%1, %1) {
          %6 = llvm.intr.stacksave : !llvm.ptr
          llvm.br ^bb1
        ^bb1:  // pred: ^bb0
          llvm.br ^bb2(%3 : i64)
        ^bb2(%7: i64):  // 2 preds: ^bb1, ^bb3
          %8 = llvm.icmp "slt" %7, %5 : i64
          llvm.cond_br %8, ^bb3, ^bb4
        ^bb3:  // pred: ^bb2
          %9 = llvm.mul %arg21, %5 : i64
          %10 = llvm.add %9, %7 : i64
          %11 = llvm.getelementptr %arg1[%10] : (!llvm.ptr, i64) -> !llvm.ptr, bf16
          %12 = llvm.load %11 : !llvm.ptr -> bf16
          %13 = llvm.mul %7, %0 : i64
          %14 = llvm.add %13, %arg22 : i64
          %15 = llvm.getelementptr %arg8[%14] : (!llvm.ptr, i64) -> !llvm.ptr, bf16
          %16 = llvm.load %15 : !llvm.ptr -> bf16
          %17 = llvm.mul %arg21, %0 : i64
          %18 = llvm.add %17, %arg22 : i64
          %19 = llvm.getelementptr %arg15[%18] : (!llvm.ptr, i64) -> !llvm.ptr, bf16
          %20 = llvm.load %19 : !llvm.ptr -> bf16
          %21 = llvm.fpext %20 {fastmath = #arith.fastmath<contract>} : bf16 to f32
          %22 = llvm.fpext %16 {fastmath = #arith.fastmath<contract>} : bf16 to f32
          %23 = llvm.fpext %12 {fastmath = #arith.fastmath<contract>} : bf16 to f32
          %24 = llvm.fmul %23, %22  : f32
          %25 = llvm.fadd %21, %24  : f32
          %26 = llvm.fptrunc %25 {fastmath = #arith.fastmath<contract>} : f32 to bf16
          llvm.store %26, %19 : bf16, !llvm.ptr
          %27 = llvm.add %7, %1 : i64
          llvm.br ^bb2(%27 : i64)
        ^bb4:  // pred: ^bb2
          llvm.intr.stackrestore %6 : !llvm.ptr
          llvm.br ^bb5
        ^bb5:  // pred: ^bb4
          omp.yield
        }
        omp.terminator
      }
      omp.terminator
    }
    llvm.return
  }
  llvm.func @_mlir_ciface_main_entry(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr) attributes {llvm.emit_c_interface} {
    %0 = llvm.load %arg0 : !llvm.ptr -> !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
    %1 = llvm.extractvalue %0[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %2 = llvm.extractvalue %0[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %3 = llvm.extractvalue %0[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %4 = llvm.extractvalue %0[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %5 = llvm.extractvalue %0[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %6 = llvm.extractvalue %0[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %7 = llvm.extractvalue %0[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %8 = llvm.load %arg1 : !llvm.ptr -> !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
    %9 = llvm.extractvalue %8[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %10 = llvm.extractvalue %8[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %11 = llvm.extractvalue %8[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %12 = llvm.extractvalue %8[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %13 = llvm.extractvalue %8[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %14 = llvm.extractvalue %8[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %15 = llvm.extractvalue %8[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %16 = llvm.load %arg2 : !llvm.ptr -> !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
    %17 = llvm.extractvalue %16[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %18 = llvm.extractvalue %16[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %19 = llvm.extractvalue %16[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %20 = llvm.extractvalue %16[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %21 = llvm.extractvalue %16[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %22 = llvm.extractvalue %16[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %23 = llvm.extractvalue %16[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    llvm.call @main_entry(%1, %2, %3, %4, %5, %6, %7, %9, %10, %11, %12, %13, %14, %15, %17, %18, %19, %20, %21, %22, %23) : (!llvm.ptr, !llvm.ptr, i64, i64, i64, i64, i64, !llvm.ptr, !llvm.ptr, i64, i64, i64, i64, i64, !llvm.ptr, !llvm.ptr, i64, i64, i64, i64, i64) -> ()
    llvm.return
  }
}

pipeline

// -----// IR Dump After ConvertOneDNNGraphToLinalg (convert-onednn-graph-to-linalg) //----- //
module {
  func.func @main_entry(%arg0: tensor<128x512xbf16>, %arg1: tensor<512x256xbf16>) -> tensor<128x256xbf16> attributes {llvm.emit_c_interface} {
    %cst = arith.constant 0.000000e+00 : bf16
    %0 = tensor.empty() : tensor<128x256xbf16>
    %1 = linalg.fill ins(%cst : bf16) outs(%0 : tensor<128x256xbf16>) -> tensor<128x256xbf16>
    %2 = linalg.matmul ins(%arg0, %arg1 : tensor<128x512xbf16>, tensor<512x256xbf16>) outs(%1 : tensor<128x256xbf16>) -> tensor<128x256xbf16>
    return %2 : tensor<128x256xbf16>
  }
}


// -----// IR Dump After LinalgGeneralizeNamedOpsPass (linalg-generalize-named-ops) //----- //
func.func @main_entry(%arg0: tensor<128x512xbf16>, %arg1: tensor<512x256xbf16>) -> tensor<128x256xbf16> attributes {llvm.emit_c_interface} {
  %cst = arith.constant 0.000000e+00 : bf16
  %0 = tensor.empty() : tensor<128x256xbf16>
  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%cst : bf16) outs(%0 : tensor<128x256xbf16>) {
  ^bb0(%in: bf16, %out: bf16):
    linalg.yield %in : bf16
  } -> tensor<128x256xbf16>
  %2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<128x512xbf16>, tensor<512x256xbf16>) outs(%1 : tensor<128x256xbf16>) {
  ^bb0(%in: bf16, %in_0: bf16, %out: bf16):
    %3 = arith.mulf %in, %in_0 : bf16
    %4 = arith.addf %out, %3 : bf16
    linalg.yield %4 : bf16
  } -> tensor<128x256xbf16>
  return %2 : tensor<128x256xbf16>
}

// -----// IR Dump After MathLegalizeToF32 (math-legalize-to-f32) //----- //
func.func @main_entry(%arg0: tensor<128x512xbf16>, %arg1: tensor<512x256xbf16>) -> tensor<128x256xbf16> attributes {llvm.emit_c_interface} {
  %cst = arith.constant 0.000000e+00 : bf16
  %0 = tensor.empty() : tensor<128x256xbf16>
  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%cst : bf16) outs(%0 : tensor<128x256xbf16>) {
  ^bb0(%in: bf16, %out: bf16):
    linalg.yield %in : bf16
  } -> tensor<128x256xbf16>
  %2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<128x512xbf16>, tensor<512x256xbf16>) outs(%1 : tensor<128x256xbf16>) {
  ^bb0(%in: bf16, %in_0: bf16, %out: bf16):
    %3 = arith.mulf %in, %in_0 : bf16
    %4 = arith.addf %out, %3 : bf16
    linalg.yield %4 : bf16
  } -> tensor<128x256xbf16>
  return %2 : tensor<128x256xbf16>
}

// -----// IR Dump After ArithEmulateUnsupportedFloats (arith-emulate-unsupported-floats) //----- //
func.func @main_entry(%arg0: tensor<128x512xbf16>, %arg1: tensor<512x256xbf16>) -> tensor<128x256xbf16> attributes {llvm.emit_c_interface} {
  %cst = arith.constant 0.000000e+00 : bf16
  %0 = tensor.empty() : tensor<128x256xbf16>
  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%cst : bf16) outs(%0 : tensor<128x256xbf16>) {
  ^bb0(%in: bf16, %out: bf16):
    linalg.yield %in : bf16
  } -> tensor<128x256xbf16>
  %2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<128x512xbf16>, tensor<512x256xbf16>) outs(%1 : tensor<128x256xbf16>) {
  ^bb0(%in: bf16, %in_0: bf16, %out: bf16):
    %3 = arith.extf %out fastmath<contract> : bf16 to f32
    %4 = arith.extf %in_0 fastmath<contract> : bf16 to f32
    %5 = arith.extf %in fastmath<contract> : bf16 to f32
    %6 = arith.mulf %5, %4 : f32
    %7 = arith.truncf %6 fastmath<contract> : f32 to bf16
    %8 = arith.extf %7 fastmath<contract> : bf16 to f32
    %9 = arith.addf %3, %8 : f32
    %10 = arith.truncf %9 fastmath<contract> : f32 to bf16
    linalg.yield %10 : bf16
  } -> tensor<128x256xbf16>
  return %2 : tensor<128x256xbf16>
}

// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func @main_entry(%arg0: tensor<128x512xbf16>, %arg1: tensor<512x256xbf16>) -> tensor<128x256xbf16> attributes {llvm.emit_c_interface} {
  %cst = arith.constant 0.000000e+00 : bf16
  %0 = tensor.empty() : tensor<128x256xbf16>
  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%cst : bf16) outs(%0 : tensor<128x256xbf16>) {
  ^bb0(%in: bf16, %out: bf16):
    linalg.yield %in : bf16
  } -> tensor<128x256xbf16>
  %2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<128x512xbf16>, tensor<512x256xbf16>) outs(%1 : tensor<128x256xbf16>) {
  ^bb0(%in: bf16, %in_0: bf16, %out: bf16):
    %3 = arith.extf %out fastmath<contract> : bf16 to f32
    %4 = arith.extf %in_0 fastmath<contract> : bf16 to f32
    %5 = arith.extf %in fastmath<contract> : bf16 to f32
    %6 = arith.mulf %5, %4 : f32
    %7 = arith.addf %3, %6 : f32
    %8 = arith.truncf %7 fastmath<contract> : f32 to bf16
    linalg.yield %8 : bf16
  } -> tensor<128x256xbf16>
  return %2 : tensor<128x256xbf16>
}

// -----// IR Dump After ArithExpandOpsPass (arith-expand) //----- //
func.func @main_entry(%arg0: tensor<128x512xbf16>, %arg1: tensor<512x256xbf16>) -> tensor<128x256xbf16> attributes {llvm.emit_c_interface} {
  %cst = arith.constant 0.000000e+00 : bf16
  %0 = tensor.empty() : tensor<128x256xbf16>
  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%cst : bf16) outs(%0 : tensor<128x256xbf16>) {
  ^bb0(%in: bf16, %out: bf16):
    linalg.yield %in : bf16
  } -> tensor<128x256xbf16>
  %2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<128x512xbf16>, tensor<512x256xbf16>) outs(%1 : tensor<128x256xbf16>) {
  ^bb0(%in: bf16, %in_0: bf16, %out: bf16):
    %3 = arith.extf %out fastmath<contract> : bf16 to f32
    %4 = arith.extf %in_0 fastmath<contract> : bf16 to f32
    %5 = arith.extf %in fastmath<contract> : bf16 to f32
    %6 = arith.mulf %5, %4 : f32
    %7 = arith.addf %3, %6 : f32
    %8 = arith.truncf %7 fastmath<contract> : f32 to bf16
    linalg.yield %8 : bf16
  } -> tensor<128x256xbf16>
  return %2 : tensor<128x256xbf16>
}

// -----// IR Dump After OneShotBufferize (one-shot-bufferize) //----- //
#map = affine_map<(d0, d1) -> ()>
#map1 = affine_map<(d0, d1) -> (d0, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d2)>
#map3 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map4 = affine_map<(d0, d1, d2) -> (d0, d1)>
module {
  func.func @main_entry(%arg0: memref<128x512xbf16>, %arg1: memref<512x256xbf16>) -> memref<128x256xbf16> attributes {llvm.emit_c_interface} {
    %cst = arith.constant 0.000000e+00 : bf16
    %alloc = memref.alloc() {alignment = 64 : i64} : memref<128x256xbf16>
    linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} ins(%cst : bf16) outs(%alloc : memref<128x256xbf16>) {
    ^bb0(%in: bf16, %out: bf16):
      linalg.yield %in : bf16
    }
    linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : memref<128x512xbf16>, memref<512x256xbf16>) outs(%alloc : memref<128x256xbf16>) {
    ^bb0(%in: bf16, %in_0: bf16, %out: bf16):
      %0 = arith.extf %out fastmath<contract> : bf16 to f32
      %1 = arith.extf %in_0 fastmath<contract> : bf16 to f32
      %2 = arith.extf %in fastmath<contract> : bf16 to f32
      %3 = arith.mulf %2, %1 : f32
      %4 = arith.addf %0, %3 : f32
      %5 = arith.truncf %4 fastmath<contract> : f32 to bf16
      linalg.yield %5 : bf16
    }
    return %alloc : memref<128x256xbf16>
  }
}


// -----// IR Dump After CSE (cse) //----- //
#map = affine_map<(d0, d1) -> ()>
#map1 = affine_map<(d0, d1) -> (d0, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d2)>
#map3 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map4 = affine_map<(d0, d1, d2) -> (d0, d1)>
module {
  func.func @main_entry(%arg0: memref<128x512xbf16>, %arg1: memref<512x256xbf16>) -> memref<128x256xbf16> attributes {llvm.emit_c_interface} {
    %cst = arith.constant 0.000000e+00 : bf16
    %alloc = memref.alloc() {alignment = 64 : i64} : memref<128x256xbf16>
    linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} ins(%cst : bf16) outs(%alloc : memref<128x256xbf16>) {
    ^bb0(%in: bf16, %out: bf16):
      linalg.yield %in : bf16
    }
    linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : memref<128x512xbf16>, memref<512x256xbf16>) outs(%alloc : memref<128x256xbf16>) {
    ^bb0(%in: bf16, %in_0: bf16, %out: bf16):
      %0 = arith.extf %out fastmath<contract> : bf16 to f32
      %1 = arith.extf %in_0 fastmath<contract> : bf16 to f32
      %2 = arith.extf %in fastmath<contract> : bf16 to f32
      %3 = arith.mulf %2, %1 : f32
      %4 = arith.addf %0, %3 : f32
      %5 = arith.truncf %4 fastmath<contract> : f32 to bf16
      linalg.yield %5 : bf16
    }
    return %alloc : memref<128x256xbf16>
  }
}


// -----// IR Dump After BufferResultsToOutParams (buffer-results-to-out-params) //----- //
...
// -----// IR Dump After BufferHoisting (buffer-hoisting) //----- //
...

// -----// IR Dump After BufferLoopHoisting (buffer-loop-hoisting) //----- //
...

// -----// IR Dump After BufferDeallocation (buffer-deallocation) //----- //
...

// -----// IR Dump After ConvertBufferizationToMemRef (convert-bufferization-to-memref) //----- //
...

// -----// IR Dump After ConvertSCFToOpenMPPass (convert-scf-to-openmp) //----- //
...

// -----// IR Dump After ExpandOps (memref-expand) //----- //
...

// -----// IR Dump After ExpandStridedMetadata (expand-strided-metadata) //----- //
...

// -----// IR Dump After FinalizeMemRefToLLVMConversionPass (finalize-memref-to-llvm) //----- //
...

// -----// IR Dump After SCFToControlFlow (convert-scf-to-cf) //----- //
...

// -----// IR Dump After CPURuntimeToLLVM (convert-cpuruntime-to-llvm) //----- //
...

// -----// IR Dump After ConvertOpenMPToLLVMPass (convert-openmp-to-llvm) //----- //
...

// -----// IR Dump After ConvertMathToLibm (convert-math-to-libm) //----- //
...

// -----// IR Dump After ConvertFuncToLLVMPass (convert-func-to-llvm) //----- //
...

// -----// IR Dump After ConvertControlFlowToLLVMPass (convert-cf-to-llvm) //----- //
...

// -----// IR Dump After CSE (cse) //----- //
...

// -----// IR Dump After Canonicalizer (canonicalize) //----- //
...

// -----// IR Dump After ReconcileUnrealizedCasts (reconcile-unrealized-casts) //----- //
...

// -----// IR Dump After SymbolDCE (symbol-dce) //----- //
...


module {
  llvm.func @main_entry(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64, %arg7: !llvm.ptr, %arg8: !llvm.ptr, %arg9: i64, %arg10: i64, %arg11: i64, %arg12: i64, %arg13: i64, %arg14: !llvm.ptr, %arg15: !llvm.ptr, %arg16: i64, %arg17: i64, %arg18: i64, %arg19: i64, %arg20: i64) attributes {llvm.emit_c_interface} {
    %0 = llvm.mlir.constant(256 : index) : i64
    %1 = llvm.mlir.constant(1 : index) : i64
    %2 = llvm.mlir.constant(128 : index) : i64
    %3 = llvm.mlir.constant(0 : index) : i64
    %4 = llvm.mlir.constant(0.000000e+00 : bf16) : bf16
    %5 = llvm.mlir.constant(512 : index) : i64
    omp.parallel {
      omp.wsloop {
        omp.loop_nest (%arg21, %arg22) : i64 = (%3, %3) to (%2, %0) step (%1, %1) {
          %6 = llvm.intr.stacksave : !llvm.ptr
          llvm.br ^bb1
        ^bb1:  // pred: ^bb0
          %7 = llvm.mul %arg21, %0 : i64
          %8 = llvm.add %7, %arg22 : i64
          %9 = llvm.getelementptr %arg15[%8] : (!llvm.ptr, i64) -> !llvm.ptr, bf16
          llvm.store %4, %9 : bf16, !llvm.ptr
          llvm.intr.stackrestore %6 : !llvm.ptr
          llvm.br ^bb2
        ^bb2:  // pred: ^bb1
          omp.yield
        }
        omp.terminator
      }
      omp.terminator
    }
    omp.parallel {
      omp.wsloop {
        omp.loop_nest (%arg21, %arg22) : i64 = (%3, %3) to (%2, %0) step (%1, %1) {
          %6 = llvm.intr.stacksave : !llvm.ptr
          llvm.br ^bb1
        ^bb1:  // pred: ^bb0
          llvm.br ^bb2(%3 : i64)
        ^bb2(%7: i64):  // 2 preds: ^bb1, ^bb3
          %8 = llvm.icmp "slt" %7, %5 : i64
          llvm.cond_br %8, ^bb3, ^bb4
        ^bb3:  // pred: ^bb2
          %9 = llvm.mul %arg21, %5 : i64
          %10 = llvm.add %9, %7 : i64
          %11 = llvm.getelementptr %arg1[%10] : (!llvm.ptr, i64) -> !llvm.ptr, bf16
          %12 = llvm.load %11 : !llvm.ptr -> bf16
          %13 = llvm.mul %7, %0 : i64
          %14 = llvm.add %13, %arg22 : i64
          %15 = llvm.getelementptr %arg8[%14] : (!llvm.ptr, i64) -> !llvm.ptr, bf16
          %16 = llvm.load %15 : !llvm.ptr -> bf16
          %17 = llvm.mul %arg21, %0 : i64
          %18 = llvm.add %17, %arg22 : i64
          %19 = llvm.getelementptr %arg15[%18] : (!llvm.ptr, i64) -> !llvm.ptr, bf16
          %20 = llvm.load %19 : !llvm.ptr -> bf16
          %21 = llvm.fpext %20 {fastmath = #arith.fastmath<contract>} : bf16 to f32
          %22 = llvm.fpext %16 {fastmath = #arith.fastmath<contract>} : bf16 to f32
          %23 = llvm.fpext %12 {fastmath = #arith.fastmath<contract>} : bf16 to f32
          %24 = llvm.fmul %23, %22  : f32
          %25 = llvm.fadd %21, %24  : f32
          %26 = llvm.fptrunc %25 {fastmath = #arith.fastmath<contract>} : f32 to bf16
          llvm.store %26, %19 : bf16, !llvm.ptr
          %27 = llvm.add %7, %1 : i64
          llvm.br ^bb2(%27 : i64)
        ^bb4:  // pred: ^bb2
          llvm.intr.stackrestore %6 : !llvm.ptr
          llvm.br ^bb5
        ^bb5:  // pred: ^bb4
          omp.yield
        }
        omp.terminator
      }
      omp.terminator
    }
    llvm.return
  }
  llvm.func @_mlir_ciface_main_entry(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr) attributes {llvm.emit_c_interface} {
    %0 = llvm.load %arg0 : !llvm.ptr -> !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
    %1 = llvm.extractvalue %0[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %2 = llvm.extractvalue %0[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %3 = llvm.extractvalue %0[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %4 = llvm.extractvalue %0[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %5 = llvm.extractvalue %0[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %6 = llvm.extractvalue %0[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %7 = llvm.extractvalue %0[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %8 = llvm.load %arg1 : !llvm.ptr -> !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
    %9 = llvm.extractvalue %8[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %10 = llvm.extractvalue %8[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %11 = llvm.extractvalue %8[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %12 = llvm.extractvalue %8[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %13 = llvm.extractvalue %8[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %14 = llvm.extractvalue %8[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %15 = llvm.extractvalue %8[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %16 = llvm.load %arg2 : !llvm.ptr -> !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
    %17 = llvm.extractvalue %16[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %18 = llvm.extractvalue %16[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %19 = llvm.extractvalue %16[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %20 = llvm.extractvalue %16[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %21 = llvm.extractvalue %16[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %22 = llvm.extractvalue %16[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %23 = llvm.extractvalue %16[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    llvm.call @main_entry(%1, %2, %3, %4, %5, %6, %7, %9, %10, %11, %12, %13, %14, %15, %17, %18, %19, %20, %21, %22, %23) : (!llvm.ptr, !llvm.ptr, i64, i64, i64, i64, i64, !llvm.ptr, !llvm.ptr, i64, i64, i64, i64, i64, !llvm.ptr, !llvm.ptr, i64, i64, i64, i64, i64) -> ()
    llvm.return
  }
}

@kurapov-peter kurapov-peter merged commit 5c0e196 into main Jul 18, 2024
4 checks passed
dchigarev pushed a commit to dchigarev/graph-compiler that referenced this pull request Jul 29, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants