diff --git a/mlir/include/mlir/Analysis/SliceAnalysis.h b/mlir/include/mlir/Analysis/SliceAnalysis.h --- a/mlir/include/mlir/Analysis/SliceAnalysis.h +++ b/mlir/include/mlir/Analysis/SliceAnalysis.h @@ -21,11 +21,27 @@ class Operation; class Value; -/// Type of the condition to limit the propagation of transitive use-defs. -/// This can be used in particular to limit the propagation to a given Scope or -/// to avoid passing through certain types of operation in a configurable -/// manner. -using TransitiveFilter = llvm::function_ref; +struct SliceOptions { + /// Type of the condition to limit the propagation of transitive use-defs. + /// This can be used in particular to limit the propagation to a given Scope + /// or to avoid passing through certain types of operation in a configurable + /// manner. + using TransitiveFilter = std::function; + TransitiveFilter filter = nullptr; + + /// Include the top level op in the slice. + bool inclusive = false; +}; + +struct BackwardSliceOptions : public SliceOptions { + /// When omitBlockArguments is true, the backward slice computation omits + /// traversing any block arguments. When omitBlockArguments is false, the + /// backward slice computation traverses block arguments and asserts that the + /// parent op has a single region with a single block. + bool omitBlockArguments = false; +}; + +using ForwardSliceOptions = SliceOptions; /// Fills `forwardSlice` with the computed forward slice (i.e. all /// the transitive uses of op), **without** including that operation. @@ -69,14 +85,12 @@ /// {4, 3, 6, 2, 1, 5, 8, 7, 9} /// void getForwardSlice(Operation *op, SetVector *forwardSlice, - TransitiveFilter filter = nullptr /* pass-through*/, - bool inclusive = false); + ForwardSliceOptions options = {}); /// Value-rooted version of `getForwardSlice`. Return the union of all forward /// slices for the uses of the value `root`. void getForwardSlice(Value root, SetVector *forwardSlice, - TransitiveFilter filter = nullptr /* pass-through*/, - bool inclusive = false); + ForwardSliceOptions options = {}); /// Fills `backwardSlice` with the computed backward slice (i.e. /// all the transitive defs of op), **without** including that operation. @@ -113,14 +127,12 @@ /// {1, 2, 5, 3, 4, 6} /// void getBackwardSlice(Operation *op, SetVector *backwardSlice, - TransitiveFilter filter = nullptr /* pass-through*/, - bool inclusive = false); + BackwardSliceOptions options = {}); /// Value-rooted version of `getBackwardSlice`. Return the union of all backward /// slices for the op defining or owning the value `root`. void getBackwardSlice(Value root, SetVector *backwardSlice, - TransitiveFilter filter = nullptr /* pass-through*/, - bool inclusive = false); + BackwardSliceOptions options = {}); /// Iteratively computes backward slices and forward slices until /// a fixed point is reached. Returns an `SetVector` which @@ -199,11 +211,9 @@ /// and keep things ordered but this is still hand-wavy and not worth the /// trouble for now: punt to a simple worklist-based solution. /// -SetVector -getSlice(Operation *op, - TransitiveFilter backwardFilter = nullptr /* pass-through*/, - TransitiveFilter forwardFilter = nullptr /* pass-through*/, - bool inclusive = false); +SetVector getSlice(Operation *op, + BackwardSliceOptions backwardSliceOptions = {}, + ForwardSliceOptions forwardSliceOptions = {}); /// Multi-root DAG topological sort. /// Performs a topological sort of the Operation in the `toSort` SetVector. diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp --- a/mlir/lib/Analysis/SliceAnalysis.cpp +++ b/mlir/lib/Analysis/SliceAnalysis.cpp @@ -24,9 +24,9 @@ using namespace mlir; -static void getForwardSliceImpl(Operation *op, - SetVector *forwardSlice, - TransitiveFilter filter) { +static void +getForwardSliceImpl(Operation *op, SetVector *forwardSlice, + SliceOptions::TransitiveFilter filter = nullptr) { if (!op) return; @@ -51,9 +51,9 @@ } void mlir::getForwardSlice(Operation *op, SetVector *forwardSlice, - TransitiveFilter filter, bool inclusive) { - getForwardSliceImpl(op, forwardSlice, filter); - if (!inclusive) { + ForwardSliceOptions options) { + getForwardSliceImpl(op, forwardSlice, options.filter); + if (!options.inclusive) { // Don't insert the top level operation, we just queried on it and don't // want it in the results. forwardSlice->remove(op); @@ -67,9 +67,9 @@ } void mlir::getForwardSlice(Value root, SetVector *forwardSlice, - TransitiveFilter filter, bool inclusive) { + SliceOptions options) { for (Operation *user : root.getUsers()) - getForwardSliceImpl(user, forwardSlice, filter); + getForwardSliceImpl(user, forwardSlice, options.filter); // Reverse to get back the actual topological order. // std::reverse does not work out of the box on SetVector and I want an @@ -80,22 +80,25 @@ static void getBackwardSliceImpl(Operation *op, SetVector *backwardSlice, - TransitiveFilter filter) { + BackwardSliceOptions options) { if (!op || op->hasTrait()) return; // Evaluate whether we should keep this def. // This is useful in particular to implement scoping; i.e. return the // transitive backwardSlice in the current scope. - if (filter && !filter(op)) + if (options.filter && !options.filter(op)) return; for (const auto &en : llvm::enumerate(op->getOperands())) { auto operand = en.value(); if (auto *definingOp = operand.getDefiningOp()) { if (backwardSlice->count(definingOp) == 0) - getBackwardSliceImpl(definingOp, backwardSlice, filter); + getBackwardSliceImpl(definingOp, backwardSlice, options); } else if (auto blockArg = dyn_cast(operand)) { + if (options.omitBlockArguments) + continue; + Block *block = blockArg.getOwner(); Operation *parentOp = block->getParentOp(); // TODO: determine whether we want to recurse backward into the other @@ -104,7 +107,7 @@ if (parentOp && backwardSlice->count(parentOp) == 0) { assert(parentOp->getNumRegions() == 1 && parentOp->getRegion(0).getBlocks().size() == 1); - getBackwardSliceImpl(parentOp, backwardSlice, filter); + getBackwardSliceImpl(parentOp, backwardSlice, options); } } else { llvm_unreachable("No definingOp and not a block argument."); @@ -116,10 +119,10 @@ void mlir::getBackwardSlice(Operation *op, SetVector *backwardSlice, - TransitiveFilter filter, bool inclusive) { - getBackwardSliceImpl(op, backwardSlice, filter); + BackwardSliceOptions options) { + getBackwardSliceImpl(op, backwardSlice, options); - if (!inclusive) { + if (!options.inclusive) { // Don't insert the top level operation, we just queried on it and don't // want it in the results. backwardSlice->remove(op); @@ -127,19 +130,18 @@ } void mlir::getBackwardSlice(Value root, SetVector *backwardSlice, - TransitiveFilter filter, bool inclusive) { + BackwardSliceOptions options) { if (Operation *definingOp = root.getDefiningOp()) { - getBackwardSlice(definingOp, backwardSlice, filter, inclusive); + getBackwardSlice(definingOp, backwardSlice, options); return; } Operation *bbAargOwner = cast(root).getOwner()->getParentOp(); - getBackwardSlice(bbAargOwner, backwardSlice, filter, inclusive); + getBackwardSlice(bbAargOwner, backwardSlice, options); } SetVector mlir::getSlice(Operation *op, - TransitiveFilter backwardFilter, - TransitiveFilter forwardFilter, - bool inclusive) { + BackwardSliceOptions backwardSliceOptions, + ForwardSliceOptions forwardSliceOptions) { SetVector slice; slice.insert(op); @@ -150,12 +152,12 @@ auto *currentOp = (slice)[currentIndex]; // Compute and insert the backwardSlice starting from currentOp. backwardSlice.clear(); - getBackwardSlice(currentOp, &backwardSlice, backwardFilter, inclusive); + getBackwardSlice(currentOp, &backwardSlice, backwardSliceOptions); slice.insert(backwardSlice.begin(), backwardSlice.end()); // Compute and insert the forwardSlice starting from currentOp. forwardSlice.clear(); - getForwardSlice(currentOp, &forwardSlice, forwardFilter, inclusive); + getForwardSlice(currentOp, &forwardSlice, forwardSliceOptions); slice.insert(forwardSlice.begin(), forwardSlice.end()); ++currentIndex; } @@ -225,8 +227,11 @@ Operation *ancestorOp) { // Compute the backward slice of the value. SetVector slice; - getBackwardSlice(value, &slice, - [&](Operation *op) { return !ancestorOp->isAncestor(op); }); + BackwardSliceOptions sliceOptions; + sliceOptions.filter = [&](Operation *op) { + return !ancestorOp->isAncestor(op); + }; + getBackwardSlice(value, &slice, sliceOptions); // Check that none of the operands of the operations in the backward slice are // loop iteration arguments, and neither is the value itself. diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -303,9 +303,9 @@ /// Return an unsorted slice handling scf.for region differently than /// `getSlice`. In scf.for we only want to include as part of the slice elements /// that are part of the use/def chain. -static SetVector getSliceContract(Operation *op, - TransitiveFilter backwardFilter, - TransitiveFilter forwardFilter) { +static SetVector +getSliceContract(Operation *op, BackwardSliceOptions backwardSliceOptions, + ForwardSliceOptions forwardSliceOptions) { SetVector slice; slice.insert(op); unsigned currentIndex = 0; @@ -315,7 +315,7 @@ auto *currentOp = (slice)[currentIndex]; // Compute and insert the backwardSlice starting from currentOp. backwardSlice.clear(); - getBackwardSlice(currentOp, &backwardSlice, backwardFilter); + getBackwardSlice(currentOp, &backwardSlice, backwardSliceOptions); slice.insert(backwardSlice.begin(), backwardSlice.end()); // Compute and insert the forwardSlice starting from currentOp. @@ -326,11 +326,11 @@ // converted to matrix type. if (auto forOp = dyn_cast(currentOp)) { for (Value forOpResult : forOp.getResults()) - getForwardSlice(forOpResult, &forwardSlice, forwardFilter); + getForwardSlice(forOpResult, &forwardSlice, forwardSliceOptions); for (BlockArgument &arg : forOp.getRegionIterArgs()) - getForwardSlice(arg, &forwardSlice, forwardFilter); + getForwardSlice(arg, &forwardSlice, forwardSliceOptions); } else { - getForwardSlice(currentOp, &forwardSlice, forwardFilter); + getForwardSlice(currentOp, &forwardSlice, forwardSliceOptions); } slice.insert(forwardSlice.begin(), forwardSlice.end()); ++currentIndex; @@ -346,16 +346,22 @@ return llvm::any_of(op->getResultTypes(), [](Type t) { return isa(t); }); }; + BackwardSliceOptions backwardSliceOptions; + backwardSliceOptions.filter = hasVectorDest; + auto hasVectorSrc = [](Operation *op) { return llvm::any_of(op->getOperandTypes(), [](Type t) { return isa(t); }); }; + ForwardSliceOptions forwardSliceOptions; + forwardSliceOptions.filter = hasVectorSrc; + SetVector opToConvert; op->walk([&](vector::ContractionOp contract) { if (opToConvert.contains(contract.getOperation())) return; SetVector dependentOps = - getSliceContract(contract, hasVectorDest, hasVectorSrc); + getSliceContract(contract, backwardSliceOptions, forwardSliceOptions); // If any instruction cannot use MMA matrix type drop the whole // chain. MMA matrix are stored in an opaque type so they cannot be used // by all operations. diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp @@ -111,20 +111,22 @@ scf::ForOp outermostEnclosingForOp, SetVector &backwardSlice) { DominanceInfo domInfo(outermostEnclosingForOp); - auto filter = [&](Operation *op) { + BackwardSliceOptions sliceOptions; + sliceOptions.filter = [&](Operation *op) { return domInfo.dominates(outermostEnclosingForOp, op) && !padOp->isProperAncestor(op); }; + sliceOptions.inclusive = true; + // First, add the ops required to compute the region to the backwardSlice. SetVector valuesDefinedAbove; getUsedValuesDefinedAbove(padOp.getRegion(), padOp.getRegion(), valuesDefinedAbove); for (Value v : valuesDefinedAbove) { - getBackwardSlice(v, &backwardSlice, filter, /*inclusive=*/true); + getBackwardSlice(v, &backwardSlice, sliceOptions); } // Then, add the backward slice from padOp itself. - getBackwardSlice(padOp.getOperation(), &backwardSlice, filter, - /*inclusive=*/true); + getBackwardSlice(padOp.getOperation(), &backwardSlice, sliceOptions); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -797,9 +797,11 @@ // Return failure when any op fails to hoist. static LogicalResult hoistOpsBetween(scf::ForOp outer, scf::ForOp inner) { SetVector forwardSlice; - getForwardSlice( - outer.getInductionVar(), &forwardSlice, - [&inner](Operation *op) { return op != inner.getOperation(); }); + ForwardSliceOptions options; + options.filter = [&inner](Operation *op) { + return op != inner.getOperation(); + }; + getForwardSlice(outer.getInductionVar(), &forwardSlice, options); LogicalResult status = success(); SmallVector toHoist; for (auto &op : outer.getBody()->without_terminator()) { diff --git a/mlir/test/IR/slice_multiple_blocks.mlir b/mlir/test/IR/slice_multiple_blocks.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/IR/slice_multiple_blocks.mlir @@ -0,0 +1,36 @@ +// RUN: mlir-opt --pass-pipeline="builtin.module(slice-analysis-test{omit-block-arguments=true})" %s | FileCheck %s + +func.func @slicing_linalg_op(%arg0 : index, %arg1 : index, %arg2 : index) { + %a = memref.alloc(%arg0, %arg2) : memref + %b = memref.alloc(%arg2, %arg1) : memref + cf.br ^bb1 +^bb1() : + %c = memref.alloc(%arg0, %arg1) : memref + %d = memref.alloc(%arg0, %arg1) : memref + linalg.matmul ins(%a, %b : memref, memref) + outs(%c : memref) + linalg.matmul ins(%a, %b : memref, memref) + outs(%d : memref) + memref.dealloc %c : memref + memref.dealloc %b : memref + memref.dealloc %a : memref + memref.dealloc %d : memref + return +} +// CHECK-LABEL: func @slicing_linalg_op__backward_slice__0 +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index +// CHECK-DAG: %[[A:.+]] = memref.alloc(%[[ARG0]], %[[ARG2]]) : memref +// CHECK-DAG: %[[B:.+]] = memref.alloc(%[[ARG2]], %[[ARG1]]) : memref +// CHECK-DAG: %[[C:.+]] = memref.alloc(%[[ARG0]], %[[ARG1]]) : memref +// CHECK: return + +// CHECK-LABEL: func @slicing_linalg_op__backward_slice__1 +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index +// CHECK-DAG: %[[A:.+]] = memref.alloc(%[[ARG0]], %[[ARG2]]) : memref +// CHECK-DAG: %[[B:.+]] = memref.alloc(%[[ARG2]], %[[ARG1]]) : memref +// CHECK-DAG: %[[C:.+]] = memref.alloc(%[[ARG0]], %[[ARG1]]) : memref +// CHECK: return diff --git a/mlir/test/lib/IR/TestSlicing.cpp b/mlir/test/lib/IR/TestSlicing.cpp --- a/mlir/test/lib/IR/TestSlicing.cpp +++ b/mlir/test/lib/IR/TestSlicing.cpp @@ -24,7 +24,8 @@ /// Create a function with the same signature as the parent function of `op` /// with name being the function name and a `suffix`. static LogicalResult createBackwardSliceFunction(Operation *op, - StringRef suffix) { + StringRef suffix, + bool omitBlockArguments) { func::FuncOp parentFuncOp = op->getParentOfType(); OpBuilder builder(parentFuncOp); Location loc = op->getLoc(); @@ -36,7 +37,9 @@ for (const auto &arg : enumerate(parentFuncOp.getArguments())) mapper.map(arg.value(), clonedFuncOp.getArgument(arg.index())); SetVector slice; - getBackwardSlice(op, &slice); + BackwardSliceOptions options; + options.omitBlockArguments = omitBlockArguments; + getBackwardSlice(op, &slice, options); for (Operation *slicedOp : slice) builder.clone(*slicedOp, mapper); builder.create(loc); @@ -53,6 +56,13 @@ StringRef getDescription() const final { return "Test Slice analysis functionality."; } + + Option omitBlockArguments{ + *this, "omit-block-arguments", + llvm::cl::desc("Test Slice analysis with multiple blocks but slice " + "omiting block arguments"), + llvm::cl::init(true)}; + void runOnOperation() override; SliceAnalysisTestPass() = default; SliceAnalysisTestPass(const SliceAnalysisTestPass &) {} @@ -64,11 +74,6 @@ auto funcOps = module.getOps(); unsigned opNum = 0; for (auto funcOp : funcOps) { - if (!llvm::hasSingleElement(funcOp.getBody())) { - funcOp->emitOpError("Does not support functions with multiple blocks"); - signalPassFailure(); - return; - } // TODO: For now this is just looking for Linalg ops. It can be generalized // to look for other ops using flags. funcOp.walk([&](Operation *op) { @@ -76,7 +81,7 @@ return WalkResult::advance(); std::string append = std::string("__backward_slice__") + std::to_string(opNum); - (void)createBackwardSliceFunction(op, append); + (void)createBackwardSliceFunction(op, append, omitBlockArguments); opNum++; return WalkResult::advance(); });