diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -87,6 +87,12 @@ void populateFoldUnitDimsReshapeOpsByLinearizationPatterns( RewritePatternSet &patterns); +/// Pattern to fuse a `linalg.pad_tensor` operation with the producer of its +/// source, if the producer is a `linalg` operation with all parallel iterator +/// types. +void populateFusePadTensorWithProducerLinalgOpPatterns( + RewritePatternSet &patterns); + /// Patterns to convert from one named op to another. These can be seen as /// canonicalizations of named ops into another named op. void populateLinalgNamedOpConversionPatterns(RewritePatternSet &patterns); diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -17,6 +17,7 @@ Loops.cpp LinalgStrategyPasses.cpp NamedOpConversions.cpp + PadOpInterchange.cpp Promotion.cpp Tiling.cpp Transforms.cpp diff --git a/mlir/lib/Dialect/Linalg/Transforms/PadOpInterchange.cpp b/mlir/lib/Dialect/Linalg/Transforms/PadOpInterchange.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/PadOpInterchange.cpp @@ -0,0 +1,122 @@ +//===- PadOpInterchange.cpp - Interchange pad operation with Generic ops --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements patterns that intechanges a generic op -> pad_tensor +// pattern into extract_slice -> generic_op. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" + +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using namespace mlir::linalg; + +namespace { + +/// A sequence of operations +/// +/// ```mlir +/// %0 = linalg. ... +/// %1 = linalg.pad_tensor %0 ... +/// ``` +/// +/// can be replaced with +/// +/// ```mlir +/// %0 = linalg.fill +/// %1 = tensor.extract_slice %0 ... +/// %2 = linalg. .... outs(..., %1, ....) .... +/// %3 = tensor.insert_slice %2 into %1 ... +/// ``` +/// +/// if the `linalg.generic` has all parallel iterator types. +struct FusePadTensorOp : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(PadTensorOp padOp, + PatternRewriter &rewriter) const override { + // Only works on padding op that sets the padded value to a constant. + Value padValue = padOp.getConstantPaddingValue(); + if (!padValue) + return rewriter.notifyMatchFailure(padOp, "non constant padding"); + + // This pattern could work for any Linalg op. For now restrict it to generic + // ops. + Value source = padOp.source(); + auto linalgOp = source.getDefiningOp(); + if (!linalgOp) { + return rewriter.notifyMatchFailure( + padOp, "expected source to be linalg.generic op"); + } + // All iterator types need to be parallel. + if (linalgOp.getNumLoops() != linalgOp.getNumParallelLoops()) { + return rewriter.notifyMatchFailure( + padOp, "only supported for ops with all parallel iterator types"); + } + ReifiedRankedShapedTypeDims resultShape; + if (failed(padOp.reifyResultShapes(rewriter, resultShape)) || + resultShape.size() != 1) { + return rewriter.notifyMatchFailure( + padOp, "failed to get shape of pad op result"); + } + + Location loc = padOp.getLoc(); + + // Create the tensor of same size as output of the pad op. + RankedTensorType padResultType = padOp.getResultType(); + auto resultSizes = getAsOpFoldResult(resultShape[0]); + auto initTensor = rewriter.create( + loc, resultSizes, padResultType.getElementType()); + + // Fill the tensor with the pad value. + // TODO: There is an option to fill only the boundaries. For now just + // filling the whole tensor. + auto fillTensor = + rewriter.create(loc, padValue, initTensor.getResult()); + + // Construct a slice of the fill result that is to be replaced with the + // result of the generic op. The low pad values are the offsets, the size of + // the source is the size of the slice. + // TODO: This insert/extract could be potentially made a utility method. + unsigned resultNumber = source.cast().getResultNumber(); + SmallVector offsets = padOp.getMixedLowPad(); + SmallVector sizes; + sizes.reserve(offsets.size()); + for (auto shape : llvm::enumerate( + source.getType().cast().getShape())) { + if (ShapedType::isDynamic(shape.value())) { + sizes.push_back( + rewriter.create(loc, source, shape.index()) + .getResult()); + } else { + sizes.push_back(rewriter.getIndexAttr(shape.value())); + } + } + SmallVector strides(offsets.size(), rewriter.getIndexAttr(1)); + auto slice = rewriter.create( + loc, fillTensor.getResult(0), offsets, sizes, strides); + + // Clone the generic op. + auto clonedOp = cast(rewriter.clone(*linalgOp.getOperation())); + clonedOp.setOutputOperand(resultNumber, slice.getResult()); + + // Insert it back into the result of the fill. + rewriter.replaceOpWithNewOp( + padOp, clonedOp.getResult(resultNumber), fillTensor.getResult(0), + offsets, sizes, strides); + return success(); + } +}; +} // namespace + +void mlir::linalg::populateFusePadTensorWithProducerLinalgOpPatterns( + RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} diff --git a/mlir/test/Dialect/Linalg/pad_fusion.mlir b/mlir/test/Dialect/Linalg/pad_fusion.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/pad_fusion.mlir @@ -0,0 +1,93 @@ +// RUN: mlir-opt -test-linalg-pad-fusion -split-input-file %s | FileCheck %s + +func @dynamic_pad_fusion(%arg0 : tensor, %arg1 : index, %arg2 : index, + %arg3 : index, %arg4 : index, %arg5 : f32) -> tensor { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %d0 = tensor.dim %arg0, %c0 : tensor + %d1 = tensor.dim %arg0, %c1 : tensor + %init = linalg.init_tensor [%d0, %d1] : tensor + %0 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%arg0 : tensor) outs(%init : tensor) { + ^bb0(%arg6 : f32, %arg7 : f32): + %1 = arith.mulf %arg6, %arg6 : f32 + linalg.yield %1 : f32 + } -> tensor + %1 = linalg.pad_tensor %0 low [%arg1, %arg2] high [%arg3, %arg4] { + ^bb0(%arg6: index, %arg7 : index): + linalg.yield %arg5 : f32 + } : tensor to tensor + return %1 : tensor +} + +// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> (s2 + s0 + s1)> +// CHECK: func @dynamic_pad_fusion +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: f32 +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[SOURCE:.+]] = linalg.generic +// CHECK-DAG: %[[SOURCE_D0:.+]] = tensor.dim %[[SOURCE]], %[[C0]] +// CHECK-DAG: %[[TARGET_D0:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG3]], %[[SOURCE_D0]]] +// CHECK-DAG: %[[SOURCE_D1:.+]] = tensor.dim %[[SOURCE]], %[[C1]] +// CHECK-DAG: %[[TARGET_D1:.+]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[ARG4]], %[[SOURCE_D1]]] +// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[TARGET_D0]], %[[TARGET_D1]]] +// CHECK: %[[FILL:.+]] = linalg.fill(%[[ARG5]], %[[INIT]]) +// CHECK-DAG: %[[SIZE_D0:.+]] = tensor.dim %[[SOURCE]], %[[C0]] +// CHECK-DAG: %[[SIZE_D1:.+]] = tensor.dim %[[SOURCE]], %[[C1]] +// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[FILL]] +// CHECK-SAME: [%[[ARG1]], %[[ARG2]]] [%[[SIZE_D0]], %[[SIZE_D1]]] [1, 1] +// CHECK: %[[SOURCE:.+]] = linalg.generic +// CHECK-SAME: outs(%[[SLICE]] : tensor) +// CHECK: %[[RESULT:.+]] = tensor.insert_slice %[[SOURCE]] into %[[FILL]] +// CHECK-SAME: [%[[ARG1]], %[[ARG2]]] [%[[SIZE_D0]], %[[SIZE_D1]]] [1, 1] +// CHECK: return %[[RESULT]] + +// ----- + +func @mixed_pad_fusion(%arg0 : tensor, %arg1 : index, %arg2 : index, + %arg3 : f32) -> tensor<49x?xf32> { + %c0 = arith.constant 0 : index + %d0 = tensor.dim %arg0, %c0 : tensor + %init = linalg.init_tensor [42, %d0] : tensor<42x?xf32> + %0 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], + iterator_types = ["parallel", "parallel"]} + ins(%arg0 : tensor) outs(%init : tensor<42x?xf32>) { + ^bb0(%arg4 : f32, %arg5 : f32): + %1 = arith.mulf %arg4, %arg4 : f32 + linalg.yield %1 : f32 + } -> tensor<42x?xf32> + %1 = linalg.pad_tensor %0 low [3, %arg1] high [4, %arg2] { + ^bb0(%arg4: index, %arg5 : index): + linalg.yield %arg3 : f32 + } : tensor<42x?xf32> to tensor<49x?xf32> + return %1 : tensor<49x?xf32> +} +// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> (s2 + s0 + s1)> +// CHECK: func @mixed_pad_fusion +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: f32 +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[SOURCE:.+]] = linalg.generic +// CHECK-DAG: %[[SOURCE_D1:.+]] = tensor.dim %[[SOURCE]], %[[C1]] +// CHECK-DAG: %[[TARGET_D1:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG2]], %[[SOURCE_D1]]] +// CHECK: %[[INIT:.+]] = linalg.init_tensor [49, %[[TARGET_D1]]] +// CHECK: %[[FILL:.+]] = linalg.fill(%[[ARG3]], %[[INIT]]) +// CHECK-DAG: %[[SIZE_D1:.+]] = tensor.dim %[[SOURCE]], %[[C1]] +// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[FILL]] +// CHECK-SAME: [3, %[[ARG1]]] [42, %[[SIZE_D1]]] [1, 1] +// CHECK: %[[SOURCE:.+]] = linalg.generic +// CHECK-SAME: outs(%[[SLICE]] : tensor<42x?xf32>) +// CHECK: %[[RESULT:.+]] = tensor.insert_slice %[[SOURCE]] into %[[FILL]] +// CHECK-SAME: [3, %[[ARG1]]] [42, %[[SIZE_D1]]] [1, 1] +// CHECK: return %[[RESULT]] diff --git a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt --- a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt @@ -8,6 +8,7 @@ TestLinalgFusionTransforms.cpp TestLinalgHoisting.cpp TestLinalgTransforms.cpp + TestPadFusion.cpp EXCLUDE_FROM_LIBMLIR diff --git a/mlir/test/lib/Dialect/Linalg/TestPadFusion.cpp b/mlir/test/lib/Dialect/Linalg/TestPadFusion.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/Linalg/TestPadFusion.cpp @@ -0,0 +1,48 @@ +//===- TestPadFusion.cpp - Test fusion of pad op with Linalg ops ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass for testing fusion of pad ops with its producer +// Linalg op. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir { + +namespace { +struct TestPadFusionPass : public PassWrapper { + + void getDependentDialects(DialectRegistry ®istry) const override { + registry + .insert(); + } + + StringRef getArgument() const final { return "test-linalg-pad-fusion"; } + StringRef getDescription() const final { return "Test PadOp fusion"; } + + void runOnFunction() override { + MLIRContext *context = &getContext(); + FuncOp funcOp = getFunction(); + RewritePatternSet patterns(context); + linalg::populateFusePadTensorWithProducerLinalgOpPatterns(patterns); + if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(), + std::move(patterns)))) + return signalPassFailure(); + } +}; +} // namespace + +namespace test { +void registerTestPadFusion() { PassRegistration(); } +} // namespace test + +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -103,6 +103,7 @@ void registerTestNumberOfBlockExecutionsPass(); void registerTestNumberOfOperationExecutionsPass(); void registerTestOpaqueLoc(); +void registerTestPadFusion(); void registerTestPDLByteCodePass(); void registerTestPreparationPassWithAllowedMemrefResults(); void registerTestRecursiveTypesPass(); @@ -195,6 +196,7 @@ mlir::test::registerTestNumberOfBlockExecutionsPass(); mlir::test::registerTestNumberOfOperationExecutionsPass(); mlir::test::registerTestOpaqueLoc(); + mlir::test::registerTestPadFusion(); mlir::test::registerTestPDLByteCodePass(); mlir::test::registerTestRecursiveTypesPass(); mlir::test::registerTestSCFUtilsPass();