diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -222,6 +222,9 @@ return getResult().getType().cast(); } + // Infer the dynamic shape of the result tensor along each dim. + SmallVector getResultTypeShapes(OpBuilder &b); + // Infer the shape of the result tensor given the static shapes // and element type of the result tensor. static RankedTensorType inferResultType(RankedTensorType sourceType, diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -964,8 +964,7 @@ builder); } -LogicalResult PadTensorOp::reifyReturnTypeShapesPerResultDim( - OpBuilder &b, SmallVectorImpl> &reifiedReturnShapes) { +SmallVector PadTensorOp::getResultTypeShapes(OpBuilder &b) { Location loc = getLoc(); auto lowPad = getMixedLowPad(); auto highPad = getMixedHighPad(); @@ -991,7 +990,12 @@ shapes.push_back(applyMapToValues( b, loc, AffineMap::get(1, numSymbols, expr), mapOperands)[0]); } - reifiedReturnShapes.emplace_back(std::move(shapes)); + return shapes; +} + +LogicalResult PadTensorOp::reifyReturnTypeShapesPerResultDim( + OpBuilder &b, SmallVectorImpl> &reifiedReturnShapes) { + reifiedReturnShapes.emplace_back(getResultTypeShapes(b)); return success(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp @@ -194,6 +194,80 @@ } }; +/// Returns the static/dynamic mixed sizes of the memref. +static SmallVector getMixedSizes(OpBuilder &b, Location loc, + Value memref) { + auto inputType = memref.getType().cast(); + auto inputShape = inputType.getShape(); + SmallVector sizeMixedValues; + for (int64_t i = 0; i < inputType.getRank(); ++i) { + if (inputShape[i] == ShapedType::kDynamicSize) { + Value dim = b.create(loc, memref, i); + sizeMixedValues.push_back(dim); + } else { + sizeMixedValues.push_back(b.getI64IntegerAttr(inputShape[i])); + } + } + return sizeMixedValues; +} + +/// Conversion pattern that bufferizes `linalg.pad_tensor` operation. +class BufferizePadTensorOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(PadTensorOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + Location loc = op->getLoc(); + PadTensorOp::Adaptor adaptor(operands, op->getAttrDictionary()); + Value sourceMemRef = adaptor.source(); + assert(sourceMemRef.getType().isa()); + + auto sourceType = sourceMemRef.getType().cast(); + // Allocate the destination buffer + SmallVector shapes = op.getResultTypeShapes(rewriter); + SmallVector dynShapes(sourceType.getRank(), -1); + auto memrefType = MemRefType::get(dynShapes, sourceType.getElementType()); + Value destMemRef = + rewriter.create(loc, memrefType, shapes); + + // Get padding value and fill the destination buffer. + auto yieldOps = op.region().getOps(); + if (!llvm::hasSingleElement(yieldOps)) { + return rewriter.notifyMatchFailure(op, + "linalg.pad_tensor with more than one " + "padding value is not supported"); + } + Value paddingValue = (*yieldOps.begin()).values()[0]; + auto constOp = paddingValue.getDefiningOp(); + if (!constOp) { + return rewriter.notifyMatchFailure( + op, + "linalg.pad_tensor with non-constant padding value is not supported"); + } + if (constOp.getValue().isa()) { + return rewriter.notifyMatchFailure( + op, "linalg.pad_tensor with non-scalar constant padding value is not " + "supported"); + } + rewriter.create(loc, paddingValue, destMemRef); + + // Get the interior region. + SmallVector sizes = + getMixedSizes(rewriter, loc, sourceMemRef); + SmallVector strides(sourceType.getRank(), + rewriter.getI64IntegerAttr(1)); + auto resultSubView = rewriter.create( + loc, destMemRef, op.getMixedLowPad(), sizes, strides); + // Copy input into the interior region. + rewriter.create(loc, sourceMemRef, resultSubView); + auto newResultType = getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, newResultType, destMemRef); + return success(); + } +}; + /// Generic conversion pattern that matches any LinalgOp. This avoids template /// instantiating one pattern for each LinalgOp. class BufferizeAnyLinalgOp : public OpInterfaceConversionPattern { @@ -359,7 +433,8 @@ BufferizeTensorReshapeOp, BufferizeTensorReshapeOp, ExtractSliceOpConverter, - InsertSliceOpConverter + InsertSliceOpConverter, + BufferizePadTensorOp >(typeConverter, patterns.getContext()); // clang-format on } diff --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir --- a/mlir/test/Dialect/Linalg/bufferize.mlir +++ b/mlir/test/Dialect/Linalg/bufferize.mlir @@ -265,3 +265,30 @@ // CHECK-SAME: : memref<4x5xf32> into memref<20xf32> // CHECK: %[[TENSOR:.*]] = memref.tensor_load %[[RESHAPE]] : memref<20xf32> // CHECK: return %[[TENSOR]] + +// CHECK-LABEL: func @bufferize_pad_tensor( +// CHECK-SAME: %[[IN:.*]]: tensor<4x?x2x?xf32>, +// CHECK-SAME: %[[PAD_DYNMIC:.*]]: index) -> tensor<4x?x?x?xf32> { +// CHECK: %[[C3:.*]] = constant 3 : index +// CHECK: %[[C1:.*]] = constant 1 : index +// CHECK: %[[C0_FLOAT:.*]] = constant 0.000000e+00 : f32 +// CHECK: %[[IN_MEMREF:.*]] = memref.buffer_cast %[[IN]] : memref<4x?x2x?xf32> +// CHECK: %[[DIM1:.*]] = memref.dim %[[IN]], %[[C1]] : tensor<4x?x2x?xf32> +// CHECK: %[[OUT_DIM2:.*]] = affine.apply #map0(){{\[}}%[[PAD_DYNMIC]]] +// CHECK: %[[DIM3:.*]] = memref.dim %[[IN]], %[[C3]] : tensor<4x?x2x?xf32> +// CHECK: %[[OUT_DIM3:.*]] = affine.apply #map1(){{\[}}%[[PAD_DYNMIC]], %[[DIM3]]] +// CHECK: %[[OUT_MEMREF:.*]] = memref.alloc(%[[DIM1]], %[[OUT_DIM2]], %[[OUT_DIM3]]) : memref<4x?x?x?xf32> +// CHECK: linalg.fill(%[[C0_FLOAT]], %[[OUT_MEMREF]]) : f32, memref<4x?x?x?xf32> +// CHECK: %[[OUT_INTERIOR:.*]] = memref.subview %[[OUT_MEMREF]][0, 0, %[[PAD_DYNMIC]], 0] [4, %[[DIM1]], 2, %[[DIM3]]] [1, 1, 1, 1] : memref<4x?x?x?xf32> to memref<4x?x2x?xf32, #map2> +// CHECK: linalg.copy(%[[IN_MEMREF]], %[[OUT_INTERIOR]]) : memref<4x?x2x?xf32>, memref<4x?x2x?xf32, #map2> +// CHECK: %[[OUT:.*]] = memref.tensor_load %[[OUT_MEMREF]] : memref<4x?x?x?xf32> +// CHECK: return %[[OUT]] : tensor<4x?x?x?xf32> +func @bufferize_pad_tensor(%arg0: tensor<4x?x2x?xf32>, %arg1: index) -> tensor<4x?x?x?xf32> { + %c0 = constant 0 : index + %cst = constant 0.0 : f32 + %out = linalg.pad_tensor %arg0 low[%c0, %c0, %arg1, %c0] high[%c0, %c0, %c0, %arg1] { + ^bb0(%gen_arg1: index, %gen_arg2: index, %gen_arg3: index, %gen_arg4: index): // no predecessors + linalg.yield %cst : f32 + } : tensor<4x?x2x?xf32> to tensor<4x?x?x?xf32> + return %out : tensor<4x?x?x?xf32> +} diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-padtensor.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-padtensor.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-padtensor.mlir @@ -0,0 +1,32 @@ +// RUN: mlir-opt %s -linalg-bufferize -std-bufferize \ +// RUN: -tensor-constant-bufferize -tensor-bufferize -func-bufferize \ +// RUN: -finalizing-bufferize \ +// RUN: -convert-linalg-to-loops -convert-scf-to-std -convert-linalg-to-llvm -convert-std-to-llvm | \ +// RUN: mlir-cpu-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ +// RUN: | FileCheck %s + +func @main() { + %const = constant dense<[[[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]]]> : tensor<1x2x3xf32> + %dynamic = tensor.cast %const: tensor<1x2x3xf32> to tensor<1x?x3xf32> + %offset = constant 2 : index + %cst = constant 2.3 : f32 + %c0 = constant 0 : index + %out = linalg.pad_tensor %dynamic low[%c0, %offset, %c0] high[%c0, %c0, %offset] { + ^bb0(%gen_arg1: index, %gen_arg2: index, %gen_arg3: index): // no predecessors + linalg.yield %cst : f32 + } : tensor<1x?x3xf32> to tensor<1x?x?xf32> + %unranked = tensor.cast %out: tensor<1x?x?xf32> to tensor<*xf32> + call @print_memref_f32(%unranked) : (tensor<*xf32>) -> () + + // CHECK: Unranked Memref base@ = {{0x[-9a-f]*}} + // CHECK-SAME: rank = 3 offset = 0 sizes = [1, 4, 5] strides = [20, 5, 1] data = + // CHECK-NEXT{LITERAL}: [[[2.3, 2.3, 2.3, 2.3, 2.3], + // CHECK-NEXT: [2.3, 2.3, 2.3, 2.3, 2.3], + // CHECK-NEXT: [1, 2, 3, 2.3, 2.3], + // CHECK-NEXT: [2, 3, 4, 2.3, 2.3]]] + + return +} + +func private @print_memref_f32(%ptr : tensor<*xf32>)