diff --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp --- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp @@ -65,16 +65,18 @@ /// operation into the kernel. An operation can be sunk if doing so does not /// introduce new kernel arguments. Whether a value is already available in the /// kernel (and hence does not introduce new arguments) is checked by -/// querying `availableValues`. +/// querying `existingDependencies` and `availableValues`. /// If an operand is not yet available, we recursively check whether it can be /// made available by siking its defining op. /// Operations that are indentified for sinking are added to `beneficiaryOps` in -/// the order the should appear in the kernel. Furthermore, `availableValues` is -/// updated with results that will be available after sinking the identified +/// the order they should appear in the kernel. Furthermore, `availableValues` +/// is updated with results that will be available after sinking the identified /// ops. -static bool extractBeneficiaryOps(Operation *op, - llvm::SetVector &beneficiaryOps, - llvm::SetVector &availableValues) { +static bool +extractBeneficiaryOps(Operation *op, + llvm::SetVector existingDependencies, + llvm::SetVector &beneficiaryOps, + llvm::SmallPtrSetImpl &availableValues) { if (beneficiaryOps.count(op)) return true; @@ -85,10 +87,13 @@ // It is already visisble in the kernel, keep going. if (availableValues.count(operand)) continue; - // Else check whether it can be made available via sinking. + // Else check whether it can be made available via sinking or already is a + // dependency. Operation *definingOp = operand.getDefiningOp(); - if (!definingOp || - !extractBeneficiaryOps(definingOp, beneficiaryOps, availableValues)) + if ((!definingOp || + !extractBeneficiaryOps(definingOp, existingDependencies, + beneficiaryOps, availableValues)) && + !existingDependencies.count(operand)) return false; } // We will sink the operation, mark its results as now available. @@ -106,13 +111,13 @@ llvm::SetVector sinkCandidates; getUsedValuesDefinedAbove(launchOpBody, sinkCandidates); - SmallVector worklist(sinkCandidates.begin(), sinkCandidates.end()); llvm::SetVector toBeSunk; - for (Value operand : worklist) { + llvm::SmallPtrSet availableValues; + for (Value operand : sinkCandidates) { Operation *operandOp = operand.getDefiningOp(); if (!operandOp) continue; - extractBeneficiaryOps(operandOp, toBeSunk, sinkCandidates); + extractBeneficiaryOps(operandOp, sinkCandidates, toBeSunk, availableValues); } // Insert operations so that the defs get cloned before uses. diff --git a/mlir/test/Dialect/GPU/outlining.mlir b/mlir/test/Dialect/GPU/outlining.mlir --- a/mlir/test/Dialect/GPU/outlining.mlir +++ b/mlir/test/Dialect/GPU/outlining.mlir @@ -165,14 +165,15 @@ // ----- +// CHECK-LABEL: @multiple_uses func @multiple_uses(%arg0 : memref) { %c1 = constant 1 : index %c2 = constant 2 : index // CHECK: gpu.func {{.*}} { - // CHECK: %[[C2:.*]] = constant 2 : index - // CHECK: "use1"(%[[C2]], %[[C2]]) - // CHECK: "use2"(%[[C2]]) - // CHECK: gpu.return + // CHECK: %[[C2:.*]] = constant 2 : index + // CHECK: "use1"(%[[C2]], %[[C2]]) + // CHECK: "use2"(%[[C2]]) + // CHECK: gpu.return // CHECK: } gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1) @@ -187,6 +188,33 @@ // ----- +// CHECK-LABEL: @multiple_uses2 +func @multiple_uses2(%arg0 : memref<*xf32>) { + %c1 = constant 1 : index + %c2 = constant 2 : index + %d = dim %arg0, %c2 : memref<*xf32> + // CHECK: gpu.func {{.*}} { + // CHECK: %[[C2:.*]] = constant 2 : index + // CHECK: %[[D:.*]] = dim %[[ARG:.*]], %[[C2]] + // CHECK: "use1"(%[[D]]) + // CHECK: "use2"(%[[C2]], %[[C2]]) + // CHECK: "use3"(%[[ARG]]) + // CHECK: gpu.return + // CHECK: } + gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, + %grid_z = %c1) + threads(%tx, %ty, %tz) in (%block_x = %c1, %block_y = %c1, + %block_z = %c1) { + "use1"(%d) : (index) -> () + "use2"(%c2, %c2) : (index, index) -> () + "use3"(%arg0) : (memref<*xf32>) -> () + gpu.terminator + } + return +} + +// ----- + llvm.mlir.global internal @global(42 : i64) : !llvm.i64 //CHECK-LABEL: @function_call