diff --git a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp --- a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include #include @@ -55,7 +56,7 @@ /// A value "has memory effects" iff it: /// (1.a) is an operand of an op with memory effects OR /// (1.b) is a non-forwarded branch operand and a block where its op could -/// take the control has an op with memory effects. +/// take the control is a callable or has an op with memory effects. /// /// A value `A` is said to be "used to compute" value `B` iff `B` cannot be /// computed in the absence of `A`. Thus, in this implementation, we say that @@ -90,23 +91,33 @@ void LivenessAnalysis::visitBranchOperand(OpOperand &operand) { // We know (at the moment) and assume (for the future) that `operand` is a // non-forwarded branch operand of a `RegionBranchOpInterface`, - // `BranchOpInterface`, `RegionBranchTerminatorOpInterface` or return-like op. + // `BranchOpInterface`, `RegionBranchTerminatorOpInterface`, return-like op, + // or `CallOpInterface`. Operation *op = operand.getOwner(); assert((isa(op) || isa(op) || isa(op) || - op->hasTrait()) && + op->hasTrait() || isa(op)) && "expected the op to be `RegionBranchOpInterface`, " - "`BranchOpInterface`, `RegionBranchTerminatorOpInterface`, or " - "return-like"); + "`BranchOpInterface`, `RegionBranchTerminatorOpInterface`, " + "return-like, or `CallOpInterface`"); // The lattices of the non-forwarded branch operands don't get updated like // the forwarded branch operands or the non-branch operands. Thus they need // to be handled separately. This is where we handle them. - // This marks values of type (1.b) liveness as "live". A non-forwarded - // branch operand will be live if a block where its op could take the control - // has an op with memory effects. - // Populating such blocks in `blocks`. + // Marking values of type (1.b) liveness as "live"... + + // Since the block to which a `CallOpInterface` with a non-forwarded branch + // operand takes control is callable, we simply mark these operands as live. + if (isa(op)) { + Liveness *operandLiveness = getLatticeElement(operand.get()); + propagateIfChanged(operandLiveness, operandLiveness->markLive()); + return; + } + + // A non-forwarded branch operand will be live if a block where its op could + // take the control has an op with memory effects. Populating such blocks in + // `blocks`. SmallVector blocks; if (isa(op)) { // When the op is a `RegionBranchOpInterface`, like an `scf.for` or an @@ -151,8 +162,6 @@ // Now that we have checked for memory-effecting ops in the blocks of concern, // we will simply visit the op with this non-forwarded operand to potentially // mark it "live" due to type (1.a/3) liveness. - if (operand.getOperandNumber() > 0) - return; SmallVector operandLiveness; operandLiveness.push_back(getLatticeElement(operand.get())); SmallVector resultsLiveness; diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp --- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp @@ -412,19 +412,34 @@ return; } - // For function calls, connect the arguments of the entry blocks - // to the operands of the call op. + // For function calls, connect the arguments of the entry blocks to the + // operands of the call op that are forwarded to these arguments. if (auto call = dyn_cast(op)) { Operation *callableOp = call.resolveCallable(&symbolTable); if (auto callable = dyn_cast_or_null(callableOp)) { + // Not all operands of a call op forward to arguments. Such operands are + // stored in `unaccounted`. + BitVector unaccounted(op->getNumOperands(), true); + + OperandRange argOperands = call.getArgOperands(); + MutableArrayRef argOpOperands = + operandsToOpOperands(argOperands); Region *region = callable.getCallableRegion(); if (region && !region->empty()) { Block &block = region->front(); - for (auto [blockArg, operand] : - llvm::zip(block.getArguments(), operandLattices)) { - meet(operand, *getLatticeElementFor(op, blockArg)); + for (auto [blockArg, argOpOperand] : + llvm::zip(block.getArguments(), argOpOperands)) { + meet(getLatticeElement(argOpOperand.get()), + *getLatticeElementFor(op, blockArg)); + unaccounted.reset(argOpOperand.getOperandNumber()); } } + // Handle the operands of the call op that aren't forwarded to any + // arguments. + for (int index : unaccounted.set_bits()) { + OpOperand &opOperand = op->getOpOperand(index); + visitBranchOperand(opOperand); + } return; } } diff --git a/mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir b/mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir --- a/mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir +++ b/mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir @@ -13,7 +13,8 @@ // ----- // Positive test: Type (1.b) "is a non-forwarded branch operand and a block -// where its op could take the control has an op with memory effects" +// where its op could take the control is a callable or has an op with memory +// effects" // %arg2 is live because it can make the control go into a block with a memory // effecting op. // CHECK-LABEL: test_tag: br: @@ -36,7 +37,8 @@ // ----- // Positive test: Type (1.b) "is a non-forwarded branch operand and a block -// where its op could take the control has an op with memory effects" +// where its op could take the control is a callable or has an op with memory +// effects" // %arg0 is live because it can make the control go into a block with a memory // effecting op. // CHECK-LABEL: test_tag: flag: @@ -59,11 +61,29 @@ // ----- +func.func private @private(%arg0 : i32, %arg1 : i32) { + func.return +} + +// Positive test: Type (1.b) "is a non-forwarded branch operand and a block +// where its op could take the control is a callable or has an op with memory +// effects" +// CHECK-LABEL: test_tag: call +// CHECK-LABEL: operand #0: not live +// CHECK-LABEL: operand #1: not live +// CHECK-LABEL: operand #2: live +func.func @test_4_CallOpInterface_type_1.b(%arg0: i32, %arg1: i32, %device: i32, %m0: memref) { + test.call_on_device @private(%arg0, %arg1), %device {tag = "call"} : (i32, i32, i32) -> () + return +} + +// ----- + // Positive test: Type (2) "is returned by a public function" // zero is live because it is returned by a public function. // CHECK-LABEL: test_tag: zero: // CHECK-NEXT: result #0: live -func.func @test_4_type_2() -> (f32){ +func.func @test_5_type_2() -> (f32){ %0 = arith.constant {tag = "zero"} 0.0 : f32 return %0 : f32 } @@ -90,7 +110,7 @@ // CHECK-NEXT: operand #3: live // CHECK-LABEL: test_tag: add: // CHECK-NEXT: operand #0: live -func.func @test_5_RegionBranchTerminatorOpInterface_type_3(%arg0: memref, %arg1: i1) -> (i32) { +func.func @test_6_RegionBranchTerminatorOpInterface_type_3(%arg0: memref, %arg1: i1) -> (i32) { %c0_i32 = arith.constant 0 : i32 %c1_i32 = arith.constant 1 : i32 %c2_i32 = arith.constant 2 : i32 @@ -135,7 +155,7 @@ // CHECK-NEXT: result #0: live // CHECK-LABEL: test_tag: y: // CHECK-NEXT: result #0: not live -func.func @test_6_type_3(%arg0: memref) { +func.func @test_7_type_3(%arg0: memref) { %c0 = arith.constant {tag = "zero"} 0 : index %c10 = arith.constant {tag = "ten"} 10 : index %c1 = arith.constant {tag = "one"} 1 : index @@ -190,7 +210,7 @@ // CHECK-NEXT: operand #0: live // CHECK-NEXT: operand #1: live // CHECK-NEXT: result #0: live -func.func @test_7_type_3(%arg: i32) -> (i32) { +func.func @test_8_type_3(%arg: i32) -> (i32) { %0 = func.call @private1(%arg) : (i32) -> i32 %final = arith.muli %0, %arg {tag = "final"} : i32 return %final : i32 @@ -205,7 +225,7 @@ // CHECK-NEXT: result #0: not live // CHECK-LABEL: test_tag: one: // CHECK-NEXT: result #0: live -func.func @test_8_negative() -> (f32){ +func.func @test_9_negative() -> (f32){ %0 = arith.constant {tag = "zero"} 0.0 : f32 %1 = arith.constant {tag = "one"} 1.0 : f32 return %1 : f32 @@ -230,7 +250,7 @@ %1 = arith.addi %0, %0 {tag = "one"} : i32 return %0, %1 : i32, i32 } -func.func @test_9_negative() -> (i32) { +func.func @test_10_negative() -> (i32) { %0:2 = func.call @private_1() : () -> (i32, i32) return %0#0 : i32 } diff --git a/mlir/test/Analysis/DataFlow/test-written-to.mlir b/mlir/test/Analysis/DataFlow/test-written-to.mlir --- a/mlir/test/Analysis/DataFlow/test-written-to.mlir +++ b/mlir/test/Analysis/DataFlow/test-written-to.mlir @@ -286,4 +286,21 @@ llvm.func @func(%lb : i64) -> () { llvm.call @decl(%lb) : (i64) -> () llvm.return -} +} + +// ----- + +func.func private @callee(%arg0 : i32, %arg1 : i32) -> i32 { + func.return %arg0 : i32 +} + +// CHECK-LABEL: test_tag: a +// CHECK-LABEL: operand #0: [b] +// CHECK-LABEL: operand #1: [] +// CHECK-LABEL: operand #2: [brancharg2] +// CHECK-LABEL: result #0: [b] +func.func @test_call_on_device(%arg0: i32, %arg1: i32, %device: i32, %m0: memref) { + %0 = test.call_on_device @callee(%arg0, %arg1), %device {tag = "a"} : (i32, i32, i32) -> (i32) + memref.store %0, %m0[] {tag_name = "b"} : memref + return +} diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -1267,6 +1267,22 @@ return getCalleeOperandsMutable(); } +CallInterfaceCallable TestCallOnDeviceOp::getCallableForCallee() { + return getCallee(); +} + +void TestCallOnDeviceOp::setCalleeFromCallable(CallInterfaceCallable callee) { + setCalleeAttr(callee.get()); +} + +Operation::operand_range TestCallOnDeviceOp::getArgOperands() { + return getForwardedOperands(); +} + +MutableOperandRange TestCallOnDeviceOp::getArgOperandsMutable() { + return getForwardedOperandsMutable(); +} + void TestStoreWithARegion::getSuccessorRegions( std::optional index, ArrayRef operands, SmallVectorImpl ®ions) { diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -2811,6 +2811,21 @@ "`:` functional-type(operands, results)"; } +def TestCallOnDeviceOp : TEST_Op<"call_on_device", + [DeclareOpInterfaceMethods]> { + let arguments = (ins + SymbolRefAttr:$callee, + Variadic:$forwarded_operands, + AnyType:$non_forwarded_device_operand + ); + let results = (outs + Variadic:$results + ); + let assemblyFormat = + "$callee `(` $forwarded_operands `)` `,` $non_forwarded_device_operand " + "attr-dict `:` functional-type(operands, results)"; +} + def TestStoreWithARegion : TEST_Op<"store_with_a_region", [DeclareOpInterfaceMethods, SingleBlock]> {