diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -222,6 +222,9 @@ return getResult().getType().cast(); } + // Infer the dynamic shape of the result tensor along each dim + SmallVector getResultTypeShapes(OpBuilder &b); + // Infer the shape of the result tensor given the static shapes // and element type of the result tensor. static RankedTensorType inferResultType(RankedTensorType sourceType, diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -964,8 +964,7 @@ builder); } -LogicalResult PadTensorOp::reifyReturnTypeShapesPerResultDim( - OpBuilder &b, SmallVectorImpl> &reifiedReturnShapes) { +SmallVector PadTensorOp::getResultTypeShapes(OpBuilder &b) { Location loc = getLoc(); auto lowPad = getMixedLowPad(); auto highPad = getMixedHighPad(); @@ -991,7 +990,12 @@ shapes.push_back(applyMapToValues( b, loc, AffineMap::get(1, numSymbols, expr), mapOperands)[0]); } - reifiedReturnShapes.emplace_back(std::move(shapes)); + return shapes; +} + +LogicalResult PadTensorOp::reifyReturnTypeShapesPerResultDim( + OpBuilder &b, SmallVectorImpl> &reifiedReturnShapes) { + reifiedReturnShapes.emplace_back(getResultTypeShapes(b)); return success(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp @@ -194,6 +194,80 @@ } }; +/// Returns the static/dynamic mixed sizes of the memref. +static SmallVector getMixedSizes(OpBuilder &b, Location loc, + Value memref) { + auto inputType = memref.getType().cast(); + auto inputShape = inputType.getShape(); + SmallVector sizeMixedValues; + for (int64_t i = 0; i < inputType.getRank(); ++i) { + if (inputShape[i] == ShapedType::kDynamicSize) { + Value dim = b.create(loc, memref, i); + sizeMixedValues.push_back(dim); + } else { + sizeMixedValues.push_back(b.getI64IntegerAttr(inputShape[i])); + } + } + return sizeMixedValues; +} + +/// Conversion pattern that bufferizes `linalg.pad_tensor` operation. +class BufferizePadTensorOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(PadTensorOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + Location loc = op->getLoc(); + linalg::PadTensorOpAdaptor adaptor(operands, op->getAttrDictionary()); + Value sourceMemRef = adaptor.source(); + assert(sourceMemRef.getType().isa()); + + auto sourceType = sourceMemRef.getType().cast(); + // Allocate the destination buffer + SmallVector shapes = op.getResultTypeShapes(rewriter); + SmallVector dynShapes(sourceType.getRank(), -1); + auto memrefType = MemRefType::get(dynShapes, sourceType.getElementType()); + Value destMemRef = + rewriter.create(loc, memrefType, shapes); + + // Get padding value and fill the destination buffer. + auto yieldOps = op.region().getOps(); + if (std::distance(yieldOps.begin(), yieldOps.end()) != 1) { + return rewriter.notifyMatchFailure(op, + "linalg.pad_tensor with more than one " + "padding value is not supported"); + } + Value paddingValue = (*yieldOps.begin()).values()[0]; + auto constOp = paddingValue.getDefiningOp(); + if (!constOp) { + return rewriter.notifyMatchFailure( + op, + "linalg.pad_tensor with non-constant padding value is not supported"); + } + if (constOp.getValue().isa()) { + return rewriter.notifyMatchFailure( + op, "linalg.pad_tensor with non-scalar constant padding value is not " + "supported"); + } + rewriter.create(loc, paddingValue, destMemRef); + + // Get the interior region + SmallVector sizes = + getMixedSizes(rewriter, loc, sourceMemRef); + SmallVector strides(sourceType.getRank(), + rewriter.getI64IntegerAttr(1)); + auto resultSubView = rewriter.create( + loc, destMemRef, op.getMixedLowPad(), sizes, strides); + // Copy input into the interior region + rewriter.create(loc, sourceMemRef, resultSubView); + auto newResultType = getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, newResultType, destMemRef); + return success(); + } +}; + /// Generic conversion pattern that matches any LinalgOp. This avoids template /// instantiating one pattern for each LinalgOp. class BufferizeAnyLinalgOp : public OpInterfaceConversionPattern { @@ -359,7 +433,8 @@ BufferizeTensorReshapeOp, BufferizeTensorReshapeOp, ExtractSliceOpConverter, - InsertSliceOpConverter + InsertSliceOpConverter, + BufferizePadTensorOp >(typeConverter, patterns.getContext()); // clang-format on } diff --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir --- a/mlir/test/Dialect/Linalg/bufferize.mlir +++ b/mlir/test/Dialect/Linalg/bufferize.mlir @@ -265,3 +265,30 @@ // CHECK-SAME: : memref<4x5xf32> into memref<20xf32> // CHECK: %[[TENSOR:.*]] = memref.tensor_load %[[RESHAPE]] : memref<20xf32> // CHECK: return %[[TENSOR]] + +// CHECK-LABEL: func @bufferize_pad_tensor( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x?x2x?xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: index) -> tensor<4x?x?x?xf32> { +// CHECK: %[[VAL_2:.*]] = constant 3 : index +// CHECK: %[[VAL_3:.*]] = constant 1 : index +// CHECK: %[[VAL_4:.*]] = constant 0.000000e+00 : f32 +// CHECK: %[[VAL_5:.*]] = memref.buffer_cast %[[VAL_0]] : memref<4x?x2x?xf32> +// CHECK: %[[VAL_6:.*]] = memref.dim %[[VAL_0]], %[[VAL_3]] : tensor<4x?x2x?xf32> +// CHECK: %[[VAL_7:.*]] = affine.apply #map0(){{\[}}%[[VAL_1]]] +// CHECK: %[[VAL_8:.*]] = memref.dim %[[VAL_0]], %[[VAL_2]] : tensor<4x?x2x?xf32> +// CHECK: %[[VAL_9:.*]] = affine.apply #map1(){{\[}}%[[VAL_1]], %[[VAL_8]]] +// CHECK: %[[VAL_10:.*]] = memref.alloc(%[[VAL_6]], %[[VAL_7]], %[[VAL_9]]) : memref<4x?x?x?xf32> +// CHECK: linalg.fill(%[[VAL_4]], %[[VAL_10]]) : f32, memref<4x?x?x?xf32> +// CHECK: %[[VAL_11:.*]] = memref.subview %[[VAL_10]][0, 0, %[[VAL_1]], 0] [4, %[[VAL_6]], 2, %[[VAL_8]]] [1, 1, 1, 1] : memref<4x?x?x?xf32> to memref<4x?x2x?xf32, #map2> +// CHECK: linalg.copy(%[[VAL_5]], %[[VAL_11]]) : memref<4x?x2x?xf32>, memref<4x?x2x?xf32, #map2> +// CHECK: %[[VAL_12:.*]] = memref.tensor_load %[[VAL_10]] : memref<4x?x?x?xf32> +// CHECK: return %[[VAL_12]] : tensor<4x?x?x?xf32> +func @bufferize_pad_tensor(%arg0: tensor<4x?x2x?xf32>, %arg1: index) -> tensor<4x?x?x?xf32> { + %c0 = constant 0 : index + %cst = constant 0.0 : f32 + %out = linalg.pad_tensor %arg0 low[%c0, %c0, %arg1, %c0] high[%c0, %c0, %c0, %arg1] { + ^bb0(%gen_arg1: index, %gen_arg2: index, %gen_arg3: index, %gen_arg4: index): // no predecessors + linalg.yield %cst : f32 + } : tensor<4x?x2x?xf32> to tensor<4x?x?x?xf32> + return %out : tensor<4x?x?x?xf32> +}