diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td @@ -24,33 +24,49 @@ //===----------------------------------------------------------------------===// def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor", - [BufferizableOpInterface, + [AttrSizedOperandSegments, BufferizableOpInterface, DeclareOpInterfaceMethods]> { let summary = "buffer allocation in tensor land"; let description = [{ `bufferization.alloc_tensor` is an operation that bufferizes to a buffer - allocation of a given shape. The shape could be dynamic or static. - Reading from the result of an `alloc_tensor` op yields an undefined value. + allocation of a given shape. The shape could be dynamic or static. The + optional `copy` operand specifies the contents of the tensors. If no `copy` + operand is specified, reading from the result of an `alloc_tensor` op yields + an undefined value. + + If `copy` is specified, no dynamic sizes should be passed, since they are + the same as the dynamic sizes of the `copy` operand. + + The optional `escape` attribute indicates whether the buffer escapes the + parent block or not. In the latter case, the buffer is deallocated at the + of the block (during bufferiztion). In the former case, the buffer is not + deallocated and leaks. `alloc_tensor` is a helper op for bufferization. It marks the beginning of a new tensor SSA use-def chain and is used to control in-place bufferization decisions during One-Shot Bufferize. }]; - let arguments = (ins Variadic:$dynamicSizes); + let arguments = (ins Variadic:$dynamicSizes, + Optional:$copy, + OptionalAttr:$escape); let results = (outs AnyTensor:$result); - let assemblyFormat = "`(`$dynamicSizes`)` attr-dict `:` type($result)"; - let extraClassDeclaration = [{ LogicalResult bufferize(RewriterBase &rewriter, BufferizationState &state); - bool isMemoryWrite(OpResult opResult, const AnalysisState &state) const { - // AllocTensorOps allocate but do not write. - return false; - } + bool isMemoryWrite(OpResult opResult, const AnalysisState &state); + + bool bufferizesToMemoryRead(OpOperand &opOperand, + const AnalysisState &state); + + bool bufferizesToMemoryWrite(OpOperand &opOperand, + const AnalysisState &state); + + SmallVector getAliasingOpResult( + OpOperand &opOperand, const AnalysisState &state); RankedTensorType getType() { return getResult().getType().cast(); @@ -65,6 +81,7 @@ // the tensor at dimension `idx`. Asserts that the shape is // dynamic at that `idx`. unsigned getIndexOfDynamicSize(unsigned idx) { + assert(!copy() && "no dim sizes specified when copying a tensor"); assert(isDynamicDim(idx) && "expected dynamic size"); ArrayRef shape = getType().getShape(); return std::count_if( @@ -74,9 +91,7 @@ // Return the Value of the dynamic size of the tensor at dimension // `idx`. Asserts that the shape is dynamic at that `idx. - Value getDynamicSize(unsigned idx) { - return getOperand(getIndexOfDynamicSize(idx)); - } + Value getDynamicSize(OpBuilder &b, unsigned idx); // Assert that the size of the result tensor is static at `idx` // and return the shape. @@ -87,6 +102,7 @@ }]; let hasCanonicalizer = 1; + let hasCustomAssemblyFormat = 1; let hasVerifier = 1; } diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -139,16 +139,58 @@ if (getOperation()->getUses().empty()) return success(); - FailureOr alloc = state.createAlloc(rewriter, getLoc(), getResult()); + Optional dealloc = llvm::None; + if (escape().hasValue()) + dealloc = !*escape(); + FailureOr alloc = + state.createAlloc(rewriter, getLoc(), getResult(), dealloc); if (failed(alloc)) return failure(); + if (copy()) { + FailureOr copyValueBuffer = state.getBuffer( + rewriter, getOperation()->getOpOperand(getNumOperands() - 1)); + if (failed(copyValueBuffer)) + return failure(); + if (failed(state.getOptions().createMemCpy(rewriter, getLoc(), + *copyValueBuffer, *alloc))) + return failure(); + } replaceOpWithBufferizedValues(rewriter, getOperation(), *alloc); return success(); } +bool AllocTensorOp::isMemoryWrite(OpResult opResult, + const AnalysisState &state) { + // AllocTensorOps do not write unless they have a `copy` value. + return static_cast(copy()); +} + +bool AllocTensorOp::bufferizesToMemoryRead(OpOperand &opOperand, + const AnalysisState &state) { + assert(opOperand.getOperandNumber() == getNumOperands() - 1 && + "expected copy operand"); + return true; +} + +bool AllocTensorOp::bufferizesToMemoryWrite(OpOperand &opOperand, + const AnalysisState &state) { + assert(opOperand.getOperandNumber() == getNumOperands() - 1 && + "expected copy operand"); + return false; +} + +SmallVector +AllocTensorOp::getAliasingOpResult(OpOperand &opOperand, + const AnalysisState &state) { + // This is a new allocation. It does not alias with any other buffer. + return {}; +} + LogicalResult AllocTensorOp::verify() { - if (getType().getNumDynamicDims() != - static_cast(dynamicSizes().size())) + if (copy() && !dynamicSizes().empty()) + return emitError("expected no dynamic sizes when copying a tensor"); + if (!copy() && getType().getNumDynamicDims() != + static_cast(dynamicSizes().size())) return emitError("expected ") << getType().getNumDynamicDims() << " dynamic sizes"; return success(); @@ -171,6 +213,8 @@ LogicalResult matchAndRewrite(AllocTensorOp op, PatternRewriter &rewriter) const override { + if (op.copy()) + return failure(); SmallVector newShape = llvm::to_vector(op.getType().getShape()); SmallVector newDynamicSizes; unsigned int dynValCounter = 0; @@ -189,8 +233,9 @@ newShape, op.getType().getElementType(), op.getType().getEncoding()); if (newType == op.getType()) return failure(); - auto newOp = - rewriter.create(op.getLoc(), newType, newDynamicSizes); + auto newOp = rewriter.create( + op.getLoc(), newType, newDynamicSizes, /*copy=*/Value(), + /*escape=*/op.escapeAttr()); rewriter.replaceOpWithNewOp(op, op.getType(), newOp); return success(); } @@ -207,8 +252,8 @@ return failure(); if (!allocTensorOp.getType().isDynamicDim(*maybeConstantIndex)) return failure(); - rewriter.replaceOp(dimOp, - allocTensorOp.getDynamicSize(*maybeConstantIndex)); + rewriter.replaceOp( + dimOp, allocTensorOp.getDynamicSize(rewriter, *maybeConstantIndex)); return success(); } }; @@ -224,7 +269,7 @@ auto shapes = llvm::to_vector<4>(llvm::map_range( llvm::seq(0, getType().getRank()), [&](int64_t dim) -> Value { if (isDynamicDim(dim)) - return getDynamicSize(dim); + return getDynamicSize(builder, dim); return builder.create(getLoc(), getStaticSize(dim)); })); @@ -232,6 +277,59 @@ return success(); } +ParseResult AllocTensorOp::parse(OpAsmParser &parser, OperationState &result) { + SmallVector dynamicSizesOperands; + if (parser.parseLParen() || parser.parseOperandList(dynamicSizesOperands) || + parser.parseRParen()) + return failure(); + ParseResult copyKeyword = parser.parseOptionalKeyword("copy"); + OpAsmParser::UnresolvedOperand copyOperand; + if (copyKeyword.succeeded()) + if (parser.parseLParen() || parser.parseOperand(copyOperand) || + parser.parseRParen()) + return failure(); + if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()) + return failure(); + + TensorType type; + if (parser.parseCustomTypeWithFallback(type)) + return failure(); + result.addTypes(type); + + Type indexType = parser.getBuilder().getIndexType(); + if (parser.resolveOperands(dynamicSizesOperands, indexType, result.operands)) + return failure(); + if (copyKeyword.succeeded()) + if (parser.resolveOperand(copyOperand, type, result.operands)) + return failure(); + result.addAttribute(AllocTensorOp::getOperandSegmentSizeAttr(), + parser.getBuilder().getI32VectorAttr( + {static_cast(dynamicSizesOperands.size()), + static_cast(copyKeyword.succeeded())})); + return success(); +} + +void AllocTensorOp::print(OpAsmPrinter &p) { + p << "(" << dynamicSizes() << ")"; + if (copy()) + p << " copy(" << copy() << ")"; + p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{ + AllocTensorOp::getOperandSegmentSizeAttr()}); + p << " : "; + auto type = result().getType(); + if (auto validType = type.dyn_cast<::mlir::TensorType>()) + p.printStrippedAttrOrType(validType); + else + p << type; +} + +Value AllocTensorOp::getDynamicSize(OpBuilder &b, unsigned idx) { + assert(isDynamicDim(idx) && "expected dynamic dim"); + if (copy()) + return b.create(getLoc(), copy(), idx); + return getOperand(getIndexOfDynamicSize(idx)); +} + //===----------------------------------------------------------------------===// // CloneOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/InitTensorToAllocTensor.cpp b/mlir/lib/Dialect/Linalg/Transforms/InitTensorToAllocTensor.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/InitTensorToAllocTensor.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/InitTensorToAllocTensor.cpp @@ -23,8 +23,8 @@ LogicalResult matchAndRewrite(InitTensorOp op, PatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp(op, op.getType(), - op.sizes()); + rewriter.replaceOpWithNewOp( + op, op.getType(), op.sizes(), /*copy=*/Value(), /*escape=*/BoolAttr()); return success(); } }; diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir @@ -119,3 +119,22 @@ %1 = arith.select %c, %0, %t : tensor return %1 : tensor } + +// ----- + +// CHECK-LABEL: func @alloc_tensor_with_copy( +// CHECK-SAME: %[[t:.*]]: tensor<5xf32>) +// TODO: Add a test case with dynamic dim size. This is not possible at the +// moment because this would create a tensor op during bufferization. That is +// currently forbidden. +func.func @alloc_tensor_with_copy(%t: tensor<5xf32>) -> tensor<5xf32> { + // CHECK: %[[m:.*]] = bufferization.to_memref %[[t]] + // CHECK: %[[alloc:.*]] = memref.alloc() {{.*}} : memref<5xf32> + // CHECK: memref.copy %[[m]], %[[alloc]] + %0 = bufferization.alloc_tensor() copy(%t) : tensor<5xf32> + // CHECK: %[[r:.*]] = bufferization.to_tensor %[[alloc]] + // CHECK: memref.dealloc %[[alloc]] + // CHECK: return %[[r]] + return %0 : tensor<5xf32> +} + diff --git a/mlir/test/Dialect/Bufferization/canonicalize.mlir b/mlir/test/Dialect/Bufferization/canonicalize.mlir --- a/mlir/test/Dialect/Bufferization/canonicalize.mlir +++ b/mlir/test/Dialect/Bufferization/canonicalize.mlir @@ -224,7 +224,7 @@ return %1 : memref } // CHECK: %[[M:.+]] = bufferization.to_memref %[[ARG0]] : memref<4x6x16x32xi8> -// CHECK: %[[M1:.+]] = memref.cast %[[M]] +// CHECK: %[[M1:.+]] = memref.cast %[[M]] // CHECK-SAME: memref<4x6x16x32xi8> to memref // CHECK: return %[[M1]] : memref diff --git a/mlir/test/Dialect/Bufferization/ops.mlir b/mlir/test/Dialect/Bufferization/ops.mlir --- a/mlir/test/Dialect/Bufferization/ops.mlir +++ b/mlir/test/Dialect/Bufferization/ops.mlir @@ -22,3 +22,22 @@ %tensor = bufferization.to_tensor %buf : memref<2xf32> return %tensor : tensor<2xf32> } + +// CHECK-LABEL: func @test_alloc_tensor_op +func.func @test_alloc_tensor_op(%t: tensor, %sz: index) + -> tensor +{ + // CHECK: bufferization.alloc_tensor(%{{.*}}) : tensor + %0 = bufferization.alloc_tensor(%sz) : tensor + // CHECK: bufferization.alloc_tensor() copy(%{{.*}}) : tensor + %1 = bufferization.alloc_tensor() copy(%t) : tensor + // CHECK: bufferization.alloc_tensor() : tensor<5x6xf32> + %2 = bufferization.alloc_tensor() : tensor<5x6xf32> + // CHECK: bufferization.alloc_tensor(%{{.*}}, %{{.*}}) : tensor + %3 = bufferization.alloc_tensor(%sz, %sz) : tensor + // CHECK: bufferization.alloc_tensor() copy(%{{.*}}) {escape = true} : tensor + %4 = bufferization.alloc_tensor() copy(%t) {escape = true} : tensor + // CHECK: bufferization.alloc_tensor() copy(%{{.*}}) {escape = false} : tensor + %5 = bufferization.alloc_tensor() copy(%t) {escape = false} : tensor + return %1 : tensor +}