diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h @@ -46,6 +46,13 @@ /// with differing element types or memory spaces. FailureOr castOrReallocMemRefValue(OpBuilder &b, Value value, MemRefType type); + +/// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the +/// to_memref op are different, a memref.cast is needed. +LogicalResult foldToMemrefToTensorPair(RewriterBase &rewriter, + ToMemrefOp toMemref, + bool allowSameType = true); + } // namespace bufferization } // namespace mlir diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -21,10 +21,6 @@ MemRefType destType) { auto srcType = value.getType().cast(); - // Casting to the same type, nothing to do. - if (srcType == destType) - return value; - // Element type, rank and memory space must match. if (srcType.getElementType() != destType.getElementType()) return failure(); @@ -79,6 +75,55 @@ return copy; } +/// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the +/// to_memref op are different, a memref.cast is needed. +LogicalResult mlir::bufferization::foldToMemrefToTensorPair( + RewriterBase &rewriter, ToMemrefOp toMemref, bool allowSameType) { + auto memrefToTensor = toMemref.tensor().getDefiningOp(); + if (!memrefToTensor) + return failure(); + + Type srcType = memrefToTensor.memref().getType(); + Type destType = toMemref.getType(); + + // Directly rewrite if the type did not change. + if (srcType == destType) { + // Function can be configured to only handle cases where a cast is needed. + if (!allowSameType) + return failure(); + rewriter.replaceOp(toMemref, memrefToTensor.memref()); + return success(); + } + + auto rankedSrcType = srcType.dyn_cast(); + auto rankedDestType = destType.dyn_cast(); + auto unrankedSrcType = srcType.dyn_cast(); + + // Ranked memref -> Ranked memref cast. + if (rankedSrcType && rankedDestType) { + FailureOr replacement = castOrReallocMemRefValue( + rewriter, memrefToTensor.memref(), rankedDestType); + if (failed(replacement)) + return failure(); + + rewriter.replaceOp(toMemref, *replacement); + return success(); + } + + // Unranked memref -> Ranked memref cast: May require a copy. + // TODO: Not implemented at the moment. + if (unrankedSrcType && rankedDestType) + return failure(); + + // Unranked memref -> unranked memref cast + // Ranked memref -> unranked memref cast: No copy needed. + assert(memref::CastOp::areCastCompatible(srcType, destType) && + "expected that types are cast compatible"); + rewriter.replaceOpWithNewOp(toMemref, destType, + memrefToTensor.memref()); + return success(); +} + //===----------------------------------------------------------------------===// // CloneOp //===----------------------------------------------------------------------===// @@ -249,51 +294,6 @@ } }; -/// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the -/// to_memref op are different, a memref.cast is needed. -static LogicalResult foldToMemrefToTensorPair(RewriterBase &rewriter, - ToMemrefOp toMemref, - bool allowSameType = true) { - auto memrefToTensor = toMemref.tensor().getDefiningOp(); - if (!memrefToTensor) - return failure(); - - Type srcType = memrefToTensor.memref().getType(); - Type destType = toMemref.getType(); - - // Function can be configured to only handle cases where a cast is needed. - if (!allowSameType && srcType == destType) - return failure(); - - auto rankedSrcType = srcType.dyn_cast(); - auto rankedDestType = destType.dyn_cast(); - auto unrankedSrcType = srcType.dyn_cast(); - - // Ranked memref -> Ranked memref cast. - if (rankedSrcType && rankedDestType) { - FailureOr replacement = castOrReallocMemRefValue( - rewriter, memrefToTensor.memref(), rankedDestType); - if (failed(replacement)) - return failure(); - - rewriter.replaceOp(toMemref, *replacement); - return success(); - } - - // Unranked memref -> Ranked memref cast: May require a copy. - // TODO: Not implemented at the moment. - if (unrankedSrcType && rankedDestType) - return failure(); - - // Unranked memref -> unranked memref cast - // Ranked memref -> unranked memref cast: No copy needed. - assert(memref::CastOp::areCastCompatible(srcType, destType) && - "expected that types are cast compatible"); - rewriter.replaceOpWithNewOp(toMemref, destType, - memrefToTensor.memref()); - return success(); -} - /// Canonicalize bufferization.to_tensor + bufferization.to_memref to /// memref.cast when type mismatches prevent `ToMemrefOp::fold` to kick in. struct TensorLoadToMemref : public OpRewritePattern { diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -242,65 +242,6 @@ return hasTensorResult || hasTensorOperand; } -/// Rewrite pattern that bufferizes bufferizable ops. -struct BufferizationPattern - : public OpInterfaceRewritePattern { - BufferizationPattern(MLIRContext *context, BufferizationState &state, - PatternBenefit benefit = 1) - : OpInterfaceRewritePattern(context, benefit), - state(&state) {} - - LogicalResult matchAndRewrite(BufferizableOpInterface bufferizableOp, - PatternRewriter &rewriter) const override { - const BufferizationOptions &options = state->getOptions(); - - // No tensors => no buffers. - if (!hasTensorSemantics(bufferizableOp.getOperation())) - return failure(); - if (!options.isOpAllowed(bufferizableOp.getOperation())) - return failure(); - return bufferizableOp.bufferize(rewriter, *state); - } - -private: - BufferizationState *const state; -}; - -/// Check the result of bufferization. Return an error if an op was not -/// bufferized, unless partial bufferization is allowed. -static LogicalResult -checkBufferizationResult(Operation *op, const BufferizationOptions &options) { - if (!options.allowUnknownOps) { - // Check if all ops were bufferized. - LogicalResult status = success(); - op->walk([&](Operation *op) { - if (!hasTensorSemantics(op)) - return WalkResult::advance(); - - // Bufferization dialect ops will canonicalize away if all other ops are - // bufferized. - if (isa(op)) - return WalkResult::advance(); - - // Ops that are not in the allow list can be ignored. - if (!options.isOpAllowed(op)) - return WalkResult::advance(); - - // Ops without any uses and no side effects will fold away. - if (op->getUses().empty() && MemoryEffectOpInterface::hasNoEffect(op)) - return WalkResult::advance(); - - status = op->emitError("op was not bufferized"); - return WalkResult::interrupt(); - }); - - if (failed(status)) - return status; - } - - return success(); -} - LogicalResult bufferization::finalizeBuffers(Operation *op, const BufferizationOptions &options) { @@ -335,35 +276,131 @@ return success(); } +namespace { +/// A rewriter that keeps track of extra information during bufferization. +class BufferizationRewriter : public IRRewriter { +public: + BufferizationRewriter(MLIRContext *ctx, DenseSet &erasedOps, + DenseSet &toMemrefOps, + SmallVector &worklist) + : IRRewriter(ctx), erasedOps(erasedOps), toMemrefOps(toMemrefOps), + worklist(worklist) {} + +protected: + void notifyOperationRemoved(Operation *op) override { + IRRewriter::notifyOperationRemoved(op); + erasedOps.insert(op); + } + + void notifyOperationInserted(Operation *op) override { + IRRewriter::notifyOperationInserted(op); + + // Keep track of to_memref ops. + if (isa(op)) { + toMemrefOps.insert(op); + return; + } + + // Skip to_tensor ops. + if (isa(op)) + return; + + // A new bufferizable op was inserted. Add it to the worklist. + if (hasTensorSemantics(op)) + worklist.push_back(op); + } + +private: + /// A set of all erased ops. + DenseSet &erasedOps; + + /// A set of all to_memref ops. + DenseSet &toMemrefOps; + + /// The list of bufferizable ops. + SmallVector &worklist; +}; +} // namespace + LogicalResult bufferization::bufferizeOp(Operation *op, BufferizationState &bufferizationState) { - // Bufferize the op and its nested ops. - RewritePatternSet patterns(op->getContext()); - patterns.add(patterns.getContext(), bufferizationState); - - // Bufferize ops top-to-bottom. When creating a new op, we should ideally - // know the exact memref type of all operands. Otherwise, we have to use a - // memref type with a fully dynamic layout map, which has to canonicalize - // away. This is less efficient. + const auto &options = bufferizationState.getOptions(); + + // Keep track of to_memref ops. + DenseSet toMemrefOps; + op->walk([&](ToMemrefOp toMemrefOp) { toMemrefOps.insert(toMemrefOp); }); + + // Gather all bufferizable ops in top-to-bottom order. // - // Note: If "fullyDynamicLayoutMaps = false", we may have to insert buffer - // copies to fold ("finalize") to_memref(to_tensor(x)) ops with non-cast- - // compatible layout maps when doing a traversal other than top-to-bottom. - // There are currently no canonicalization patterns to fold these away. - GreedyRewriteConfig config; - config.useTopDownTraversal = true; - - // TODO: Perform a preorder walk instead of the greedy pattern rewriter. This - // would be more efficient because every bufferization pattern is guaranteed - // to apply only a single time (otherwise, an assertion would be triggered). - // However, there are restrictions wrt. erasing ops during a preorder walk, - // which would likely require a larger refactoring. - if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config))) - return failure(); + // We should ideally know the exact memref type of all operands when + // bufferizing an op. (This is the case when bufferizing top-to-bottom.) + // Otherwise, we have to use a memref type with a fully dynamic layout map, + // which has to canonicalize away. This is less efficient. + // + // If "fullyDynamicLayoutMaps = false", we would have to insert buffer copies + // to fold ("finalize") to_memref(to_tensor(x)) ops with non-cast-compatible + // layout maps when doing a traversal other than top-to-bottom. These would + // not easily fold away. + SmallVector worklist; + op->walk([&](Operation *op) { + if (hasTensorSemantics(op)) + worklist.push_back(op); + }); - if (failed(checkBufferizationResult(op, bufferizationState.getOptions()))) - return failure(); + // Keep track of all erased ops. + DenseSet erasedOps; + + // Bufferize all ops. + BufferizationRewriter rewriter(op->getContext(), erasedOps, toMemrefOps, + worklist); + for (unsigned i = 0; i < worklist.size(); ++i) { + Operation *op = worklist[i]; + // Skip ops that were erased. + if (erasedOps.contains(op)) + continue; + // Skip ops that are not bufferizable. + auto bufferizableOp = dyn_cast(op); + if (!bufferizableOp) + continue; + // Continue ops that are not allowed. + if (!options.isOpAllowed(op)) + continue; + // Bufferize the op. + rewriter.setInsertionPoint(op); + (void)bufferizableOp.bufferize(rewriter, bufferizationState); + } + + // Fold all to_memref(to_tensor(x)) pairs. + for (Operation *op : toMemrefOps) { + if (erasedOps.contains(op)) + continue; + rewriter.setInsertionPoint(op); + (void)bufferization::foldToMemrefToTensorPair(rewriter, + cast(op)); + } + + /// Check the result of bufferization. Return an error if an op was not + /// bufferized, unless partial bufferization is allowed. + if (bufferizationState.getOptions().allowUnknownOps) + return success(); + + for (Operation *op : worklist) { + // Skip ops that are entirely gone. + if (erasedOps.contains(op)) + continue; + // Ops that no longer have tensor semantics (because they were updated + // in-place) are allowed. + if (!hasTensorSemantics(op)) + continue; + // Continue ops that are not allowed. + if (!options.isOpAllowed(op)) + continue; + // Ops without any uses and no side effects will fold away. + if (op->getUses().empty() && MemoryEffectOpInterface::hasNoEffect(op)) + continue; + return op->emitError("op was not bufferized"); + } return success(); } diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir @@ -884,8 +884,8 @@ // CHECK: %[[cloned:.*]] = bufferization.clone %[[t]] // CHECK: %[[for:.*]] = scf.for {{.*}} iter_args(%[[iter:.*]] = %[[cloned]]) // This alloc is for the linalg.init_tensor. -// CHECK: %[[alloc2:.*]] = memref.alloc(%{{.*}}) -// CHECK: memref.dealloc %[[iter]] +// CHECK-DAG: %[[alloc2:.*]] = memref.alloc(%{{.*}}) +// CHECK-DAG: memref.dealloc %[[iter]] // This alloc is for the scf.yield. // CHECK: %[[alloc3:.*]] = memref.alloc(%{{.*}}) // CHECK: memref.copy %[[alloc2]], %[[alloc3]] diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir --- a/mlir/test/Dialect/Tensor/bufferize.mlir +++ b/mlir/test/Dialect/Tensor/bufferize.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -tensor-bufferize | FileCheck %s +// RUN: mlir-opt %s -tensor-bufferize -cse | FileCheck %s // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> // CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1)[s0] -> (d0 * 20 + s0 + d1)> @@ -72,14 +72,6 @@ return %0 : f32 } -// CHECK-LABEL: func @tensor.from_elements_no_elements() -> tensor<0xindex> { -// CHECK: %[[RET:.*]] = arith.constant dense<> : tensor<0xindex> -// CHECK: return %[[RET]] : tensor<0xindex> -func.func @tensor.from_elements_no_elements() -> tensor<0xindex> { - %0 = tensor.from_elements : tensor<0xindex> - return %0 : tensor<0xindex> -} - // CHECK-LABEL: func @tensor.from_elements_0d( // CHECK-SAME: %[[ELEM0:.*]]: index) -> tensor { // CHECK: %[[MEMREF:.*]] = memref.alloc() {{.*}} : memref @@ -185,8 +177,8 @@ // CHECK-SAME: %[[DYNAMIC_EXTENT:.*]]: index) -> tensor { // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[CASTED:.*]] = bufferization.to_memref %[[ARG]] : memref<*xf32> -// CHECK: %[[MEMREF:.*]] = memref.alloc(%[[DYNAMIC_EXTENT]]) {{.*}} : memref +// CHECK-DAG: %[[CASTED:.*]] = bufferization.to_memref %[[ARG]] : memref<*xf32> +// CHECK-DAG: %[[MEMREF:.*]] = memref.alloc(%[[DYNAMIC_EXTENT]]) {{.*}} : memref // CHECK: scf.parallel (%[[I:.*]]) = (%[[C0]]) to (%[[DYNAMIC_EXTENT]]) step (%[[C1]]) { // CHECK: %[[ELEM:.*]] = memref.dim %[[CASTED]], %[[I]] : memref<*xf32> // CHECK: store %[[ELEM]], %[[MEMREF]][%[[I]]] : memref @@ -212,7 +204,7 @@ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index -// CHECK: %[[MEMREF:.*]] = memref.alloc(%[[DYNAMIC_EXTENT]]) {{.*}} : memref<16x?xindex> +// CHECK-DAG: %[[MEMREF:.*]] = memref.alloc(%[[DYNAMIC_EXTENT]]) {{.*}} : memref<16x?xindex> // CHECK: scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]]) to (%[[C16]], %[[DYNAMIC_EXTENT]]) step (%[[C1]], %[[C1]]) { // CHECK: %[[VAL_7:.*]] = arith.addi %[[I]], %[[J]] : index // CHECK: store %[[VAL_7]], %[[MEMREF]][%[[I]], %[[J]]] : memref<16x?xindex> @@ -278,8 +270,8 @@ // CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index // CHECK-DAG: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref // CHECK-DAG: %[[m2:.*]] = bufferization.to_memref %[[t2]] : memref - // CHECK: %[[dim0:.*]] = memref.dim %[[m1]], %[[c0]] - // CHECK: %[[dim1:.*]] = memref.dim %[[m1]], %[[c1]] + // CHECK-DAG: %[[dim0:.*]] = memref.dim %[[m1]], %[[c0]] + // CHECK-DAG: %[[dim1:.*]] = memref.dim %[[m1]], %[[c1]] // CHECK: %[[alloc:.*]] = memref.alloc(%[[dim0]], %[[dim1]]) // CHECK: memref.copy %[[m1]], %[[alloc]] // CHECK: %[[subview:.*]] = memref.subview %[[alloc]][%[[idx1]], 5] [%[[idx2]], 10] [1, 1]