diff --git a/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h --- a/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h +++ b/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h @@ -43,8 +43,9 @@ /// /// A value "has memory effects" iff it: /// (1.a) is an operand of an op with memory effects OR -/// (1.b) is a non-forwarded branch operand and a block where its op could -/// take the control has an op with memory effects. +/// (1.b) is a non-forwarded branch operand and its branch op could take the +/// control to a block that has an op with memory effects OR +/// (1.c) is a non-forwarded call operand. /// /// A value `A` is said to be "used to compute" value `B` iff `B` cannot be /// computed in the absence of `A`. Thus, in this implementation, we say that @@ -83,6 +84,8 @@ void visitBranchOperand(OpOperand &operand) override; + void visitCallOperand(OpOperand &operand) override; + void setToExitState(Liveness *lattice) override; }; diff --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h --- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h +++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h @@ -363,9 +363,12 @@ Operation *op, ArrayRef operandLattices, ArrayRef resultLattices) = 0; - // Visit operands on branch instructions that are not forwarded + // Visit operands on branch instructions that are not forwarded. virtual void visitBranchOperand(OpOperand &operand) = 0; + // Visit operands on call instructions that are not forwarded. + virtual void visitCallOperand(OpOperand &operand) = 0; + /// Set the given lattice element(s) at control flow exit point(s). virtual void setToExitState(AbstractSparseLattice *lattice) = 0; diff --git a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp --- a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include #include @@ -54,8 +55,9 @@ /// /// A value "has memory effects" iff it: /// (1.a) is an operand of an op with memory effects OR -/// (1.b) is a non-forwarded branch operand and a block where its op could -/// take the control has an op with memory effects. +/// (1.b) is a non-forwarded branch operand and its branch op could take the +/// control to a block that has an op with memory effects OR +/// (1.c) is a non-forwarded call operand. /// /// A value `A` is said to be "used to compute" value `B` iff `B` cannot be /// computed in the absence of `A`. Thus, in this implementation, we say that @@ -149,8 +151,6 @@ // Now that we have checked for memory-effecting ops in the blocks of concern, // we will simply visit the op with this non-forwarded operand to potentially // mark it "live" due to type (1.a/3) liveness. - if (operand.getOperandNumber() > 0) - return; SmallVector operandLiveness; operandLiveness.push_back(getLatticeElement(operand.get())); SmallVector resultsLiveness; @@ -171,6 +171,22 @@ visitOperation(parentOp, operandLiveness, parentResultsLiveness); } +void LivenessAnalysis::visitCallOperand(OpOperand &operand) { + // We know (at the moment) and assume (for the future) that `operand` is a + // non-forwarded call operand of an op implementing `CallOpInterface`. + assert(isa(operand.getOwner()) && + "expected the op to implement `CallOpInterface`"); + + // The lattices of the non-forwarded call operands don't get updated like the + // forwarded call operands or the non-call operands. Thus they need to be + // handled separately. This is where we handle them. + + // This marks values of type (1.c) liveness as "live". A non-forwarded + // call operand is live. + Liveness *operandLiveness = getLatticeElement(operand.get()); + propagateIfChanged(operandLiveness, operandLiveness->markLive()); +} + void LivenessAnalysis::setToExitState(Liveness *lattice) { // This marks values of type (2) liveness as "live". lattice->markLive(); diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp --- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp @@ -412,19 +412,34 @@ return; } - // For function calls, connect the arguments of the entry blocks - // to the operands of the call op. + // For function calls, connect the arguments of the entry blocks to the + // operands of the call op that are forwarded to these arguments. if (auto call = dyn_cast(op)) { Operation *callableOp = call.resolveCallable(&symbolTable); if (auto callable = dyn_cast_or_null(callableOp)) { + // Not all operands of a call op forward to arguments. Such operands are + // stored in `unaccounted`. + BitVector unaccounted(op->getNumOperands(), true); + + OperandRange argOperands = call.getArgOperands(); + MutableArrayRef argOpOperands = + operandsToOpOperands(argOperands); Region *region = callable.getCallableRegion(); if (region && !region->empty()) { Block &block = region->front(); - for (auto [blockArg, operand] : - llvm::zip(block.getArguments(), operandLattices)) { - meet(operand, *getLatticeElementFor(op, blockArg)); + for (auto [blockArg, argOpOperand] : + llvm::zip(block.getArguments(), argOpOperands)) { + meet(getLatticeElement(argOpOperand.get()), + *getLatticeElementFor(op, blockArg)); + unaccounted.reset(argOpOperand.getOperandNumber()); } } + // Handle the operands of the call op that aren't forwarded to any + // arguments. + for (int index : unaccounted.set_bits()) { + OpOperand &opOperand = op->getOpOperand(index); + visitCallOperand(opOperand); + } return; } } diff --git a/mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir b/mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir --- a/mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir +++ b/mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir @@ -59,11 +59,27 @@ // ----- +func.func private @private(%arg0 : i32, %arg1 : i32) { + func.return +} + +// Positive test: Type (1.c) "is a non-forwarded call operand" +// CHECK-LABEL: test_tag: call +// CHECK-LABEL: operand #0: not live +// CHECK-LABEL: operand #1: not live +// CHECK-LABEL: operand #2: live +func.func @test_4_type_1.c(%arg0: i32, %arg1: i32, %device: i32, %m0: memref) { + test.call_on_device @private(%arg0, %arg1), %device {tag = "call"} : (i32, i32, i32) -> () + return +} + +// ----- + // Positive test: Type (2) "is returned by a public function" // zero is live because it is returned by a public function. // CHECK-LABEL: test_tag: zero: // CHECK-NEXT: result #0: live -func.func @test_4_type_2() -> (f32){ +func.func @test_5_type_2() -> (f32){ %0 = arith.constant {tag = "zero"} 0.0 : f32 return %0 : f32 } @@ -90,7 +106,7 @@ // CHECK-NEXT: operand #3: live // CHECK-LABEL: test_tag: add: // CHECK-NEXT: operand #0: live -func.func @test_5_RegionBranchTerminatorOpInterface_type_3(%arg0: memref, %arg1: i1) -> (i32) { +func.func @test_6_RegionBranchTerminatorOpInterface_type_3(%arg0: memref, %arg1: i1) -> (i32) { %c0_i32 = arith.constant 0 : i32 %c1_i32 = arith.constant 1 : i32 %c2_i32 = arith.constant 2 : i32 @@ -135,7 +151,7 @@ // CHECK-NEXT: result #0: live // CHECK-LABEL: test_tag: y: // CHECK-NEXT: result #0: not live -func.func @test_6_type_3(%arg0: memref) { +func.func @test_7_type_3(%arg0: memref) { %c0 = arith.constant {tag = "zero"} 0 : index %c10 = arith.constant {tag = "ten"} 10 : index %c1 = arith.constant {tag = "one"} 1 : index @@ -190,7 +206,7 @@ // CHECK-NEXT: operand #0: live // CHECK-NEXT: operand #1: live // CHECK-NEXT: result #0: live -func.func @test_7_type_3(%arg: i32) -> (i32) { +func.func @test_8_type_3(%arg: i32) -> (i32) { %0 = func.call @private1(%arg) : (i32) -> i32 %final = arith.muli %0, %arg {tag = "final"} : i32 return %final : i32 @@ -205,7 +221,7 @@ // CHECK-NEXT: result #0: not live // CHECK-LABEL: test_tag: one: // CHECK-NEXT: result #0: live -func.func @test_8_negative() -> (f32){ +func.func @test_9_negative() -> (f32){ %0 = arith.constant {tag = "zero"} 0.0 : f32 %1 = arith.constant {tag = "one"} 1.0 : f32 return %1 : f32 @@ -230,7 +246,7 @@ %1 = arith.addi %0, %0 {tag = "one"} : i32 return %0, %1 : i32, i32 } -func.func @test_9_negative() -> (i32) { +func.func @test_10_negative() -> (i32) { %0:2 = func.call @private_1() : () -> (i32, i32) return %0#0 : i32 } diff --git a/mlir/test/Analysis/DataFlow/test-written-to.mlir b/mlir/test/Analysis/DataFlow/test-written-to.mlir --- a/mlir/test/Analysis/DataFlow/test-written-to.mlir +++ b/mlir/test/Analysis/DataFlow/test-written-to.mlir @@ -286,4 +286,21 @@ llvm.func @func(%lb : i64) -> () { llvm.call @decl(%lb) : (i64) -> () llvm.return -} +} + +// ----- + +func.func private @callee(%arg0 : i32, %arg1 : i32) -> i32 { + func.return %arg0 : i32 +} + +// CHECK-LABEL: test_tag: a +// CHECK-LABEL: operand #0: [b] +// CHECK-LABEL: operand #1: [] +// CHECK-LABEL: operand #2: [callarg2] +// CHECK-LABEL: result #0: [b] +func.func @test_call_on_device(%arg0: i32, %arg1: i32, %device: i32, %m0: memref) { + %0 = test.call_on_device @callee(%arg0, %arg1), %device {tag = "a"} : (i32, i32, i32) -> (i32) + memref.store %0, %m0[] {tag_name = "b"} : memref + return +} diff --git a/mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp --- a/mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp +++ b/mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp @@ -57,6 +57,8 @@ void visitBranchOperand(OpOperand &operand) override; + void visitCallOperand(OpOperand &operand) override; + void setToExitState(WrittenTo *lattice) override { lattice->writes.clear(); } }; @@ -87,6 +89,16 @@ propagateIfChanged(lattice, lattice->addWrites(newWrites)); } +void WrittenToAnalysis::visitCallOperand(OpOperand &operand) { + // Mark call operands as "callarg%d", with %d the operand number. + WrittenTo *lattice = getLatticeElement(operand.get()); + SetVector newWrites; + newWrites.insert( + StringAttr::get(operand.getOwner()->getContext(), + "callarg" + Twine(operand.getOperandNumber()))); + propagateIfChanged(lattice, lattice->addWrites(newWrites)); +} + } // end anonymous namespace namespace { diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -1296,6 +1296,22 @@ return getCalleeOperandsMutable(); } +CallInterfaceCallable TestCallOnDeviceOp::getCallableForCallee() { + return getCallee(); +} + +void TestCallOnDeviceOp::setCalleeFromCallable(CallInterfaceCallable callee) { + setCalleeAttr(callee.get()); +} + +Operation::operand_range TestCallOnDeviceOp::getArgOperands() { + return getForwardedOperands(); +} + +MutableOperandRange TestCallOnDeviceOp::getArgOperandsMutable() { + return getForwardedOperandsMutable(); +} + void TestStoreWithARegion::getSuccessorRegions( std::optional index, SmallVectorImpl ®ions) { if (!index) { diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -2834,6 +2834,21 @@ "`:` functional-type(operands, results)"; } +def TestCallOnDeviceOp : TEST_Op<"call_on_device", + [DeclareOpInterfaceMethods]> { + let arguments = (ins + SymbolRefAttr:$callee, + Variadic:$forwarded_operands, + AnyType:$non_forwarded_device_operand + ); + let results = (outs + Variadic:$results + ); + let assemblyFormat = + "$callee `(` $forwarded_operands `)` `,` $non_forwarded_device_operand " + "attr-dict `:` functional-type(operands, results)"; +} + def TestStoreWithARegion : TEST_Op<"store_with_a_region", [DeclareOpInterfaceMethods, SingleBlock]> {