diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -427,6 +427,8 @@ /// BufferizationState provides helper functions for performing bufferization /// rewrites and handling memref buffers. struct BufferizationState { + enum ForceInPlacability { FORCE_INPLACE, FORCE_OUT_OF_PLACE }; + BufferizationState(const AnalysisState &analysisState) : analysisState(analysisState) {} @@ -448,11 +450,19 @@ /// Return the buffer (memref) for a given OpOperand (tensor). Allocate /// a new buffer and copy over data from the existing buffer if out-of-place /// bufferization was decided. + /// + /// Whether a buffer is in-place or out-of-place is queried from the analysis + /// state. Some analyses may always conservatively opt for out-of-place + /// bufferization. Inplacability decisions can be overridden with the optional + /// `overrideInPlace` parameter. FailureOr getBuffer(RewriterBase &rewriter, OpOperand &opOperand, - bool forceInPlace = false, + Optional overrideInPlace = None, Optional customCopyInsertionPoint = None); + /// Return the buffer type for a given OpOperand (tensor) after bufferization. + BaseMemRefType getBufferType(OpOperand &opOperand) const; + /// Return a reference to the BufferizationOptions. const BufferizationOptions &getOptions() const { return analysisState.getOptions(); diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -1295,7 +1295,14 @@ "ArrayRef":$reassociation)> ]; - let extraClassDeclaration = commonExtraClassDeclaration; + let extraClassDeclaration = commonExtraClassDeclaration # [{ + /// Return `true` if this source MemRef type is guaranteed to be collapsible + /// according to the given reassociation indices. In the presence of dynamic + /// strides this is usually not the case. + static bool isGuaranteedCollapsible( + MemRefType srcType, ArrayRef reassociation); + }]; + let hasVerifier = 1; } diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -247,12 +247,12 @@ tensor); } -/// Return the result buffer (memref) for a given OpResult (tensor). Allocate +/// Return the buffer (memref) for a given OpOperand (tensor). Allocate /// a new buffer and copy over data from the existing buffer if out-of-place -/// bufferization is necessary. +/// bufferization was decided. FailureOr BufferizationState::getBuffer(RewriterBase &rewriter, OpOperand &opOperand, - bool forceInPlace, + Optional overrideInPlace, Optional customCopyInsertionPoint) { const BufferizationOptions &options = analysisState.getOptions(); OpBuilder::InsertionGuard guard(rewriter); @@ -263,7 +263,11 @@ Value operand = opOperand.get(); Value operandBuffer = lookupBuffer(rewriter, operand, options); - if (forceInPlace || analysisState.isInPlace(opOperand)) + // Can `operandBuffer` be used directly or do we need a copy? + bool inplace = + overrideInPlace != FORCE_OUT_OF_PLACE && + (overrideInPlace == FORCE_INPLACE || analysisState.isInPlace(opOperand)); + if (inplace) return operandBuffer; // Bufferizing out-of-place: Allocate a new buffer. @@ -317,6 +321,18 @@ return resultBuffer; } +/// Return the buffer type for a given OpOperand (tensor) after bufferization. +BaseMemRefType BufferizationState::getBufferType(OpOperand &opOperand) const { + Value tensor = opOperand.get(); + auto tensorType = tensor.getType().dyn_cast(); + assert(tensorType && "unexpected non-tensor type"); + + if (auto toTensorOp = tensor.getDefiningOp()) + return toTensorOp.memref().getType().cast(); + + return getMemRefType(tensorType, getOptions()); +} + void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, ValueRange values) { diff --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp @@ -48,8 +48,9 @@ continue; } // Input operands are never written to. - newInputBuffers.push_back( - *state.getBuffer(rewriter, *opOperand, /*forceInPlace=*/true)); + newInputBuffers.push_back(*state.getBuffer( + rewriter, *opOperand, + BufferizationState::ForceInPlacability::FORCE_INPLACE)); } // New output operands for the cloned op. diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1812,10 +1812,12 @@ /// /// Note: All collapsed dims in a reassociation group must be contiguous. It is /// not possible to check this by inspecting a MemRefType in the general case. -/// But it is assumed. If this is not the case, the behavior is undefined. +/// If non-contiguity cannot be checked statically, the collapse is assumed to +/// be valid (and thus accepted by this function) unless `strict = true`. static FailureOr computeCollapsedLayoutMap(MemRefType srcType, - ArrayRef reassociation) { + ArrayRef reassociation, + bool strict = false) { int64_t srcOffset; SmallVector srcStrides; auto srcShape = srcType.getShape(); @@ -1837,11 +1839,26 @@ auto stride = Wrapper::stride(resultStrides[resultStrideIndex--]); for (int64_t idx : llvm::reverse(trailingReassocs)) { stride = stride * Wrapper::size(srcShape[idx]); - // Both are either static strides of the same value, or both are dynamic. - // The dynamic case is best effort atm : we can't check it statically. - // One exception to the dynamic check is when the srcShape is `1`, in - // which case it can never produce a non-contiguity. - if (stride != Wrapper::stride(srcStrides[idx - 1]) && srcShape[idx] != 1) + + // Both source and result stride must have the same static value. In that + // case, we can be sure, that the dimensions are collapsible (because they + // are contiguous). + // + // One special case is when the srcShape is `1`, in which case it can + // never produce non-contiguity. + if (srcShape[idx] == 1) + continue; + + // If `strict = false` (default during op verification), we accept cases + // where one or both strides are dynamic. This is best effort: We reject + // ops where obviously non-contiguous dims are collapsed, but accept ops + // where we cannot be sure statically. Such ops may fail at runtime. See + // the op documentation for details. + auto srcStride = Wrapper::stride(srcStrides[idx - 1]); + if (strict && (stride.saturated || srcStride.saturated)) + return failure(); + + if (!stride.saturated && !srcStride.saturated && stride != srcStride) return failure(); } } @@ -1849,6 +1866,16 @@ srcType.getContext()); } +bool ExpandShapeOp::isGuaranteedCollapsible( + MemRefType srcType, ArrayRef reassociation) { + // MemRefs with standard layout are always collapsible. + if (srcType.getLayout().isIdentity()) + return true; + + return succeeded(computeCollapsedLayoutMap(srcType, reassociation, + /*strict=*/true)); +} + static MemRefType computeCollapsedType(MemRefType srcType, ArrayRef reassociation) { diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -109,12 +109,12 @@ BufferizationState &state) const { auto collapseShapeOp = cast(op); RankedTensorType tensorResultType = collapseShapeOp.getResultType(); - Value buffer = - *state.getBuffer(rewriter, collapseShapeOp->getOpOperand(0) /*src*/); + OpOperand &srcOperand = collapseShapeOp->getOpOperand(0) /*src*/; + auto bufferType = state.getBufferType(srcOperand).cast(); if (tensorResultType.getRank() == 0) { // 0-d collapses must go through a different op builder. - auto bufferType = buffer.getType().cast(); + Value buffer = *state.getBuffer(rewriter, srcOperand); MemRefType resultType; if (bufferType.getLayout().isIdentity()) { @@ -141,6 +141,18 @@ return success(); } + // If the dims are not collapsible (due to an incompatible source layout + // map), force an out-of-place bufferization, i.e., a buffer copy. This + // newly allocated buffer will have no layout map and thus be collapsible. + bool canBeCollapsed = memref::ExpandShapeOp::isGuaranteedCollapsible( + bufferType, collapseShapeOp.getReassociationIndices()); + Optional overrideInPlace = + canBeCollapsed + ? None + : Optional( + BufferizationState::ForceInPlacability::FORCE_OUT_OF_PLACE); + Value buffer = *state.getBuffer(rewriter, srcOperand, overrideInPlace); + // Result type is inferred by the builder. replaceOpWithNewBufferizedOp( rewriter, op, buffer, collapseShapeOp.getReassociationIndices()); @@ -248,9 +260,12 @@ BufferizationState &state) const { auto extractSliceOp = cast(op); Location loc = extractSliceOp.getLoc(); + + // Even if this op was decided to bufferize out-of-place, do not insert the + // buffer copy yet. This is done later in this function. Value srcMemref = *state.getBuffer(rewriter, extractSliceOp->getOpOperand(0) /*source*/, - /*forceInPlace=*/true); + BufferizationState::ForceInPlacability::FORCE_INPLACE); auto srcMemrefType = srcMemref.getType().cast(); auto dstTensorType = extractSliceOp.result().getType().cast(); diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir --- a/mlir/test/Dialect/Tensor/bufferize.mlir +++ b/mlir/test/Dialect/Tensor/bufferize.mlir @@ -384,3 +384,20 @@ %1 = tensor.collapse_shape %0 [] : tensor<1xi32> into tensor return %1 : tensor } + +// CHECK-LABEL: func @tensor.collapse_shape_of_slice2( +func @tensor.collapse_shape_of_slice2( + %arg0: tensor, %o1: index, %o2: index, %o3: index, %o4: index) + -> tensor<87x63648xi64> { + // CHECK: %[[subview:.*]] = memref.subview %{{.*}} : memref to memref<87x78x68x12xi64, #{{.*}}> + %0 = tensor.extract_slice %arg0[%o1, %o2, %o3, %o4] [87, 78, 68, 12] [1, 1, 1, 1] : tensor to tensor<87x78x68x12xi64> + + // This memref may not be collapsible, so the buffer must be copied to get rid + // of the layout map. + // CHECK: %[[alloc:.*]] = memref.alloc() {{.*}} : memref<87x78x68x12xi64> + // CHECK: memref.copy %[[subview]], %[[alloc]] + // CHECK: memref.collapse_shape %[[alloc]] [ + // CHECK-SAME: [0], [1, 2, 3]] : memref<87x78x68x12xi64> into memref<87x63648xi64> + %1 = tensor.collapse_shape %0 [[0], [1, 2, 3]] : tensor<87x78x68x12xi64> into tensor<87x63648xi64> + return %1 : tensor<87x63648xi64> +}