diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -222,6 +222,9 @@ return getResult().getType().cast(); } + // Infer the dynamic shape of the result tensor along each dim + SmallVector getResultTypeShapes(OpBuilder &b); + // Infer the shape of the result tensor given the static shapes // and element type of the result tensor. static RankedTensorType inferResultType(RankedTensorType sourceType, diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -936,8 +936,7 @@ builder); } -LogicalResult PadTensorOp::reifyReturnTypeShapesPerResultDim( - OpBuilder &b, SmallVectorImpl> &reifiedReturnShapes) { +SmallVector PadTensorOp::getResultTypeShapes(OpBuilder &b) { Location loc = getLoc(); auto lowPad = getMixedLowPad(); auto highPad = getMixedHighPad(); @@ -963,7 +962,12 @@ shapes.push_back(applyMapToValues( b, loc, AffineMap::get(1, numSymbols, expr), mapOperands)[0]); } - reifiedReturnShapes.emplace_back(std::move(shapes)); + return shapes; +} + +LogicalResult PadTensorOp::reifyReturnTypeShapesPerResultDim( + OpBuilder &b, SmallVectorImpl> &reifiedReturnShapes) { + reifiedReturnShapes.emplace_back(getResultTypeShapes(b)); return success(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -116,6 +116,7 @@ #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/BufferUtils.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" #include "llvm/ADT/DenseSet.h" @@ -2581,8 +2582,87 @@ }; } // end namespace +/// Returns the static/dynamic mixed sizes of the memref. +static SmallVector getMixedSizes(OpBuilder &b, Location loc, + Value memref) { + auto inputType = memref.getType().cast(); + auto inputShape = inputType.getShape(); + SmallVector sizeMixedValues; + for (int64_t i = 0; i < inputType.getRank(); ++i) { + if (inputShape[i] == ShapedType::kDynamicSize) { + Value dim = b.create(loc, memref, i); + sizeMixedValues.push_back(dim); + } else { + sizeMixedValues.push_back(b.getI64IntegerAttr(inputShape[i])); + } + } + return sizeMixedValues; +} + +/// Conversion pattern that bufferizes `linalg.pad_tensor` operation. +class PadTensorOpPattern : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(PadTensorOp op, + PatternRewriter &rewriter) const { + Location loc = op->getLoc(); + Value sourceTensor = op.source(); + auto sourceTensorType = sourceTensor.getType().cast(); + + // Allocate the destination buffer + SmallVector resultShape = op.getResultTypeShapes(rewriter); + Value resultTensor = rewriter.create( + loc, resultShape, sourceTensorType.getElementType()); + + // Get padding value and fill the destination buffer. + auto yieldOps = op.region().getOps(); + if (std::distance(yieldOps.begin(), yieldOps.end()) != 1) { + return rewriter.notifyMatchFailure(op, + "linalg.pad_tensor with more than one " + "padding value is not supported"); + } + Value paddingValue = (*yieldOps.begin()).values()[0]; + auto constOp = paddingValue.getDefiningOp(); + if (!constOp) { + return rewriter.notifyMatchFailure( + op, + "linalg.pad_tensor with non-constant padding value is not supported"); + } + if (constOp.getValue().isa()) { + return rewriter.notifyMatchFailure( + op, "linalg.pad_tensor with non-scalar constant padding value is not " + "supported"); + } + resultTensor = + rewriter.create(loc, paddingValue, resultTensor) + ->getResult(0); + + // Get the interior region. + SmallVector sizes = + getMixedSizes(rewriter, loc, sourceTensor); + SmallVector strides(sourceTensorType.getRank(), + rewriter.getI64IntegerAttr(1)); + + // Copy input into the interior region + resultTensor = rewriter.create( + loc, sourceTensor, resultTensor, op.getMixedLowPad(), sizes, strides); + + rewriter.replaceOpWithNewOp(op, op.getResultType(), + resultTensor); + return success(); + } +}; + +static void applyEnablingTransformations(ModuleOp moduleOp) { + RewritePatternSet patterns(moduleOp.getContext()); + patterns.add(moduleOp.getContext()); + (void)applyPatternsAndFoldGreedily(moduleOp, std::move(patterns)); +} + void LinalgComprehensiveModuleBufferize::runOnOperation() { ModuleOp moduleOp = getOperation(); + applyEnablingTransformations(moduleOp); SmallVector orderedFuncOps; DenseMap> callerMap;