diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -222,6 +222,9 @@ return getResult().getType().cast(); } + // Infer the dynamic shape of the result tensor along each dim. + SmallVector getResultTypeShapes(OpBuilder &b); + // Infer the shape of the result tensor given the static shapes // and element type of the result tensor. static RankedTensorType inferResultType(RankedTensorType sourceType, diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -936,8 +936,7 @@ builder); } -LogicalResult PadTensorOp::reifyReturnTypeShapesPerResultDim( - OpBuilder &b, SmallVectorImpl> &reifiedReturnShapes) { +SmallVector PadTensorOp::getResultTypeShapes(OpBuilder &b) { Location loc = getLoc(); auto lowPad = getMixedLowPad(); auto highPad = getMixedHighPad(); @@ -963,7 +962,12 @@ shapes.push_back(applyMapToValues( b, loc, AffineMap::get(1, numSymbols, expr), mapOperands)[0]); } - reifiedReturnShapes.emplace_back(std::move(shapes)); + return shapes; +} + +LogicalResult PadTensorOp::reifyReturnTypeShapesPerResultDim( + OpBuilder &b, SmallVectorImpl> &reifiedReturnShapes) { + reifiedReturnShapes.emplace_back(getResultTypeShapes(b)); return success(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -116,6 +116,7 @@ #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/BufferUtils.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" #include "llvm/ADT/DenseSet.h" @@ -2451,8 +2452,88 @@ }; } // end namespace +/// Returns the static/dynamic mixed sizes of the tensor. +static SmallVector getMixedSizes(OpBuilder &b, Location loc, + Value tensor) { + auto inputType = tensor.getType().cast(); + auto inputShape = inputType.getShape(); + SmallVector sizeMixedValues; + for (int64_t i = 0; i < inputType.getRank(); ++i) { + if (inputShape[i] == ShapedType::kDynamicSize) { + Value dim = b.create(loc, tensor, i); + sizeMixedValues.push_back(dim); + } else { + sizeMixedValues.push_back(b.getI64IntegerAttr(inputShape[i])); + } + } + return sizeMixedValues; +} + +/// Conversion `linalg.pad_tensor` operation to other ops that will be +/// bufferized by the general bufferization pass +class PadTensorOpPattern : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(PadTensorOp op, + PatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + Value sourceTensor = op.source(); + auto sourceTensorType = sourceTensor.getType().cast(); + + // Allocate the destination buffer + SmallVector resultShape = op.getResultTypeShapes(rewriter); + Value resultTensor = rewriter.create( + loc, resultShape, sourceTensorType.getElementType()); + + // Get padding value and fill the destination buffer. + auto yieldOps = op.region().getOps(); + if (!llvm::hasSingleElement(yieldOps)) { + return rewriter.notifyMatchFailure(op, + "linalg.pad_tensor with more than one " + "padding value is not supported"); + } + Value paddingValue = (*yieldOps.begin()).values()[0]; + auto constOp = paddingValue.getDefiningOp(); + if (!constOp) { + return rewriter.notifyMatchFailure( + op, + "linalg.pad_tensor with non-constant padding value is not supported"); + } + if (constOp.getValue().isa()) { + return rewriter.notifyMatchFailure( + op, "linalg.pad_tensor with non-scalar constant padding value is not " + "supported"); + } + resultTensor = + rewriter.create(loc, paddingValue, resultTensor) + ->getResult(0); + + // Get the interior region. + SmallVector sizes = + getMixedSizes(rewriter, loc, sourceTensor); + SmallVector strides(sourceTensorType.getRank(), + rewriter.getI64IntegerAttr(1)); + + // Copy input into the interior region + resultTensor = rewriter.create( + loc, sourceTensor, resultTensor, op.getMixedLowPad(), sizes, strides); + + rewriter.replaceOpWithNewOp(op, op.getResultType(), + resultTensor); + return success(); + } +}; + +static void applyEnablingTransformations(ModuleOp moduleOp) { + RewritePatternSet patterns(moduleOp.getContext()); + patterns.add(moduleOp.getContext()); + (void)applyPatternsAndFoldGreedily(moduleOp, std::move(patterns)); +} + void LinalgComprehensiveModuleBufferize::runOnOperation() { ModuleOp moduleOp = getOperation(); + applyEnablingTransformations(moduleOp); SmallVector orderedFuncOps; DenseMap> callerMap; diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-padtensor.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-padtensor.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-padtensor.mlir @@ -0,0 +1,32 @@ +// RUN: mlir-opt %s -canonicalize -cse -linalg-comprehensive-module-bufferize |\ +// RUN: mlir-opt -convert-vector-to-scf -lower-affine -convert-linalg-to-loops |\ +// RUN: mlir-opt -canonicalize -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \ + +// RUN: mlir-cpu-runner -O3 -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext |\ +// RUN: FileCheck %s + +func @main() { + %const = constant dense<[[[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]]]> : tensor<1x2x3xf32> + %dynamic = tensor.cast %const: tensor<1x2x3xf32> to tensor<1x?x3xf32> + %offset = constant 2 : index + %cst = constant 2.3 : f32 + %c0 = constant 0 : index + %out = linalg.pad_tensor %dynamic low[%c0, %offset, %c0] high[%c0, %c0, %offset] { + ^bb0(%gen_arg1: index, %gen_arg2: index, %gen_arg3: index): // no predecessors + linalg.yield %cst : f32 + } : tensor<1x?x3xf32> to tensor<1x?x?xf32> + %unranked = tensor.cast %out: tensor<1x?x?xf32> to tensor<*xf32> + call @print_memref_f32(%unranked) : (tensor<*xf32>) -> () + + // CHECK: Unranked Memref base@ = {{0x[-9a-f]*}} + // CHECK-SAME: rank = 3 offset = 0 sizes = [1, 4, 5] strides = [20, 5, 1] data = + // CHECK-NEXT{LITERAL}: [[[2.3, 2.3, 2.3, 2.3, 2.3], + // CHECK-NEXT: [2.3, 2.3, 2.3, 2.3, 2.3], + // CHECK-NEXT: [1, 2, 3, 2.3, 2.3], + // CHECK-NEXT: [2, 3, 4, 2.3, 2.3]]] + + return +} + +func private @print_memref_f32(%ptr : tensor<*xf32>)