diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1398,6 +1398,14 @@ ControlFn controlFn; }; +struct BubbleUpExtractSliceOpPattern + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp, + PatternRewriter &rewriter) const override; +}; + //===----------------------------------------------------------------------===// // Helper classes for type list expansion. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp b/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp @@ -0,0 +1,108 @@ +//===- BubbleUpExtractSlice.cpp - bubble up tensor.extract_slice ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements patterns that transforms linalg. + +// tensor.extract_slice into tensor.extract_slice + linalg. to reduce +// the computation for the linalg op. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using namespace mlir::linalg; + +/// Apply projected permutation on the list of OpFoldResult. For example, +/// when the mixed offsets of a tensor.extract_slice is [%o0, %o1] and the +/// indexing map is <(d0, d1) -> (d1)>, the result will be [%o1]. +SmallVector +applyProjectedPermutationMap(AffineMap map, + ArrayRef foldResults) { + assert(map.isProjectedPermutation() && + "map must be a projected permutation."); + + SmallVector res; + res.reserve(map.getNumResults()); + + for (AffineExpr expr : map.getResults()) { + unsigned pos = expr.cast().getPosition(); + res.push_back(foldResults[pos]); + } + + return res; +} + +/// Bubble up extract_slice above Linalg operation. +/// +/// A sequence of operations +/// +/// ```mlir +/// %0 = linalg. ... arg0, arg1, ... +/// %1 = tensor.extract_slice %0 ... +/// ``` +/// +/// can be replaced with +/// +/// ```mlir +/// %0 = tensor.extract_slice %arg0 +/// %1 = tensor.extract_slice %arg1 +/// %2 = linalg. ... %0, %1, ... +/// ``` +/// +/// This results in the reduce computation of the linalg operation. +/// +LogicalResult BubbleUpExtractSliceOpPattern::matchAndRewrite( + tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const { + Value source = sliceOp.source(); + auto linalgOp = source.getDefiningOp(); + if (!linalgOp) { + return rewriter.notifyMatchFailure(sliceOp, + "expected source to be linalg op"); + } + + if (!linalgOp->hasOneUse()) { + return rewriter.notifyMatchFailure(sliceOp, + "expected single use of linalg op"); + } + + if (!linalgOp.hasTensorSemantics()) { + return rewriter.notifyMatchFailure(sliceOp, "expected tensor of linalg op"); + } + + if (!llvm::all_of(linalgOp.indexing_maps().getValue(), [](Attribute attr) { + return attr.cast().getValue().isProjectedPermutation(); + })) { + return rewriter.notifyMatchFailure(sliceOp, + "expected projected permutation"); + } + + // bubble up extract slice for each operand. + auto sliceOffsets = sliceOp.getMixedOffsets(); + auto sliceSizes = sliceOp.getMixedSizes(); + auto sliceStrides = sliceOp.getMixedStrides(); + + SmallVector newOperands; + for (OpOperand *operand : linalgOp.getInputAndOutputOperands()) { + auto indexingMap = linalgOp.getTiedIndexingMap(operand); + auto newOffsets = applyProjectedPermutationMap(indexingMap, sliceOffsets); + auto newSizes = applyProjectedPermutationMap(indexingMap, sliceSizes); + auto newStrides = applyProjectedPermutationMap(indexingMap, sliceStrides); + Value newSlice = rewriter.create( + sliceOp.getLoc(), operand->get(), newOffsets, newSizes, newStrides); + newOperands.push_back(newSlice); + } + + Operation *newOp = linalgOp.clone(rewriter, linalgOp.getLoc(), + sliceOp.getType(), newOperands); + rewriter.replaceOp(sliceOp, newOp->getResults()); + return success(); +} diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_dialect_library(MLIRLinalgTransforms + BubbleUpExtractSlice.cpp BufferizableOpInterfaceImpl.cpp Bufferize.cpp CodegenStrategy.cpp diff --git a/mlir/test/Dialect/Linalg/bubble-up-extract-slice-op.mlir b/mlir/test/Dialect/Linalg/bubble-up-extract-slice-op.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/bubble-up-extract-slice-op.mlir @@ -0,0 +1,111 @@ +// RUN: mlir-opt -test-linalg-transform-patterns=test-bubble-up-extract-slice-op-pattern -split-input-file %s | FileCheck %s + +#map0 = affine_map<(d0, d1) -> (d0, d1)> +#map1 = affine_map<(d0, d1) -> (d1)> + +func @dynamic(%arg0: tensor, %arg1: tensor, %arg2: index, %arg3: index, %arg4: index, %arg5:index) -> tensor { + %0 = linalg.generic { + indexing_maps = [#map0, #map1, #map0], + iterator_types = ["parallel", "parallel"] + } ins(%arg0, %arg1 : tensor, tensor) + outs(%arg0 : tensor) { + ^bb0(%b0 : f32, %b1 : f32, %b2 : f32): + %add = arith.addf %b0, %b1 : f32 + linalg.yield %add : f32 + } -> tensor + %1 = tensor.extract_slice %0 [%arg2, %arg3] [%arg4, %arg5] [1, 1] + : tensor to tensor + return %1 : tensor +} + +// CHECK: func @dynamic +// CHECK: %[[SLICE0:.+]] = tensor.extract_slice %arg0[%arg2, %arg3] [%arg4, %arg5] [1, 1] : tensor to tensor +// CHECK: %[[SLICE1:.+]] = tensor.extract_slice %arg1[%arg3] [%arg5] [1] : tensor to tensor +// CHECK: %[[SLICE2:.+]] = tensor.extract_slice %arg0[%arg2, %arg3] [%arg4, %arg5] [1, 1] : tensor to tensor +// CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel"]} +// CHECK-SAME: ins(%[[SLICE0]], %[[SLICE1]] : tensor, tensor) outs(%[[SLICE2]] : tensor) +// CHECK: return %[[GENERIC]] : tensor + +// ----- + +#map0 = affine_map<(d0, d1) -> (d0, d1)> +#map1 = affine_map<(d0, d1) -> (d1)> + +func @static(%arg0: tensor<16x8xf32>, %arg1: tensor<8xf32>) -> tensor<4x2xf32> { + %0 = linalg.generic { + indexing_maps = [#map0, #map1, #map0], + iterator_types = ["parallel", "parallel"] + } ins(%arg0, %arg1 : tensor<16x8xf32>, tensor<8xf32>) + outs(%arg0 : tensor<16x8xf32>) { + ^bb0(%b0 : f32, %b1 : f32, %b2 : f32): + %add = arith.addf %b0, %b1 : f32 + linalg.yield %add : f32 + } -> tensor<16x8xf32> + %1 = tensor.extract_slice %0 [8, 4] [4, 2] [1, 1] + : tensor<16x8xf32> to tensor<4x2xf32> + return %1 : tensor<4x2xf32> +} + +// CHECK: func @static +// CHECK: %[[SLICE0:.+]] = tensor.extract_slice %arg0[8, 4] [4, 2] [1, 1] : tensor<16x8xf32> to tensor<4x2xf32> +// CHECK: %[[SLICE1:.+]] = tensor.extract_slice %arg1[4] [2] [1] : tensor<8xf32> to tensor<2xf32> +// CHECK: %[[SLICE2:.+]] = tensor.extract_slice %arg0[8, 4] [4, 2] [1, 1] : tensor<16x8xf32> to tensor<4x2xf32> +// CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel"]} +// CHECK-SAME: ins(%[[SLICE0]], %[[SLICE1]] : tensor<4x2xf32>, tensor<2xf32>) outs(%[[SLICE2]] : tensor<4x2xf32>) +// CHECK: return %[[GENERIC]] : tensor<4x2xf32> + +// ----- + +#map0 = affine_map<(d0, d1) -> (d0, d1)> +#map1 = affine_map<(d0, d1) -> (d1)> + +func @mixed(%arg0: tensor, %arg1: tensor<8xf32>, %arg2: index, %arg3: index) -> tensor { + %0 = linalg.generic { + indexing_maps = [#map0, #map1, #map0], + iterator_types = ["parallel", "parallel"] + } ins(%arg0, %arg1 : tensor, tensor<8xf32>) + outs(%arg0 : tensor) { + ^bb0(%b0 : f32, %b1 : f32, %b2 : f32): + %add = arith.addf %b0, %b1 : f32 + linalg.yield %add : f32 + } -> tensor + %1 = tensor.extract_slice %0 [8, %arg2] [%arg3, 2] [1, 1] + : tensor to tensor + return %1 : tensor +} + +// CHECK: func @mixed +// CHECK: %[[SLICE0:.+]] = tensor.extract_slice %arg0[8, %arg2] [%arg3, 2] [1, 1] : tensor to tensor +// CHECK: %[[SLICE1:.+]] = tensor.extract_slice %arg1[%arg2] [2] [1] : tensor<8xf32> to tensor<2xf32> +// CHECK: %[[SLICE2:.+]] = tensor.extract_slice %arg0[8, %arg2] [%arg3, 2] [1, 1] : tensor to tensor +// CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel"]} +// CHECK-SAME: ins(%[[SLICE0]], %[[SLICE1]] : tensor, tensor<2xf32>) outs(%[[SLICE2]] : tensor) +// CHECK: return %[[GENERIC]] : tensor + +// ----- + +#map0 = affine_map<(d0, d1) -> (d0, d1)> +#map1 = affine_map<(d0, d1) -> (d1)> + +func @dynamic_to_static(%arg0: tensor, %arg1: tensor) -> tensor<4x2xf32> { + %0 = linalg.generic { + indexing_maps = [#map0, #map1, #map0], + iterator_types = ["parallel", "parallel"] + } ins(%arg0, %arg1 : tensor, tensor) + outs(%arg0 : tensor) { + ^bb0(%b0 : f32, %b1 : f32, %b2 : f32): + %add = arith.addf %b0, %b1 : f32 + linalg.yield %add : f32 + } -> tensor + %1 = tensor.extract_slice %0 [8, 4] [4, 2] [1, 1] + : tensor to tensor<4x2xf32> + return %1 : tensor<4x2xf32> +} + +// CHECK: func @dynamic_to_static +// CHECK: %[[SLICE0:.+]] = tensor.extract_slice %arg0[8, 4] [4, 2] [1, 1] : tensor to tensor<4x2xf32> +// CHECK: %[[SLICE1:.+]] = tensor.extract_slice %arg1[4] [2] [1] : tensor to tensor<2xf32> +// CHECK: %[[SLICE2:.+]] = tensor.extract_slice %arg0[8, 4] [4, 2] [1, 1] : tensor to tensor<4x2xf32> +// CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel"]} +// CHECK-SAME: ins(%[[SLICE0]], %[[SLICE1]] : tensor<4x2xf32>, tensor<2xf32>) outs(%[[SLICE2]] : tensor<4x2xf32>) +// CHECK: return %[[GENERIC]] : tensor<4x2xf32> diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -127,6 +127,11 @@ llvm::cl::desc("Specify the type of loops to generate: for, parallel or " "tiled_loop"), llvm::cl::init("for")}; + Option testBubbleUpExtractSliceOpPattern{ + *this, "test-bubble-up-extract-slice-op-pattern", + llvm::cl::desc("Test rewrite of extract_slice(generic) into " + "generic(extract_slice)"), + llvm::cl::init(false)}; }; } // namespace @@ -616,6 +621,12 @@ (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern)); } +static void applyBubbleUpExtractSliceOpPattern(FuncOp funcOp) { + RewritePatternSet patterns(funcOp.getContext()); + patterns.add(funcOp.getContext()); + (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); +} + /// Apply transformations specified as patterns. void TestLinalgTransforms::runOnOperation() { auto lambda = [&](void *) { @@ -665,6 +676,8 @@ if (testTileScalarizeDynamicDims) return applyTilePattern(getOperation(), loopType, tileSizes, /*peeledLoops=*/{}, /*scalarizeDynamicDims=*/true); + if (testBubbleUpExtractSliceOpPattern) + return applyBubbleUpExtractSliceOpPattern(getOperation()); } namespace mlir {