diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -1342,7 +1342,13 @@ OpBuilder<(ins "ArrayRef":$resultShape, "Value":$src, "ArrayRef":$reassociation)> ]; - let extraClassDeclaration = commonExtraClassDeclaration; + + let extraClassDeclaration = commonExtraClassDeclaration # [{ + static FailureOr computeExpandedType( + MemRefType srcType, ArrayRef resultShape, + ArrayRef reassociation); + }]; + let hasVerifier = 1; } @@ -1389,6 +1395,7 @@ Note: This op currently assumes that the inner strides are of the source/result layout map are the faster-varying ones. }]; + let builders = [ // Builders for a contracting reshape whose result type is computed from // `src` and `reassociation`. @@ -1422,12 +1429,16 @@ build($_builder, $_state, resultType, src, reassociationMaps, attrs); }]> ]; + let extraClassDeclaration = commonExtraClassDeclaration # [{ /// Return `true` if this source MemRef type is guaranteed to be collapsible /// according to the given reassociation indices. In the presence of dynamic /// strides this is usually not the case. static bool isGuaranteedCollapsible( MemRefType srcType, ArrayRef reassociation); + + static MemRefType computeCollapsedType( + MemRefType srcType, ArrayRef reassociation); }]; let hasVerifier = 1; diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1801,9 +1801,9 @@ return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides); } -static FailureOr -computeExpandedType(MemRefType srcType, ArrayRef resultShape, - ArrayRef reassociation) { +FailureOr ExpandShapeOp::computeExpandedType( + MemRefType srcType, ArrayRef resultShape, + ArrayRef reassociation) { if (srcType.getLayout().isIdentity()) { // If the source is contiguous (i.e., no layout map specified), so is the // result. @@ -1827,7 +1827,7 @@ // Only ranked memref source values are supported. auto srcType = src.getType().cast(); FailureOr resultType = - computeExpandedType(srcType, resultShape, reassociation); + ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation); // Failure of this assertion usually indicates a problem with the source // type, e.g., could not get strides/offset. assert(succeeded(resultType) && "could not compute layout"); @@ -1846,7 +1846,7 @@ return failure(); // Compute expected result type (including layout map). - FailureOr expectedResultType = computeExpandedType( + FailureOr expectedResultType = ExpandShapeOp::computeExpandedType( srcType, resultType.getShape(), getReassociationIndices()); if (failed(expectedResultType)) return emitOpError("invalid source layout map"); @@ -1943,9 +1943,8 @@ /*strict=*/true)); } -static MemRefType -computeCollapsedType(MemRefType srcType, - ArrayRef reassociation) { +MemRefType CollapseShapeOp::computeCollapsedType( + MemRefType srcType, ArrayRef reassociation) { SmallVector resultShape; resultShape.reserve(reassociation.size()); for (const ReassociationIndices &group : reassociation) { @@ -1979,7 +1978,8 @@ ArrayRef reassociation, ArrayRef attrs) { auto srcType = src.getType().cast(); - MemRefType resultType = computeCollapsedType(srcType, reassociation); + MemRefType resultType = + CollapseShapeOp::computeCollapsedType(srcType, reassociation); build(b, result, resultType, src, attrs); result.addAttribute(::mlir::getReassociationAttrName(), getReassociationIndicesAttribute(b, reassociation)); @@ -2039,9 +2039,9 @@ if (!CastOp::canFoldIntoConsumerOp(cast)) return failure(); - Type newResultType = - computeCollapsedType(cast.getOperand().getType().cast(), - op.getReassociationIndices()); + Type newResultType = CollapseShapeOp::computeCollapsedType( + cast.getOperand().getType().cast(), + op.getReassociationIndices()); if (newResultType == op.getResultType()) { rewriter.updateRootInPlace( diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -109,6 +109,29 @@ return BufferRelation::Equivalent; } + FailureOr + getBufferType(Operation *op, Value value, const BufferizationOptions &options, + const DenseMap &fixedTypes) const { + auto collapseShapeOp = cast(op); + auto maybeSrcBufferType = bufferization::getBufferType( + collapseShapeOp.getSrc(), options, fixedTypes); + if (failed(maybeSrcBufferType)) + return failure(); + auto srcBufferType = maybeSrcBufferType->cast(); + bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible( + srcBufferType, collapseShapeOp.getReassociationIndices()); + + if (!canBeCollapsed) { + // If dims cannot be collapsed, this op bufferizes to a new allocation. + RankedTensorType tensorResultType = collapseShapeOp.getResultType(); + return bufferization::getMemRefTypeWithStaticIdentityLayout( + tensorResultType, srcBufferType.getMemorySpaceAsInt()); + } + + return memref::CollapseShapeOp::computeCollapsedType( + srcBufferType, collapseShapeOp.getReassociationIndices()); + } + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto collapseShapeOp = cast(op); @@ -232,6 +255,23 @@ return BufferRelation::Equivalent; } + FailureOr + getBufferType(Operation *op, Value value, const BufferizationOptions &options, + const DenseMap &fixedTypes) const { + auto expandShapeOp = cast(op); + auto maybeSrcBufferType = bufferization::getBufferType( + expandShapeOp.getSrc(), options, fixedTypes); + if (failed(maybeSrcBufferType)) + return failure(); + auto srcBufferType = maybeSrcBufferType->cast(); + auto maybeResultType = memref::ExpandShapeOp::computeExpandedType( + srcBufferType, expandShapeOp.getResultType().getShape(), + expandShapeOp.getReassociationIndices()); + if (failed(maybeResultType)) + return failure(); + return *maybeResultType; + } + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto expandShapeOp = cast(op); diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir --- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir @@ -858,3 +858,21 @@ } return %r1#1 : tensor } + +// ----- + +// This is a regression test. Just check that the IR bufferizes. + +// CHECK-LABEL: func @buffer_type_of_collapse_shape +func.func @buffer_type_of_collapse_shape(%arg0: tensor) { + %true = arith.constant true + %0 = scf.while (%arg1 = %arg0) : (tensor) -> (tensor) { + scf.condition(%true) %arg1 : tensor + } do { + ^bb0(%_: tensor): + %3 = bufferization.alloc_tensor() : tensor<1xf64> + %16 = tensor.collapse_shape %3 [] : tensor<1xf64> into tensor + scf.yield %16 : tensor + } + return +}