diff --git a/mlir/include/mlir/Analysis/SliceAnalysis.h b/mlir/include/mlir/Analysis/SliceAnalysis.h --- a/mlir/include/mlir/Analysis/SliceAnalysis.h +++ b/mlir/include/mlir/Analysis/SliceAnalysis.h @@ -19,6 +19,7 @@ namespace mlir { class Operation; +class Value; /// Type of the condition to limit the propagation of transitive use-defs. /// This can be used in particular to limit the propagation to a given Scope or @@ -72,6 +73,13 @@ TransitiveFilter filter = /* pass-through*/ [](Operation *) { return true; }); +/// Value-rooted version of `getForwardSlice`. Return the union of all forward +/// slices for the uses of the value `root`. +void getForwardSlice( + Value root, llvm::SetVector *forwardSlice, + TransitiveFilter filter = /* pass-through*/ + [](Operation *) { return true; }); + /// Fills `backwardSlice` with the computed backward slice (i.e. /// all the transitive defs of op), **without** including that operation. /// @@ -111,6 +119,13 @@ TransitiveFilter filter = /* pass-through*/ [](Operation *) { return true; }); +/// Value-rooted version of `getBackwardSlice`. Return the union of all backward +/// slices for the op defining or owning the value `root`. +void getBackwardSlice( + Value root, llvm::SetVector *backwardSlice, + TransitiveFilter filter = /* pass-through*/ + [](Operation *) { return true; }); + /// Iteratively computes backward slices and forward slices until /// a fixed point is reached. Returns an `llvm::SetVector` which /// **includes** the original operation. diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp --- a/mlir/lib/Analysis/SliceAnalysis.cpp +++ b/mlir/lib/Analysis/SliceAnalysis.cpp @@ -30,36 +30,24 @@ static void getForwardSliceImpl(Operation *op, SetVector *forwardSlice, TransitiveFilter filter) { - if (!op) { + if (!op) return; - } // Evaluate whether we should keep this use. // This is useful in particular to implement scoping; i.e. return the // transitive forwardSlice in the current scope. - if (!filter(op)) { + if (!filter(op)) return; - } - if (auto forOp = dyn_cast(op)) { - for (Operation *userOp : forOp.getInductionVar().getUsers()) + for (Region ®ion : op->getRegions()) + for (Block &block : region) + for (Operation &op : block) + if (forwardSlice->count(&op) == 0) + getForwardSliceImpl(&op, forwardSlice, filter); + for (Value result : op->getResults()) { + for (Operation *userOp : result.getUsers()) if (forwardSlice->count(userOp) == 0) getForwardSliceImpl(userOp, forwardSlice, filter); - } else if (auto forOp = dyn_cast(op)) { - for (Operation *userOp : forOp.getInductionVar().getUsers()) - if (forwardSlice->count(userOp) == 0) - getForwardSliceImpl(userOp, forwardSlice, filter); - for (Value result : forOp.getResults()) - for (Operation *userOp : result.getUsers()) - if (forwardSlice->count(userOp) == 0) - getForwardSliceImpl(userOp, forwardSlice, filter); - } else { - assert(op->getNumRegions() == 0 && "unexpected generic op with regions"); - for (Value result : op->getResults()) { - for (Operation *userOp : result.getUsers()) - if (forwardSlice->count(userOp) == 0) - getForwardSliceImpl(userOp, forwardSlice, filter); - } } forwardSlice->insert(op); @@ -79,45 +67,47 @@ forwardSlice->insert(v.rbegin(), v.rend()); } +void mlir::getForwardSlice(Value root, SetVector *forwardSlice, + TransitiveFilter filter) { + for (Operation *user : root.getUsers()) + getForwardSliceImpl(user, forwardSlice, filter); + + // Reverse to get back the actual topological order. + // std::reverse does not work out of the box on SetVector and I want an + // in-place swap based thing (the real std::reverse, not the LLVM adapter). + std::vector v(forwardSlice->takeVector()); + forwardSlice->insert(v.rbegin(), v.rend()); +} + static void getBackwardSliceImpl(Operation *op, SetVector *backwardSlice, TransitiveFilter filter) { - if (!op) + if (!op || isa(op)) return; - assert((op->getNumRegions() == 0 || - isa( - op)) && - "unexpected generic op with regions"); - // Evaluate whether we should keep this def. // This is useful in particular to implement scoping; i.e. return the - // transitive forwardSlice in the current scope. - if (!filter(op)) { + // transitive backwardSlice in the current scope. + if (!filter(op)) return; - } for (auto en : llvm::enumerate(op->getOperands())) { auto operand = en.value(); - if (auto blockArg = operand.dyn_cast()) { - if (auto affIv = getForInductionVarOwner(operand)) { - auto *affOp = affIv.getOperation(); - if (backwardSlice->count(affOp) == 0) - getBackwardSliceImpl(affOp, backwardSlice, filter); - } else if (auto loopIv = scf::getForInductionVarOwner(operand)) { - auto *loopOp = loopIv.getOperation(); - if (backwardSlice->count(loopOp) == 0) - getBackwardSliceImpl(loopOp, backwardSlice, filter); - } else if (blockArg.getOwner() != - &op->getParentOfType().getBody().front()) { - op->emitError("unsupported CF for operand ") << en.index(); - llvm_unreachable("Unsupported control flow"); - } - continue; - } - auto *op = operand.getDefiningOp(); - if (backwardSlice->count(op) == 0) { - getBackwardSliceImpl(op, backwardSlice, filter); + if (auto *definingOp = operand.getDefiningOp()) { + if (backwardSlice->count(definingOp) == 0) + getBackwardSliceImpl(definingOp, backwardSlice, filter); + } else if (auto blockArg = operand.dyn_cast()) { + Block *block = blockArg.getOwner(); + Operation *parentOp = block->getParentOp(); + // TODO: determine whether we want to recurse backward into the other + // blocks of parentOp, which are not technically backward unless they flow + // into us. For now, just bail. + assert(parentOp->getNumRegions() == 1 && + parentOp->getRegion(0).getBlocks().size() == 1); + if (backwardSlice->count(parentOp) == 0) + getBackwardSliceImpl(parentOp, backwardSlice, filter); + } else { + llvm_unreachable("No definingOp and not a block argument."); } } @@ -134,6 +124,16 @@ backwardSlice->remove(op); } +void mlir::getBackwardSlice(Value root, SetVector *backwardSlice, + TransitiveFilter filter) { + if (Operation *definingOp = root.getDefiningOp()) { + getBackwardSlice(definingOp, backwardSlice, filter); + return; + } + Operation *bbAargOwner = root.cast().getOwner()->getParentOp(); + getBackwardSlice(bbAargOwner, backwardSlice, filter); +} + SetVector mlir::getSlice(Operation *op, TransitiveFilter backwardFilter, TransitiveFilter forwardFilter) { diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp @@ -243,7 +243,7 @@ << "\n"); llvm::SetVector forwardSlice; - getForwardSlice(transferRead, &forwardSlice); + getForwardSlice(transferRead.getOperation(), &forwardSlice); // Look for the last TransferWriteOp in the forwardSlice of // `transferRead` that operates on the same memref. @@ -381,9 +381,10 @@ // Get the backwards slice from `padTensorOp` that is dominated by the // outermost enclosing loop. DominanceInfo domInfo(outermostEnclosingForOp); - getBackwardSlice(padTensorOp, &backwardSlice, [&](Operation *op) { - return domInfo.dominates(outermostEnclosingForOp, op); - }); + getBackwardSlice(padTensorOp.getOperation(), &backwardSlice, + [&](Operation *op) { + return domInfo.dominates(outermostEnclosingForOp, op); + }); // Bail on any op with a region that is not a LoopLikeInterface or a LinalgOp. if (llvm::any_of(backwardSlice, [](Operation *op) { diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -1830,9 +1830,9 @@ // Return failure when any op fails to hoist. static LogicalResult hoistOpsBetween(scf::ForOp outer, scf::ForOp inner) { SetVector forwardSlice; - getForwardSlice(outer.getOperation(), &forwardSlice, [&inner](Operation *op) { - return op != inner.getOperation(); - }); + getForwardSlice( + outer.getInductionVar(), &forwardSlice, + [&inner](Operation *op) { return op != inner.getOperation(); }); LogicalResult status = success(); SmallVector toHoist; for (auto &op : outer.getBody()->without_terminator()) { @@ -1844,8 +1844,8 @@ status = failure(); continue; } - // Skip scf::ForOp, these are not considered a failure. - if (op.getNumRegions() > 0) + // Skip intermediate scf::ForOp, these are not considered a failure. + if (isa(op)) continue; // Skip other ops with regions. if (op.getNumRegions() > 0) {