diff --git a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp @@ -32,8 +32,7 @@ } static LogicalResult -allocateBuffersForResults(Location loc, LinalgOp linalgOp, - linalg::GenericOpAdaptor &adaptor, +allocateBuffersForResults(Location loc, LinalgOp linalgOp, ValueRange outputs, SmallVectorImpl &resultBuffers, OpBuilder &b) { // Lazily compute loopRanges. SmallVector loopRanges; @@ -52,7 +51,7 @@ } auto tensorShape = tensorType.getShape(); auto memrefType = MemRefType::get(tensorShape, tensorType.getElementType()); - Value resultTensor = adaptor.outputs()[resultIndex]; + Value resultTensor = outputs[resultIndex]; // Clone output buffers whose value is actually used. if (linalgOp.payloadUsesValueFromOutputOperandIndex(resultIndex)) { @@ -138,8 +137,7 @@ namespace { -/// Generic conversion pattern that matches any LinalgOp. This avoids template -/// instantiating one pattern for each LinalgOp. +/// Conversion pattern that replaces `linalg.init_tensor` with allocation. class BufferizeInitTensorOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -155,6 +153,26 @@ } }; +/// Conversion pattern that bufferizes `linalg.fill` operation. +class BufferizeFillOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(FillOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + linalg::FillOpAdaptor adaptor(operands, op->getAttrDictionary()); + if (!op.output().getType().isa()) + return rewriter.notifyMatchFailure(op, + "operand must be of a tensor type"); + + rewriter.create(op.getLoc(), adaptor.output(), adaptor.value()); + rewriter.replaceOp(op, adaptor.output()); + + return success(); + } +}; + /// Generic conversion pattern that matches any LinalgOp. This avoids template /// instantiating one pattern for each LinalgOp. class BufferizeAnyLinalgOp : public ConversionPattern { @@ -178,7 +196,7 @@ Location loc = linalgOp.getLoc(); SmallVector newOutputBuffers; - if (failed(allocateBuffersForResults(loc, linalgOp, adaptor, + if (failed(allocateBuffersForResults(loc, linalgOp, adaptor.outputs(), newOutputBuffers, rewriter))) { linalgOp.emitOpError() << "Failed to allocate buffers for tensor results."; @@ -325,6 +343,7 @@ // TODO: Drop this once tensor constants work in standard. // clang-format off patterns.insert< + BufferizeFillOp, BufferizeInitTensorOp, SubTensorOpConverter, SubTensorInsertOpConverter diff --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir --- a/mlir/test/Dialect/Linalg/bufferize.mlir +++ b/mlir/test/Dialect/Linalg/bufferize.mlir @@ -265,3 +265,16 @@ return %t0, %t1: tensor, tensor } +// ----- + +// CHECK-LABEL: func @bufferize_fill( +// CHECK-SAME: %[[IN:.*]]: tensor +func @bufferize_fill(%arg0: tensor) -> tensor { + %c0 = constant 0.0 : f32 + // CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[IN]] : memref + // CHECK: linalg.fill(%[[MEMREF]], %cst) : memref, f32 + // CHECK: %[[TENSOR:.*]] = tensor_load %[[MEMREF]] : memref + // CHECK: return %[[TENSOR]] + %0 = linalg.fill(%arg0, %c0) : tensor, f32 -> tensor + return %0 : tensor +}