diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h @@ -19,6 +19,13 @@ namespace comprehensive_bufferize { namespace scf_ext { +/// Equivalence analysis for scf.for. Raise an error if iter_args are not +/// equivalent to their corresponding loop yield values. +struct AssertDestinationPassingStyle : public PostAnalysisStep { + LogicalResult run(Operation *op, BufferizationState &state, + SmallVector &newOps) override; +}; + void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry); } // namespace scf_ext diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -63,6 +63,9 @@ /// performs a function call analysis and bufferizes an entire module. std::unique_ptr createLinalgComprehensiveModuleBufferizePass(); +/// This is a pass that bufferizes only SCF dialect ops. +std::unique_ptr createLinalgComprehensiveSCFBufferizePass(); + /// This is a pass that bufferizes only tensor dialect ops. std::unique_ptr createLinalgComprehensiveTensorBufferizePass(); diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -80,6 +80,16 @@ let constructor = "mlir::createLinalgComprehensiveModuleBufferizePass()"; } +def LinalgComprehensiveSCFBufferize : + FunctionPass<"linalg-comprehensive-scf-bufferize"> { + let summary = "Bufferize (tensor into memref) SCF dialect ops."; + let description = [{ + This pass bufferizes only SCF dialect ops. + TODO: This will be moved to the bufferization dialect in the future. + }]; + let constructor = "mlir::createLinalgComprehensiveSCFBufferizePass()"; +} + def LinalgComprehensiveTensorBufferize : FunctionPass<"linalg-comprehensive-tensor-bufferize"> { let summary = "Bufferize (tensor into memref) tensor dialect ops."; diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp @@ -141,22 +141,7 @@ /// Return `true` if a value was marked as in-place bufferized. bool BufferizationAliasInfo::isInPlace(OpResult opResult) const { - bool inplace = inplaceBufferized.contains(opResult); -#ifndef NDEBUG - if (inplace) { - auto bufferizableOp = - dyn_cast(opResult.getDefiningOp()); - assert(bufferizableOp && - "expected that in-place bufferized op is bufferizable"); - SmallVector operands = - bufferizableOp.getAliasingOpOperand(opResult); - for (OpOperand *operand : operands) - assert(areAliasingBufferizedValues(operand->get(), opResult) && - "expected that in-place bufferized OpResult aliases with " - "aliasing OpOperand"); - } -#endif // NDEBUG - return inplace; + return inplaceBufferized.contains(opResult); } /// Set the inPlace bufferization spec to true. @@ -435,8 +420,14 @@ LogicalResult mlir::linalg::comprehensive_bufferize::bufferize(Block *block, BufferizationState &state) { + // Ops may get deleted during the traversal, so do not iterate over `block` + // directly. + SmallVector ops; + ops.reserve(block->getOperations().size()); for (Operation &op : *block) - if (failed(bufferize(&op, state))) + ops.push_back(&op); + for (Operation *op : ops) + if (failed(bufferize(op, state))) return failure(); return success(); } diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp @@ -7,10 +7,12 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" namespace mlir { namespace linalg { @@ -221,40 +223,114 @@ return true; } - LogicalResult bufferize(Operation *op, OpBuilder &b, + LogicalResult bufferize(Operation *op, OpBuilder & /*b*/, BufferizationState &state) const { auto forOp = cast(op); - - // Take a guard before anything else. - OpBuilder::InsertionGuard g(b); - - for (OpResult opResult : forOp->getResults()) { - if (!opResult.getType().isa()) - continue; - // TODO: Atm we bail on unranked TensorType because we don't know how to - // alloc an UnrankedMemRefType + its underlying ranked MemRefType. - assert(opResult.getType().isa() && - "unsupported unranked tensor"); - - // TODO: More general: Matching bbArg does not bufferize to a read. - Value resultBuffer = getResultBuffer(b, opResult, state); - if (!resultBuffer) - return failure(); - - OpOperand &opOperand = forOp.getOpOperandForResult(opResult); - BlockArgument bbArg = forOp.getRegionIterArgForOpOperand(opOperand); - state.aliasInfo.createAliasInfoEntry(resultBuffer); - state.aliasInfo.insertNewBufferEquivalence(bbArg, resultBuffer); - state.mapBuffer(bbArg, resultBuffer); - state.mapBuffer(opResult, resultBuffer); + Block *oldLoopBody = &forOp.getLoopBody().front(); + + // Use IRRewriter instead of OpBuilder because it has additional helper + // functions. + IRRewriter rewriter(op->getContext()); + rewriter.setInsertionPoint(forOp); + + // Indices of all iter_args that have tensor type. These are the ones that + // are bufferized. + DenseSet indices; + for (const auto &it : llvm::enumerate(forOp.initArgs())) + if (it.value().getType().isa()) + indices.insert(it.index()); + + // Given a range of values, apply `func` to those marked in `indices`. + // Otherwise, store the unmodified value in the result vector. + auto convert = [&](ValueRange values, + std::function func) { + SmallVector result; + for (const auto &it : llvm::enumerate(values)) { + size_t idx = it.index(); + Value val = it.value(); + result.push_back(indices.contains(idx) ? func(val, idx) : val); + } + return result; + }; + + // Construct a new scf.for op with memref instead of tensor values. + SmallVector initArgs = + convert(forOp.initArgs(), [&](Value val, int64_t index) { + return getResultBuffer(rewriter, forOp->getOpResult(index), state); + // return state.lookupBuffer(val); + }); + auto newForOp = + rewriter.create(forOp.getLoc(), forOp.lowerBound(), + forOp.upperBound(), forOp.step(), initArgs); + Block *loopBody = &newForOp.getLoopBody().front(); + + // Set up new iter_args. The loop body uses tensors, so wrap the (memref) + // iter_args of the new loop in ToTensorOps. + rewriter.setInsertionPointToStart(loopBody); + SmallVector iterArgs = + convert(newForOp.getRegionIterArgs(), [&](Value val, int64_t index) { + Value toTensorOp = + rewriter.create(val.getLoc(), val); + state.mapBuffer(toTensorOp, val); + return toTensorOp; + }); + iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar()); + + // Erase terminator if present. + if (iterArgs.size() == 1) + rewriter.eraseOp(loopBody->getTerminator()); + + // Move loop body to new loop. + rewriter.mergeBlocks(oldLoopBody, loopBody, iterArgs); + + // Update scf.yield of new loop. + auto yieldOp = cast(loopBody->getTerminator()); + rewriter.setInsertionPoint(yieldOp); + SmallVector yieldValues = + convert(yieldOp.results(), [&](Value val, int64_t index) { + return rewriter.create( + val.getLoc(), initArgs[index].getType(), val); + }); + yieldOp.resultsMutable().assign(yieldValues); + + // Replace loop results. + // TODO: Some of this functionality should be moved to BufferizatioState for + // better code reuse. + for (OpOperand &use : forOp->getUses()) { + auto loopResult = use.get().cast(); + int64_t resultIdx = loopResult.getResultNumber(); + bool isTensorResult = indices.contains(resultIdx); + Value replacement = newForOp->getResult(resultIdx); + if (isTensorResult) { + rewriter.setInsertionPoint(use.getOwner()); + replacement = rewriter.create( + loopResult.getLoc(), replacement); + } + use.set(replacement); + if (isTensorResult) + state.mapBuffer(replacement, newForOp->getResult(resultIdx)); } + // Erase the old loop. + forOp.erase(); + // Bufferize loop body. - if (failed(comprehensive_bufferize::bufferize(&forOp.region(), state))) + if (failed(comprehensive_bufferize::bufferize(loopBody, state))) return failure(); - // Finish bufferizing scf::ForOp. - auto yieldOp = cast(&forOp.region().front().back()); + return success(); + } +}; + +LogicalResult mlir::linalg::comprehensive_bufferize::scf_ext:: + AssertDestinationPassingStyle::run(Operation *op, BufferizationState &state, + SmallVector &newOps) { + LogicalResult status = success(); + op->walk([&](scf::YieldOp yieldOp) { + auto forOp = dyn_cast(yieldOp->getParentOp()); + if (!forOp) + return WalkResult::advance(); + for (OpOperand &operand : yieldOp->getOpOperands()) { auto tensorType = operand.get().getType().dyn_cast(); if (!tensorType) @@ -267,19 +343,19 @@ bbArg)) { // TODO: this could get resolved with copies but it can also turn into // swaps so we need to be careful about order of copies. - return yieldOp->emitError() - << "Yield operand #" << operand.getOperandNumber() - << " does not bufferize to an equivalent buffer to the matching" - << " enclosing scf::for operand"; + status = + yieldOp->emitError() + << "Yield operand #" << operand.getOperandNumber() + << " does not bufferize to an equivalent buffer to the matching" + << " enclosing scf::for operand"; + return WalkResult::interrupt(); } - - // Buffers are equivalent so the work is already done and we just yield - // the bbArg so that it later canonicalizes away. - operand.set(bbArg); } - return success(); - } -}; + + return WalkResult::advance(); + }); + return status; +} struct YieldOpInterface : public BufferizableOpInterface::ExternalModel(); // TODO: Find a way to enable these steps automatically. options.addPostAnalysisStep(); + options.addPostAnalysisStep(); options.allowReturnMemref = this->allowReturnMemref; options.allowUnknownOps = this->allowUnknownOps; @@ -119,6 +120,23 @@ void runOnFunction() final; }; +struct LinalgComprehensiveSCFBufferize + : public LinalgComprehensiveSCFBufferizeBase< + LinalgComprehensiveSCFBufferize> { + LinalgComprehensiveSCFBufferize() {} + + LinalgComprehensiveSCFBufferize(const LinalgComprehensiveSCFBufferize &p) {} + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + bufferization_ext::registerBufferizableOpInterfaceExternalModels(registry); + scf_ext::registerBufferizableOpInterfaceExternalModels(registry); + } + + void runOnFunction() override; +}; + struct LinalgComprehensiveTensorBufferize : public LinalgComprehensiveTensorBufferizeBase< LinalgComprehensiveTensorBufferize> { @@ -197,6 +215,15 @@ this->runCleanupPipeline("builtin.func", getOperation()); } +void LinalgComprehensiveSCFBufferize::runOnFunction() { + BufferizationOptions options; + options.allowReturnMemref = true; + options.allowUnknownOps = true; + + if (failed(runComprehensiveBufferize(getOperation(), options))) + signalPassFailure(); +} + void LinalgComprehensiveTensorBufferize::runOnFunction() { BufferizationOptions options; options.addPostAnalysisStep(); @@ -224,6 +251,10 @@ return std::make_unique(); } +std::unique_ptr mlir::createLinalgComprehensiveSCFBufferizePass() { + return std::make_unique(); +} + std::unique_ptr mlir::createLinalgComprehensiveTensorBufferizePass() { return std::make_unique(); } diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir @@ -9,8 +9,15 @@ // RUN: mlir-opt %s -allow-unregistered-dialect -linalg-comprehensive-module-bufferize="test-analysis-only analysis-fuzzer-seed=91" -split-input-file -o /dev/null // RUN: mlir-opt %s -allow-unregistered-dialect -linalg-comprehensive-tensor-bufferize -canonicalize -split-input-file | FileCheck %s --check-prefix=CHECK-TENSOR +// RUN: mlir-opt %s -allow-unregistered-dialect -linalg-comprehensive-scf-bufferize -canonicalize -split-input-file | FileCheck %s --check-prefix=CHECK-SCF // RUN: mlir-opt %s -allow-unregistered-dialect -linalg-comprehensive-vector-bufferize -canonicalize -split-input-file | FileCheck %s --check-prefix=CHECK-VECTOR +// RUN: mlir-opt %s -allow-unregistered-dialect -linalg-comprehensive-scf-bufferize -canonicalize | \ +// RUN: mlir-opt -linalg-comprehensive-tensor-bufferize -canonicalize | FileCheck %s --check-prefix=CHECK-SCF-TENSOR + +// RUN: mlir-opt %s -allow-unregistered-dialect -linalg-comprehensive-tensor-bufferize -canonicalize | \ +// RUN: mlir-opt -linalg-comprehensive-scf-bufferize -canonicalize | FileCheck %s --check-prefix=CHECK-SCF-TENSOR + // RUN: mlir-opt %s -allow-unregistered-dialect -linalg-comprehensive-tensor-bufferize -canonicalize | \ // RUN: mlir-opt -linalg-comprehensive-vector-bufferize -canonicalize | FileCheck %s --check-prefix=CHECK-TENSOR-VECTOR @@ -267,5 +274,51 @@ return %write2 : tensor } +// ----- - +// CHECK-SCF-LABEL: func @scf_for_equivalent_not_detected( +// CHECK-SCF-SAME: %[[t1:.*]]: tensor + +// CHECK-SCF-TENSOR-LABEL: func @scf_for_equivalent_not_detected( +// CHECK-SCF-TENSOR-SAME: %[[t1:.*]]: tensor +func @scf_for_equivalent_not_detected( + %t1: tensor, %sz: index, %step: index, %f: f32) -> tensor { + %c0 = arith.constant 0 : index + + // CHECK-SCF: %[[t1_memref:.*]] = bufferization.to_memref %[[t1]] + // CHECK-SCF: %[[alloc:.*]] = memref.alloc + // CHECK-SCF: %[[casted:.*]] = memref.cast %[[alloc]] + // CHECK-SCF: memref.copy %[[t1_memref]], %[[casted]] + // CHECK-SCF: %[[scf_for:.*]] = scf.for %[[iv:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[arg0:.*]] = %[[casted]]) -> ({{.*}}) { + + // CHECK-SCF-TENSOR: %[[t1_memref:.*]] = bufferization.to_memref %[[t1]] + // CHECK-SCF-TENSOR: %[[alloc:.*]] = memref.alloc + // CHECK-SCF-TENSOR: %[[casted:.*]] = memref.cast %[[alloc]] + // CHECK-SCF-TENSOR: memref.copy %[[t1_memref]], %[[casted]] + // CHECK-SCF-TENSOR: %[[scf_for:.*]] = scf.for %[[iv:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[arg0:.*]] = %[[casted]]) -> ({{.*}}) { + %0 = scf.for %iv = %c0 to %sz step %step iter_args(%arg0 = %t1) -> tensor { + // CHECK-SCF: %[[arg0_tensor:.*]] = bufferization.to_tensor %[[arg0]] + // CHECK-SCF: %[[insert:.*]] = tensor.insert %{{.*}} into %[[arg0_tensor]] + + // CHECK-SCF-TENSOR: %[[alloc2:.*]] = memref.alloc + // CHECK-SCF-TENSOR: %[[casted2:.*]] = memref.cast %[[alloc2]] + // CHECK-SCF-TENSOR: memref.copy %[[arg0]], %[[casted2]] + // CHECK-SCF-TENSOR: memref.store %{{.*}}, %[[alloc2]] + %1 = tensor.insert %f into %arg0[%iv] : tensor + + // CHECK-SCF: %[[insert_memref:.*]] = bufferization.to_memref %[[insert]] + // CHECK-SCF: scf.yield %[[insert_memref]] + + // CHECK-SCF-TENSOR: scf.yield %[[casted2]] + scf.yield %1 : tensor + } + // CHECK-SCF: } + // CHECK-SCF-TENSOR: } + + // CHECK-SCF: %[[scf_for_tensor:.*]] = bufferization.to_tensor %[[scf_for]] + // CHECK-SCF: return %[[scf_for_tensor]] + + // CHECK-SCF-TENSOR: %[[scf_for_tensor:.*]] = bufferization.to_tensor %[[scf_for]] + // CHECK-SCF-TENSOR: return %[[scf_for_tensor]] + return %0 : tensor +} diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir @@ -291,6 +291,7 @@ -> (tensor, tensor) { // CHECK: %[[ALLOC_FOR_A:.*]] = memref.alloc + // CHECK: %[[CASTED:.*]] = memref.cast %[[ALLOC_FOR_A]] // CHECK: linalg.copy(%[[A]], %[[ALLOC_FOR_A]]) // The first scf.for remains but just turns into dead code. @@ -304,7 +305,7 @@ } // CHECK: memref.dealloc %[[ALLOC_FOR_A]] : memref - // CHECK: return %[[ALLOC_FOR_A]] : memref + // CHECK: return %[[CASTED]] : memref return %r0, %r1: tensor, tensor } @@ -346,6 +347,7 @@ -> (tensor, tensor) { // CHECK: %[[ALLOC_FOR_A:.*]] = memref.alloc + // CHECK: %[[CASTED:.*]] = memref.cast %[[ALLOC_FOR_A]] // CHECK: linalg.copy(%[[A]], %[[ALLOC_FOR_A]]) // CHECK: %[[svA:.*]] = memref.subview %[[ALLOC_FOR_A]][0] [4] [1] @@ -369,7 +371,7 @@ } // CHECK: memref.dealloc %[[ALLOC_FOR_A]] : memref - // CHECK: return %[[ALLOC_FOR_A]] : memref + // CHECK: return %[[CASTED]] : memref return %r0#0, %r0#1: tensor, tensor } diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -6391,6 +6391,7 @@ includes = ["include"], deps = [ ":BufferizableOpInterface", + ":BufferizationDialect", ":IR", ":SCFDialect", ":Support",