diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -63,6 +63,9 @@ /// performs a function call analysis and bufferizes an entire module. std::unique_ptr createLinalgComprehensiveModuleBufferizePass(); +/// This is a pass that bufferizes only SCF dialect ops. +std::unique_ptr createLinalgComprehensiveSCFBufferizePass(); + /// This is a pass that bufferizes only tensor dialect ops. std::unique_ptr createLinalgComprehensiveTensorBufferizePass(); diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -80,6 +80,16 @@ let constructor = "mlir::createLinalgComprehensiveModuleBufferizePass()"; } +def LinalgComprehensiveSCFBufferize : + FunctionPass<"linalg-comprehensive-scf-bufferize"> { + let summary = "Bufferize (tensor into memref) SCF dialect ops."; + let description = [{ + This pass bufferizes only SCF dialect ops. + TODO: This will be moved to the bufferization dialect in the future. + }]; + let constructor = "mlir::createLinalgComprehensiveSCFBufferizePass()"; +} + def LinalgComprehensiveTensorBufferize : FunctionPass<"linalg-comprehensive-tensor-bufferize"> { let summary = "Bufferize (tensor into memref) tensor dialect ops."; diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp @@ -410,8 +410,14 @@ LogicalResult mlir::linalg::comprehensive_bufferize::bufferize(Block *block, BufferizationState &state) { + // Ops may get deleted during the traversal, so do not iterate over `block` + // directly. + SmallVector ops; + ops.reserve(block->getOperations().size()); for (Operation &op : *block) - if (failed(bufferize(&op, state))) + ops.push_back(&op); + for (Operation *op : ops) + if (failed(bufferize(op, state))) return failure(); return success(); } diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp @@ -7,10 +7,12 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" namespace mlir { namespace linalg { @@ -220,51 +222,100 @@ return true; } - LogicalResult bufferize(Operation *op, OpBuilder &b, + LogicalResult bufferize(Operation *op, OpBuilder & /*b*/, BufferizationState &state) const { auto forOp = cast(op); - - // Take a guard before anything else. - OpBuilder::InsertionGuard g(b); - - for (OpResult opResult : forOp->getResults()) { - if (!opResult.getType().isa()) - continue; - // TODO: Atm we bail on unranked TensorType because we don't know how to - // alloc an UnrankedMemRefType + its underlying ranked MemRefType. - assert(opResult.getType().isa() && - "unsupported unranked tensor"); - - // TODO: More general: Matching bbArg does not bufferize to a read. - Value resultBuffer = state.getResultBuffer(opResult); - if (!resultBuffer) - return failure(); - - OpOperand &opOperand = forOp.getOpOperandForResult(opResult); - BlockArgument bbArg = forOp.getRegionIterArgForOpOperand(opOperand); - state.mapBuffer(bbArg, resultBuffer); - state.mapBuffer(opResult, resultBuffer); + Block *oldLoopBody = &forOp.getLoopBody().front(); + + // Use IRRewriter instead of OpBuilder because it has additional helper + // functions. + IRRewriter rewriter(op->getContext()); + rewriter.setInsertionPoint(forOp); + + // Indices of all iter_args that have tensor type. These are the ones that + // are bufferized. + DenseSet indices; + for (const auto &it : llvm::enumerate(forOp.initArgs())) + if (it.value().getType().isa()) + indices.insert(it.index()); + + // Given a range of values, apply `func` to those marked in `indices`. + // Otherwise, store the unmodified value in the result vector. + auto convert = [&](ValueRange values, + std::function func) { + SmallVector result; + for (const auto &it : llvm::enumerate(values)) { + size_t idx = it.index(); + Value val = it.value(); + result.push_back(indices.contains(idx) ? func(val, idx) : val); + } + return result; + }; + + // Construct a new scf.for op with memref instead of tensor values. + SmallVector initArgs = + convert(forOp.initArgs(), [&](Value val, int64_t index) { + return state.getResultBuffer(forOp->getOpResult(index)); + }); + auto newForOp = + rewriter.create(forOp.getLoc(), forOp.lowerBound(), + forOp.upperBound(), forOp.step(), initArgs); + Block *loopBody = &newForOp.getLoopBody().front(); + + // Set up new iter_args. The loop body uses tensors, so wrap the (memref) + // iter_args of the new loop in ToTensorOps. + rewriter.setInsertionPointToStart(loopBody); + SmallVector iterArgs = + convert(newForOp.getRegionIterArgs(), [&](Value val, int64_t index) { + Value toTensorOp = + rewriter.create(val.getLoc(), val); + state.mapBuffer(toTensorOp, val); + return toTensorOp; + }); + iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar()); + + // Erase terminator if present. + if (iterArgs.size() == 1) + rewriter.eraseOp(loopBody->getTerminator()); + + // Move loop body to new loop. + rewriter.mergeBlocks(oldLoopBody, loopBody, iterArgs); + + // Update scf.yield of new loop. + auto yieldOp = cast(loopBody->getTerminator()); + rewriter.setInsertionPoint(yieldOp); + SmallVector yieldValues = + convert(yieldOp.results(), [&](Value val, int64_t index) { + return rewriter.create( + val.getLoc(), initArgs[index].getType(), val); + }); + yieldOp.resultsMutable().assign(yieldValues); + + // Replace loop results. + // TODO: Some of this functionality should be moved to BufferizatioState for + // better code reuse. + for (OpOperand &use : forOp->getUses()) { + auto loopResult = use.get().cast(); + int64_t resultIdx = loopResult.getResultNumber(); + bool isTensorResult = indices.contains(resultIdx); + Value replacement = newForOp->getResult(resultIdx); + if (isTensorResult) { + rewriter.setInsertionPoint(use.getOwner()); + replacement = rewriter.create( + loopResult.getLoc(), replacement); + } + use.set(replacement); + if (isTensorResult) + state.mapBuffer(replacement, newForOp->getResult(resultIdx)); } + // Erase the old loop. + forOp.erase(); + // Bufferize loop body. - if (failed(comprehensive_bufferize::bufferize(&forOp.region(), state))) + if (failed(comprehensive_bufferize::bufferize(loopBody, state))) return failure(); - // Finish bufferizing scf::ForOp. - auto yieldOp = cast(&forOp.region().front().back()); - for (OpOperand &operand : yieldOp->getOpOperands()) { - auto tensorType = operand.get().getType().dyn_cast(); - if (!tensorType) - continue; - - OpOperand &forOperand = forOp.getOpOperandForResult( - forOp->getResult(operand.getOperandNumber())); - auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand); - - // Buffers are equivalent so the work is already done and we just yield - // the bbArg so that it later canonicalizes away. - operand.set(bbArg); - } return success(); } }; diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp @@ -121,6 +121,23 @@ void runOnFunction() final; }; +struct LinalgComprehensiveSCFBufferize + : public LinalgComprehensiveSCFBufferizeBase< + LinalgComprehensiveSCFBufferize> { + LinalgComprehensiveSCFBufferize() {} + + LinalgComprehensiveSCFBufferize(const LinalgComprehensiveSCFBufferize &p) {} + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + bufferization_ext::registerBufferizableOpInterfaceExternalModels(registry); + scf_ext::registerBufferizableOpInterfaceExternalModels(registry); + } + + void runOnFunction() override; +}; + struct LinalgComprehensiveTensorBufferize : public LinalgComprehensiveTensorBufferizeBase< LinalgComprehensiveTensorBufferize> { @@ -199,6 +216,15 @@ this->runCleanupPipeline("builtin.func", getOperation()); } +void LinalgComprehensiveSCFBufferize::runOnFunction() { + BufferizationOptions options; + options.allowReturnMemref = true; + options.allowUnknownOps = true; + + if (failed(runComprehensiveBufferize(getOperation(), options))) + signalPassFailure(); +} + void LinalgComprehensiveTensorBufferize::runOnFunction() { BufferizationOptions options; options.addPostAnalysisStep(); @@ -226,6 +252,10 @@ return std::make_unique(); } +std::unique_ptr mlir::createLinalgComprehensiveSCFBufferizePass() { + return std::make_unique(); +} + std::unique_ptr mlir::createLinalgComprehensiveTensorBufferizePass() { return std::make_unique(); } diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir @@ -9,8 +9,15 @@ // RUN: mlir-opt %s -allow-unregistered-dialect -linalg-comprehensive-module-bufferize="test-analysis-only analysis-fuzzer-seed=91" -split-input-file -o /dev/null // RUN: mlir-opt %s -allow-unregistered-dialect -linalg-comprehensive-tensor-bufferize -canonicalize -split-input-file | FileCheck %s --check-prefix=CHECK-TENSOR +// RUN: mlir-opt %s -allow-unregistered-dialect -linalg-comprehensive-scf-bufferize -canonicalize -split-input-file | FileCheck %s --check-prefix=CHECK-SCF // RUN: mlir-opt %s -allow-unregistered-dialect -linalg-comprehensive-vector-bufferize -canonicalize -split-input-file | FileCheck %s --check-prefix=CHECK-VECTOR +// RUN: mlir-opt %s -allow-unregistered-dialect -linalg-comprehensive-scf-bufferize -canonicalize | \ +// RUN: mlir-opt -linalg-comprehensive-tensor-bufferize -canonicalize | FileCheck %s --check-prefix=CHECK-SCF-TENSOR + +// RUN: mlir-opt %s -allow-unregistered-dialect -linalg-comprehensive-tensor-bufferize -canonicalize | \ +// RUN: mlir-opt -linalg-comprehensive-scf-bufferize -canonicalize | FileCheck %s --check-prefix=CHECK-SCF-TENSOR + // RUN: mlir-opt %s -allow-unregistered-dialect -linalg-comprehensive-tensor-bufferize -canonicalize | \ // RUN: mlir-opt -linalg-comprehensive-vector-bufferize -canonicalize | FileCheck %s --check-prefix=CHECK-TENSOR-VECTOR @@ -273,5 +280,51 @@ return %write2 : tensor } +// ----- - +// CHECK-SCF-LABEL: func @scf_for_equivalent_not_detected( +// CHECK-SCF-SAME: %[[t1:.*]]: tensor + +// CHECK-SCF-TENSOR-LABEL: func @scf_for_equivalent_not_detected( +// CHECK-SCF-TENSOR-SAME: %[[t1:.*]]: tensor +func @scf_for_equivalent_not_detected( + %t1: tensor, %sz: index, %step: index, %f: f32) -> tensor { + %c0 = arith.constant 0 : index + + // CHECK-SCF: %[[t1_memref:.*]] = bufferization.to_memref %[[t1]] + // CHECK-SCF: %[[alloc:.*]] = memref.alloc + // CHECK-SCF: %[[casted:.*]] = memref.cast %[[alloc]] + // CHECK-SCF: memref.copy %[[t1_memref]], %[[casted]] + // CHECK-SCF: %[[scf_for:.*]] = scf.for %[[iv:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[arg0:.*]] = %[[casted]]) -> ({{.*}}) { + + // CHECK-SCF-TENSOR: %[[t1_memref:.*]] = bufferization.to_memref %[[t1]] + // CHECK-SCF-TENSOR: %[[alloc:.*]] = memref.alloc + // CHECK-SCF-TENSOR: %[[casted:.*]] = memref.cast %[[alloc]] + // CHECK-SCF-TENSOR: memref.copy %[[t1_memref]], %[[casted]] + // CHECK-SCF-TENSOR: %[[scf_for:.*]] = scf.for %[[iv:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[arg0:.*]] = %[[casted]]) -> ({{.*}}) { + %0 = scf.for %iv = %c0 to %sz step %step iter_args(%arg0 = %t1) -> tensor { + // CHECK-SCF: %[[arg0_tensor:.*]] = bufferization.to_tensor %[[arg0]] + // CHECK-SCF: %[[insert:.*]] = tensor.insert %{{.*}} into %[[arg0_tensor]] + + // CHECK-SCF-TENSOR: %[[alloc2:.*]] = memref.alloc + // CHECK-SCF-TENSOR: %[[casted2:.*]] = memref.cast %[[alloc2]] + // CHECK-SCF-TENSOR: memref.copy %[[arg0]], %[[casted2]] + // CHECK-SCF-TENSOR: memref.store %{{.*}}, %[[alloc2]] + %1 = tensor.insert %f into %arg0[%iv] : tensor + + // CHECK-SCF: %[[insert_memref:.*]] = bufferization.to_memref %[[insert]] + // CHECK-SCF: scf.yield %[[insert_memref]] + + // CHECK-SCF-TENSOR: scf.yield %[[casted2]] + scf.yield %1 : tensor + } + // CHECK-SCF: } + // CHECK-SCF-TENSOR: } + + // CHECK-SCF: %[[scf_for_tensor:.*]] = bufferization.to_tensor %[[scf_for]] + // CHECK-SCF: return %[[scf_for_tensor]] + + // CHECK-SCF-TENSOR: %[[scf_for_tensor:.*]] = bufferization.to_tensor %[[scf_for]] + // CHECK-SCF-TENSOR: return %[[scf_for_tensor]] + return %0 : tensor +} diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir @@ -291,6 +291,7 @@ -> (tensor, tensor) { // CHECK: %[[ALLOC_FOR_A:.*]] = memref.alloc + // CHECK: %[[CASTED:.*]] = memref.cast %[[ALLOC_FOR_A]] // CHECK: linalg.copy(%[[A]], %[[ALLOC_FOR_A]]) // The first scf.for remains but just turns into dead code. @@ -304,7 +305,7 @@ } // CHECK: memref.dealloc %[[ALLOC_FOR_A]] : memref - // CHECK: return %[[ALLOC_FOR_A]] : memref + // CHECK: return %[[CASTED]] : memref return %r0, %r1: tensor, tensor } @@ -346,6 +347,7 @@ -> (tensor, tensor) { // CHECK: %[[ALLOC_FOR_A:.*]] = memref.alloc + // CHECK: %[[CASTED:.*]] = memref.cast %[[ALLOC_FOR_A]] // CHECK: linalg.copy(%[[A]], %[[ALLOC_FOR_A]]) // CHECK: %[[svA:.*]] = memref.subview %[[ALLOC_FOR_A]][0] [4] [1] @@ -369,7 +371,7 @@ } // CHECK: memref.dealloc %[[ALLOC_FOR_A]] : memref - // CHECK: return %[[ALLOC_FOR_A]] : memref + // CHECK: return %[[CASTED]] : memref return %r0#0, %r0#1: tensor, tensor } diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -6391,6 +6391,7 @@ includes = ["include"], deps = [ ":BufferizableOpInterface", + ":BufferizationDialect", ":IR", ":SCFDialect", ":Support",