diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -103,10 +103,9 @@ void populateFoldUnitDimsReshapeOpsByLinearizationPatterns( RewritePatternSet &patterns); -/// Pattern to fuse a `linalg.pad_tensor` operation with the producer of its -/// source, if the producer is a `linalg` operation with all parallel iterator -/// types. -void populateFusePadTensorWithProducerLinalgOpPatterns( +/// Pattern to fuse a `tensor.pad` operation with the producer of its source, +/// if the producer is a `linalg` operation with all parallel iterator types. +void populateFuseTensorPadWithProducerLinalgOpPatterns( RewritePatternSet &patterns); /// Patterns to convert from one named op to another. These can be seen as diff --git a/mlir/lib/Dialect/Linalg/Transforms/PadOpInterchange.cpp b/mlir/lib/Dialect/Linalg/Transforms/PadOpInterchange.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/PadOpInterchange.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/PadOpInterchange.cpp @@ -1,4 +1,4 @@ -//===- PadOpInterchange.cpp - Interchange pad operation with Generic ops --===// +//===- PadOpInterchange.cpp - Interchange tensor.pad with linalg producer -===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,8 +6,9 @@ // //===----------------------------------------------------------------------===// // -// This file implements patterns that intechanges a generic op -> pad_tensor -// pattern into extract_slice -> generic_op. +// This file implements patterns that intechanges a linalg.generic -> tensor.pad +// op chain into a tensor.extract_slice -> linalg.generic -> tensor.insert_slice +// op chain. // //===----------------------------------------------------------------------===// @@ -17,7 +18,6 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" using namespace mlir; -using namespace mlir::linalg; namespace { @@ -25,7 +25,7 @@ /// /// ```mlir /// %0 = linalg. ... -/// %1 = linalg.pad_tensor %0 ... +/// %1 = tensor.pad %0 ... /// ``` /// /// can be replaced with @@ -40,6 +40,7 @@ /// if the `linalg.generic` has all parallel iterator types. struct FusePadOp : OpRewritePattern { using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const override { // Only works on padding op that sets the padded value to a constant. @@ -50,7 +51,7 @@ // This pattern could work for any Linalg op. For now restrict it to generic // ops. Value source = padOp.source(); - auto linalgOp = source.getDefiningOp(); + auto linalgOp = source.getDefiningOp(); if (!linalgOp) { return rewriter.notifyMatchFailure( padOp, "expected source to be linalg.generic op"); @@ -75,14 +76,14 @@ // Create the tensor of same size as output of the pad op. RankedTensorType padResultType = padOp.getResultType(); auto resultSizes = getAsOpFoldResult(resultShape[0]); - auto initTensor = rewriter.create( + auto initTensor = rewriter.create( loc, resultSizes, padResultType.getElementType()); // Fill the tensor with the pad value. // TODO: There is an option to fill only the boundaries. For now just // filling the whole tensor. auto fillTensor = - rewriter.create(loc, padValue, initTensor.getResult()); + rewriter.create(loc, padValue, initTensor.getResult()); // Construct a slice of the fill result that is to be replaced with the // result of the generic op. The low pad values are the offsets, the size of @@ -107,7 +108,8 @@ loc, fillTensor.getResult(0), offsets, sizes, strides); // Clone the generic op. - auto clonedOp = cast(rewriter.clone(*linalgOp.getOperation())); + auto clonedOp = + cast(rewriter.clone(*linalgOp.getOperation())); clonedOp.setOutputOperand(resultNumber, slice.getResult()); // Insert it back into the result of the fill. @@ -119,7 +121,7 @@ }; } // namespace -void mlir::linalg::populateFusePadTensorWithProducerLinalgOpPatterns( +void mlir::linalg::populateFuseTensorPadWithProducerLinalgOpPatterns( RewritePatternSet &patterns) { patterns.add(patterns.getContext()); } diff --git a/mlir/test/lib/Dialect/Linalg/TestPadFusion.cpp b/mlir/test/lib/Dialect/Linalg/TestPadFusion.cpp --- a/mlir/test/lib/Dialect/Linalg/TestPadFusion.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestPadFusion.cpp @@ -34,7 +34,7 @@ MLIRContext *context = &getContext(); FuncOp funcOp = getOperation(); RewritePatternSet patterns(context); - linalg::populateFusePadTensorWithProducerLinalgOpPatterns(patterns); + linalg::populateFuseTensorPadWithProducerLinalgOpPatterns(patterns); if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns)))) return signalPassFailure();