diff --git a/mlir/include/mlir/Dialect/GPU/Passes.h b/mlir/include/mlir/Dialect/GPU/Passes.h --- a/mlir/include/mlir/Dialect/GPU/Passes.h +++ b/mlir/include/mlir/Dialect/GPU/Passes.h @@ -23,6 +23,9 @@ } // namespace llvm namespace mlir { +/// Sinks ops into gpu.launch body. +std::unique_ptr createGpuLauchSinkOpsPass(); + /// Replaces `gpu.launch` with `gpu.launch_func` by moving the region into /// a separate kernel function. std::unique_ptr> diff --git a/mlir/include/mlir/Dialect/GPU/Passes.td b/mlir/include/mlir/Dialect/GPU/Passes.td --- a/mlir/include/mlir/Dialect/GPU/Passes.td +++ b/mlir/include/mlir/Dialect/GPU/Passes.td @@ -11,6 +11,12 @@ include "mlir/Pass/PassBase.td" +def GpuLaunchSinkOps : Pass<"gpu-launch-sink-ops"> { + let summary = "Sink ops into gpu.launch body"; + let constructor = "mlir::createGpuLauchSinkOpsPass()"; + let dependentDialects = ["gpu::GPUDialect"]; +} + def GpuKernelOutlining : Pass<"gpu-kernel-outlining", "ModuleOp"> { let summary = "Outline gpu.launch bodies to kernel functions"; let constructor = "mlir::createGpuKernelOutliningPass()"; diff --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp --- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp @@ -59,7 +59,7 @@ /// Identifies operations that are beneficial to sink into kernels. These /// operations may not have side-effects, as otherwise sinking (and hence /// duplicating them) is not legal. -static bool isLikelyAnIndexComputatio(Operation *op) { +static bool isLikelyAnIndexComputation(Operation *op) { return isa(op); } @@ -232,6 +232,24 @@ } namespace { +/// Pass that moves ops which are likely an index computation into gpu.launch +/// body. +class GpuLaunchSinkOpsPass : public GpuLaunchSinkOpsBase { +public: + void runOnOperation() override { + Operation *op = getOperation(); + if (op->walk([](gpu::LaunchOp launch) { + // Pull in instructions that can be sunk + if (failed(sinkOperationsIntoLaunchOp(launch, + isLikelyAnIndexComputation))) + return WalkResult::interrupt(); + + return WalkResult::advance(); + }).wasInterrupted()) + signalPassFailure(); + } +}; + /// Pass that moves the kernel of each LaunchOp into its separate nested module. /// /// This pass moves the kernel code of each LaunchOp into a function created @@ -280,9 +298,6 @@ std::string kernelFnName = Twine(op->getParentOfType().getName(), "_kernel").str(); - // Pull in instructions that can be sunk - if (failed(sinkOperationsIntoLaunchOp(op, isLikelyAnIndexComputatio))) - return WalkResult::interrupt(); gpu::GPUFuncOp outlinedFunc = outlineKernelFuncImpl(op, kernelFnName, operands); @@ -360,6 +375,10 @@ } // namespace +std::unique_ptr mlir::createGpuLauchSinkOpsPass() { + return std::make_unique(); +} + std::unique_ptr> mlir::createGpuKernelOutliningPass(StringRef dataLayoutStr) { return std::make_unique(dataLayoutStr); diff --git a/mlir/test/Dialect/GPU/outlining.mlir b/mlir/test/Dialect/GPU/outlining.mlir --- a/mlir/test/Dialect/GPU/outlining.mlir +++ b/mlir/test/Dialect/GPU/outlining.mlir @@ -1,5 +1,5 @@ -// RUN: mlir-opt -allow-unregistered-dialect -gpu-kernel-outlining -split-input-file -verify-diagnostics %s | FileCheck %s -// RUN: mlir-opt -allow-unregistered-dialect -gpu-kernel-outlining=data-layout-str='#dlti.dl_spec<#dlti.dl_entry>' -split-input-file %s | FileCheck --check-prefix CHECK-DL %s +// RUN: mlir-opt -allow-unregistered-dialect -gpu-launch-sink-ops -gpu-kernel-outlining -split-input-file -verify-diagnostics %s | FileCheck %s +// RUN: mlir-opt -allow-unregistered-dialect -gpu-launch-sink-ops -gpu-kernel-outlining=data-layout-str='#dlti.dl_spec<#dlti.dl_entry>' -split-input-file %s | FileCheck --check-prefix CHECK-DL %s // CHECK: module attributes {gpu.container_module} diff --git a/mlir/test/Dialect/GPU/sink-ops.mlir b/mlir/test/Dialect/GPU/sink-ops.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/GPU/sink-ops.mlir @@ -0,0 +1,100 @@ +// RUN: mlir-opt -allow-unregistered-dialect -gpu-launch-sink-ops -split-input-file -verify-diagnostics %s | FileCheck %s + + +// CHECK-LABEL: @extra_constants +// CHECK-SAME: %[[ARG0:.*]]: memref +func @extra_constants(%arg0: memref) { + %cst = arith.constant 8 : index + %cst2 = arith.constant 2 : index + %c0 = arith.constant 0 : index + %cst3 = memref.dim %arg0, %c0 : memref + // CHECK: gpu.launch blocks + gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %cst, %grid_y = %cst, + %grid_z = %cst) + threads(%tx, %ty, %tz) in (%block_x = %cst, %block_y = %cst, + %block_z = %cst) { + // CHECK-NOT: arith.constant 8 + // CHECK: %[[CST2:.*]] = arith.constant 2 + // CHECK-NEXT: %[[CST0:.*]] = arith.constant 0 + // CHECK-NEXT: %[[DIM:.*]] = memref.dim %[[ARG0]], %[[CST0]] + // CHECK-NEXT: "use"(%[[CST2]], %[[ARG0]], %[[DIM]]) : (index, memref, index) -> () + // CHECK-NEXT: gpu.terminator + "use"(%cst2, %arg0, %cst3) : (index, memref, index) -> () + gpu.terminator + } + return +} + +// ----- + +// CHECK-LABEL: @extra_constants_not_inlined +// CHECK-SAME: %[[ARG0:.*]]: memref +func @extra_constants_not_inlined(%arg0: memref) { + %cst = arith.constant 8 : index + %cst2 = arith.constant 2 : index + %c0 = arith.constant 0 : index + // CHECK: %[[CST_X:.*]] = "secret_constant"() + %cst3 = "secret_constant"() : () -> index + // CHECK: gpu.launch blocks + gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %cst, %grid_y = %cst, + %grid_z = %cst) + threads(%tx, %ty, %tz) in (%block_x = %cst, %block_y = %cst, + %block_z = %cst) { + // CHECK-NOT: arith.constant 8 + // CHECK-NOT: "secret_constant"() + // CHECK: %[[CST2:.*]] = arith.constant 2 + // CHECK-NEXT: "use"(%[[CST2]], %[[ARG0]], %[[CST_X]]) : (index, memref, index) -> () + // CHECK-NEXT: gpu.terminator + "use"(%cst2, %arg0, %cst3) : (index, memref, index) -> () + gpu.terminator + } + return +} + +// ----- + +// CHECK-LABEL: @multiple_uses +func @multiple_uses(%arg0 : memref) { + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + // CHECK: gpu.launch blocks + gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, + %grid_z = %c1) + threads(%tx, %ty, %tz) in (%block_x = %c1, %block_y = %c1, + %block_z = %c1) { + // CHECK: %[[C2:.*]] = arith.constant 2 + // CHECK-NEXT: "use1"(%[[C2]], %[[C2]]) + // CHECK-NEXT: "use2"(%[[C2]]) + // CHECK-NEXT: gpu.terminator + "use1"(%c2, %c2) : (index, index) -> () + "use2"(%c2) : (index) -> () + gpu.terminator + } + return +} + +// ----- + +// CHECK-LABEL: @multiple_uses2 +func @multiple_uses2(%arg0 : memref<*xf32>) { + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %d = memref.dim %arg0, %c2 : memref<*xf32> + // CHECK: gpu.launch blocks + gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, + %grid_z = %c1) + threads(%tx, %ty, %tz) in (%block_x = %c1, %block_y = %c1, + %block_z = %c1) { + // CHECK: %[[C2:.*]] = arith.constant 2 : index + // CHECK: %[[D:.*]] = memref.dim %[[ARG:.*]], %[[C2]] + // CHECK: "use1"(%[[D]]) + // CHECK: "use2"(%[[C2]], %[[C2]]) + // CHECK: "use3"(%[[ARG]]) + // CHECK: gpu.terminator + "use1"(%d) : (index) -> () + "use2"(%c2, %c2) : (index, index) -> () + "use3"(%arg0) : (memref<*xf32>) -> () + gpu.terminator + } + return +}