diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -119,6 +119,9 @@ /// Patterns that are used to inline constant operands into linalg generic ops. void populateInlineConstantOperandsPatterns(RewritePatternSet &patterns); +/// Patterns that are used to bubble up extract slice op above linalg op. +void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns); + /// Options that control fusion of elementwise operations. struct LinalgElementwiseFusionOptions { /// Enable fusion of reshapes into the shape with elementwise operations. By diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -166,16 +166,21 @@ /// Creates an extract_slice/subview op for a single `valueToTile` with /// `builder`. This new operation extracts a tile of `valueToTile`, starting -/// at offsets `lbs` and with sizes `subShapeSizes`. +/// at offsets `lbs` and with sizes `subShapeSizes`. `omitPartialTileCheck` +/// controls whether to omit the partial/boundary tile condition check in cases +/// where we statically know that it is unnecessary. Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile, ValueRange tileSizes, AffineMap map, ValueRange lbs, - ValueRange ubs, ValueRange subShapeSizes); + ValueRange ubs, ValueRange subShapeSizes, + bool omitPartialTileCheck); /// Creates extract_slice/subview ops for all `valuesToTile` of the given /// `linalgOp` with `builder`, assuming `linalgOp` is being fused into a loop /// nest for tiling with the given induction variables `ivs` and tile sizes /// `tileSizes`. `sizeBounds` are the iteration space bounds for *all* the -/// implicit loops in `linalgOp`. +/// implicit loops in `linalgOp`. `omitPartialTileCheck` controls whether to +/// omit the partial/boundary tile condition check in cases where we statically +/// know that it is unnecessary. /// /// Note that a constant zero in `tileSizes` means no tiling at that implicit /// loop. The number of non-zero values in `tileSizes` should be equal to the @@ -184,7 +189,8 @@ LinalgOp linalgOp, ArrayRef valuesToTile, ValueRange ivs, ValueRange tileSizes, - ArrayRef sizeBounds); + ArrayRef sizeBounds, + bool omitPartialTileCheck); /// Add the tile loop induction variables `ivs` to the IndexOp results found in /// the body of the `tiledOp` to account for the tile offset. diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h --- a/mlir/include/mlir/IR/AffineMap.h +++ b/mlir/include/mlir/IR/AffineMap.h @@ -480,7 +480,7 @@ /// ```mlir /// affine_map<(d0, d1) -> (d0, 0, 0)> /// ``` -AffineMap inverseAndBroadcastProjectedPermuation(AffineMap map); +AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map); /// Concatenates a list of `maps` into a single AffineMap, stepping over /// potentially empty maps. Assumes each of the underlying map has 0 symbols. diff --git a/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp b/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp @@ -0,0 +1,139 @@ +//===- BubbleUpExtractSlice.cpp - bubble up tensor.extract_slice ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements patterns that transforms linalg. + +// tensor.extract_slice into tensor.extract_slice + linalg. to reduce +// the computation for the linalg op. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arithmetic/Utils/Utils.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using namespace mlir::linalg; + +namespace { +/// Bubble up extract_slice above Linalg operation. +/// +/// A sequence of operations +/// +/// ```mlir +/// %0 = linalg. ... arg0, arg1, ... +/// %1 = tensor.extract_slice %0 ... +/// ``` +/// +/// can be replaced with +/// +/// ```mlir +/// %0 = tensor.extract_slice %arg0 +/// %1 = tensor.extract_slice %arg1 +/// %2 = linalg. ... %0, %1, ... +/// ``` +/// +/// This results in the reduce computation of the linalg operation. +/// +struct BubbleUpExtractSliceOpPattern + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp, + PatternRewriter &rewriter) const final { + Value source = sliceOp.source(); + auto linalgOp = source.getDefiningOp(); + if (!linalgOp) { + return rewriter.notifyMatchFailure(sliceOp, + "expected source to be linalg op"); + } + + // TODO: we might relax this if we want heuristics to detect that all uses + // are small portion of the output. + if (!linalgOp->hasOneUse()) { + return rewriter.notifyMatchFailure(sliceOp, + "expected single use of linalg op"); + } + + if (linalgOp.getNumOutputs() != 1) { + return rewriter.notifyMatchFailure(sliceOp, + "expected single output of linalg op"); + } + + if (!linalgOp.hasTensorSemantics()) { + return rewriter.notifyMatchFailure(sliceOp, + "expected tensor of linalg op"); + } + + if (!sliceOp.hasUnitStride()) + return rewriter.notifyMatchFailure(sliceOp, "expected unit stride"); + + OpOperand *outOperand = linalgOp.getOutputOperand(0); + AffineMap indexingMap = linalgOp.getTiedIndexingMap(outOperand); + if (!indexingMap.isProjectedPermutation()) { + return rewriter.notifyMatchFailure( + sliceOp, "expected a projected permutation for output"); + } + + auto linalgLoc = linalgOp.getLoc(); + auto allShapeSizes = + linalgOp.createFlatListOfOperandDims(rewriter, linalgLoc); + AffineMap shapeSizesToLoopsMap = linalgOp.getShapesToLoopsMap(); + if (!shapeSizesToLoopsMap) { + return rewriter.notifyMatchFailure( + linalgOp, "failed to get loops map from shape sizes"); + } + auto sizeBounds = applyMapToValues(rewriter, linalgLoc, + shapeSizesToLoopsMap, allShapeSizes); + + auto sliceLoc = sliceOp.getLoc(); + auto offsetVals = getValueOrCreateConstantIndexOp( + rewriter, sliceLoc, sliceOp.getMixedOffsets()); + auto sizeVals = getValueOrCreateConstantIndexOp(rewriter, sliceLoc, + sliceOp.getMixedSizes()); + + // The offsets and sizes from the slice operation only give you the tile + // size of the output. Use that compute the tile sizes and offsets of the + // loops. For loops not used to access the output, set the tile sizes to + // loop bounds and set the offset to 0. + Value zero = rewriter.create(linalgLoc, 0); + SmallVector tileOffsets(sizeBounds.size(), zero); + SmallVector tileSizes = sizeBounds; + for (auto const &result : enumerate(indexingMap.getResults())) { + unsigned position = result.value().cast().getPosition(); + tileOffsets[position] = offsetVals[result.index()]; + tileSizes[position] = sizeVals[result.index()]; + } + + SmallVector valuesToTile = linalgOp.getInputAndOutputOperands(); + + SmallVector tiledOperands = makeTiledShapes( + rewriter, linalgLoc, linalgOp, valuesToTile, tileOffsets, tileSizes, + sizeBounds, /*omitPartialTileCheck=*/true); + + SmallVector resultTensorTypes; + for (OpOperand *opOperand : linalgOp.getOutputTensorOperands()) + resultTensorTypes.push_back( + tiledOperands[opOperand->getOperandNumber()].getType()); + + Operation *newOp = + linalgOp.clone(rewriter, linalgLoc, resultTensorTypes, tiledOperands); + rewriter.replaceOp(sliceOp, newOp->getResults()); + return success(); + } +}; +} // namespace + +void mlir::linalg::populateBubbleUpExtractSliceOpPatterns( + RewritePatternSet &patterns) { + auto *context = patterns.getContext(); + patterns.add(context); +} diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_dialect_library(MLIRLinalgTransforms + BubbleUpExtractSlice.cpp BufferizableOpInterfaceImpl.cpp Bufferize.cpp CodegenStrategy.cpp diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -142,9 +142,9 @@ clonedShapes.reserve(producer.getNumInputsAndOutputs()); // Compute subranges for all tensor input/output operands. - clonedShapes.append(makeTiledShapes(b, loc, producer, - getTiledOperands(producer), ivs, - tileSizes, sizeBounds)); + clonedShapes.append(makeTiledShapes( + b, loc, producer, getTiledOperands(producer), ivs, tileSizes, sizeBounds, + /**omitPartialTileCheck=*/false)); // Iterate over the results in order. // Extract the subtensor type from the linearized range. diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -163,7 +163,8 @@ erase_value(tileIvs, nullptr); SmallVector tiledOperands = producerOp.getInputAndOutputOperands(); tiledOperands = makeTiledShapes(b, loc, producerOp, tiledOperands, tileIvs, - tileSizes, producerLoopBounds); + tileSizes, producerLoopBounds, + /**omitPartialTileCheck=*/false); // Output fusion has to update the iteration arguments of the tile loop nest. // In particular, the iteration argument of the outermost tile loop needs to diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -178,8 +178,9 @@ SmallVector valuesToTile = operandValuesToUse; auto sizeBounds = applyMapToValues(b, loc, shapeSizesToLoopsMap, allShapeSizes); - SmallVector tiledOperands = makeTiledShapes( - b, loc, op, valuesToTile, interchangedIvs, tileSizes, sizeBounds); + SmallVector tiledOperands = + makeTiledShapes(b, loc, op, valuesToTile, interchangedIvs, tileSizes, + sizeBounds, /*omitPartialTileCheck=*/false); // TODO: use an interface/adaptor to avoid leaking position in // `tiledOperands`. @@ -325,9 +326,9 @@ // Note: The tensor::PadOp is located outside of the loop nest. It is // later moved inside by ExtractSliceOfPadTensorSwapPattern. auto map = AffineMap::getMultiDimIdentityMap(rank, b.getContext()); - Value tiledOutput = - makeTiledShape(b, loc, newPadOp->getResult(0), tileSizes, map, - offsets, allDims, sizes); + Value tiledOutput = makeTiledShape( + b, loc, newPadOp->getResult(0), tileSizes, map, offsets, allDims, + sizes, /*omitPartialTileCheck=*/false); auto sliceOp = tiledOutput.getDefiningOp(); assert(sliceOp && "expected ExtractSliceOp"); // Insert the tile into the output tensor. diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -504,7 +504,7 @@ // readType = VectorType::get({}, bbarg.getType()); // } else { if (opOperand->getOperandNumber() < linalgOp.getNumInputs()) { - map = inverseAndBroadcastProjectedPermuation( + map = inverseAndBroadcastProjectedPermutation( linalgOp.getTiedIndexingMap(opOperand)); readType = VectorType::get(commonVectorShape, getElementTypeOrSelf(opOperand->get())); diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -745,7 +745,8 @@ Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile, ValueRange tileSizes, AffineMap map, ValueRange lbs, - ValueRange ubs, ValueRange subShapeSizes) { + ValueRange ubs, ValueRange subShapeSizes, + bool omitPartialTileCheck) { auto shapedType = valueToTile.getType().dyn_cast(); assert(shapedType && "only shaped types can be tiled"); ArrayRef shape = shapedType.getShape(); @@ -773,7 +774,7 @@ auto m = map.getSubMap({r}); LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: submap: " << m << "\n"); auto offset = applyMapToValues(builder, loc, m, lbs).front(); - offsets.push_back(offset); + offsets.push_back(getAsOpFoldResult(offset)); auto closedIntSize = applyMapToValues(builder, loc, m, subShapeSizes).front(); // Resulting size needs to be made half open interval again. @@ -781,6 +782,17 @@ Value size = fullyComposeAndAffineApply(builder, loc, s0 + 1, closedIntSize); LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: raw size: " << size << "\n"); + LLVM_DEBUG(llvm::dbgs() + << "makeTiledShape: new offset: " << offset << "\n"); + strides.push_back(builder.getIndexAttr(1)); + + if (omitPartialTileCheck) { + // We statically know that the partial/boundary tile condition is + // unnecessary. + LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: new size: " << size << "\n"); + sizes.push_back(getAsOpFoldResult(size)); + continue; + } // The size of the subview / extract_slice should be trimmed to avoid // out-of-bounds accesses, unless: @@ -829,12 +841,8 @@ size = builder.create(loc, builder.getIndexType(), minMap, operands); } - - sizes.push_back(size); - LLVM_DEBUG(llvm::dbgs() - << "makeTiledShape: new offset: " << offset << "\n"); LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: new size: " << size << "\n"); - strides.push_back(builder.getIndexAttr(1)); + sizes.push_back(getAsOpFoldResult(size)); } auto *sliceOp = TypeSwitch(shapedType) @@ -886,7 +894,8 @@ LinalgOp linalgOp, ArrayRef valuesToTile, ValueRange ivs, ValueRange tileSizes, - ArrayRef sizeBounds) { + ArrayRef sizeBounds, + bool omitPartialTileCheck) { assert(ivs.size() == static_cast(llvm::count_if( llvm::make_range(tileSizes.begin(), tileSizes.end()), [](Value v) { return !isZero(v); })) && @@ -921,7 +930,8 @@ LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subshape...\n"); tiledShapes.push_back(makeTiledShape(b, loc, shapedOp, tileSizes, map, lbs, - sizeBounds, subShapeSizes)); + sizeBounds, subShapeSizes, + omitPartialTileCheck)); } return tiledShapes; diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp --- a/mlir/lib/IR/AffineMap.cpp +++ b/mlir/lib/IR/AffineMap.cpp @@ -679,7 +679,7 @@ return AffineMap::get(map.getNumResults(), 0, seenExprs, map.getContext()); } -AffineMap mlir::inverseAndBroadcastProjectedPermuation(AffineMap map) { +AffineMap mlir::inverseAndBroadcastProjectedPermutation(AffineMap map) { assert(map.isProjectedPermutation(/*allowZeroInResults=*/true)); MLIRContext *context = map.getContext(); AffineExpr zero = mlir::getAffineConstantExpr(0, context); diff --git a/mlir/test/Dialect/Linalg/bubble-up-extract-slice-op.mlir b/mlir/test/Dialect/Linalg/bubble-up-extract-slice-op.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/bubble-up-extract-slice-op.mlir @@ -0,0 +1,158 @@ +//RUN: mlir-opt -test-linalg-transform-patterns=test-bubble-up-extract-slice-op-pattern -split-input-file %s | FileCheck %s + +func @dynamic(%arg0: tensor, %arg1: tensor, %arg2: index, %arg3: index, %arg4: index, %arg5:index) -> tensor { + %0 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d1)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"] + } ins(%arg0, %arg1 : tensor, tensor) + outs(%arg0 : tensor) { + ^bb0(%b0 : f32, %b1 : f32, %b2 : f32): + %add = arith.addf %b0, %b1 : f32 + linalg.yield %add : f32 + } -> tensor + %1 = tensor.extract_slice %0 [%arg2, %arg3] [%arg4, %arg5] [1, 1] + : tensor to tensor + return %1 : tensor +} + +// CHECK: func @dynamic +// CHECK: %[[SLICE0:.+]] = tensor.extract_slice %arg0[%arg2, %arg3] [%arg4, %arg5] [1, 1] : tensor to tensor +// CHECK: %[[SLICE1:.+]] = tensor.extract_slice %arg1[%arg3] [%arg5] [1] : tensor to tensor +// CHECK: %[[SLICE2:.+]] = tensor.extract_slice %arg0[%arg2, %arg3] [%arg4, %arg5] [1, 1] : tensor to tensor +// CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel"]} +// CHECK-SAME: ins(%[[SLICE0]], %[[SLICE1]] : tensor, tensor) outs(%[[SLICE2]] : tensor) +// CHECK: return %[[GENERIC]] : tensor + +//----- + +func @static(%arg0: tensor<16x8xf32>, %arg1: tensor<8xf32>) -> tensor<4x2xf32> { + %0 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d1)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"] + } ins(%arg0, %arg1 : tensor<16x8xf32>, tensor<8xf32>) + outs(%arg0 : tensor<16x8xf32>) { + ^bb0(%b0 : f32, %b1 : f32, %b2 : f32): + %add = arith.addf %b0, %b1 : f32 + linalg.yield %add : f32 + } -> tensor<16x8xf32> + %1 = tensor.extract_slice %0 [8, 4] [4, 2] [1, 1] + : tensor<16x8xf32> to tensor<4x2xf32> + return %1 : tensor<4x2xf32> +} + +// CHECK: func @static +// CHECK: %[[SLICE0:.+]] = tensor.extract_slice %arg0[8, 4] [4, 2] [1, 1] : tensor<16x8xf32> to tensor<4x2xf32> +// CHECK: %[[SLICE1:.+]] = tensor.extract_slice %arg1[4] [2] [1] : tensor<8xf32> to tensor<2xf32> +// CHECK: %[[SLICE2:.+]] = tensor.extract_slice %arg0[8, 4] [4, 2] [1, 1] : tensor<16x8xf32> to tensor<4x2xf32> +// CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel"]} +// CHECK-SAME: ins(%[[SLICE0]], %[[SLICE1]] : tensor<4x2xf32>, tensor<2xf32>) outs(%[[SLICE2]] : tensor<4x2xf32>) +// CHECK: return %[[GENERIC]] : tensor<4x2xf32> + +//----- + +func @mixed(%arg0: tensor, %arg1: tensor<8xf32>, %arg2: index, %arg3: index) -> tensor { + %0 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d1)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"] + } ins(%arg0, %arg1 : tensor, tensor<8xf32>) + outs(%arg0 : tensor) { + ^bb0(%b0 : f32, %b1 : f32, %b2 : f32): + %add = arith.addf %b0, %b1 : f32 + linalg.yield %add : f32 + } -> tensor + %1 = tensor.extract_slice %0 [8, %arg2] [%arg3, 2] [1, 1] + : tensor to tensor + return %1 : tensor +} + +// CHECK: func @mixed +// CHECK: %[[SLICE0:.+]] = tensor.extract_slice %arg0[8, %arg2] [%arg3, 2] [1, 1] : tensor to tensor +// CHECK: %[[SLICE1:.+]] = tensor.extract_slice %arg1[%arg2] [2] [1] : tensor<8xf32> to tensor<2xf32> +// CHECK: %[[SLICE2:.+]] = tensor.extract_slice %arg0[8, %arg2] [%arg3, 2] [1, 1] : tensor to tensor +// CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel"]} +// CHECK-SAME: ins(%[[SLICE0]], %[[SLICE1]] : tensor, tensor<2xf32>) outs(%[[SLICE2]] : tensor) +// CHECK: return %[[GENERIC]] : tensor + +//----- + +func @dynamic_to_static(%arg0: tensor, %arg1: tensor) -> tensor<4x2xf32> { + %0 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d1)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"] + } ins(%arg0, %arg1 : tensor, tensor) + outs(%arg0 : tensor) { + ^bb0(%b0 : f32, %b1 : f32, %b2 : f32): + %add = arith.addf %b0, %b1 : f32 + linalg.yield %add : f32 + } -> tensor + %1 = tensor.extract_slice %0 [8, 4] [4, 2] [1, 1] + : tensor to tensor<4x2xf32> + return %1 : tensor<4x2xf32> +} + +// CHECK: func @dynamic_to_static +// CHECK: %[[SLICE0:.+]] = tensor.extract_slice %arg0[8, 4] [4, 2] [1, 1] : tensor to tensor<4x2xf32> +// CHECK: %[[SLICE1:.+]] = tensor.extract_slice %arg1[4] [2] [1] : tensor to tensor<2xf32> +// CHECK: %[[SLICE2:.+]] = tensor.extract_slice %arg0[8, 4] [4, 2] [1, 1] : tensor to tensor<4x2xf32> +// CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel"]} +// CHECK-SAME: ins(%[[SLICE0]], %[[SLICE1]] : tensor<4x2xf32>, tensor<2xf32>) outs(%[[SLICE2]] : tensor<4x2xf32>) +// CHECK: return %[[GENERIC]] : tensor<4x2xf32> + +//----- + +func @matmul_slice() -> tensor<2x2xf32> { + %lhs = arith.constant dense<1.0> : tensor<4x4xf32> + %rhs = arith.constant dense<1.0> : tensor<4x4xf32> + %dst = arith.constant dense<[[0.0, 1.0, 2.0, 3.0], [4.0, 5.0, 6.0, 7.0], [8.0, 9.0, 10.0, 11.0], [12.0, 13.0, 14.0, 15.0]]> : tensor<4x4xf32> + %0 = linalg.matmul ins(%lhs, %rhs : tensor<4x4xf32>, tensor<4x4xf32>) outs(%dst : tensor<4x4xf32>) -> tensor<4x4xf32> + %1 = tensor.extract_slice %0[1,1][2,2][1,1] : tensor<4x4xf32> to tensor<2x2xf32> + return %1 : tensor<2x2xf32> +} + +// CHECK: func @matmul_slice +// CHECK: %[[SLICE0:.+]] = arith.constant dense<1.000000e+00> : tensor<2x4xf32> +// CHECK: %[[SLICE1:.+]] = arith.constant dense<1.000000e+00> : tensor<4x2xf32> +// CHECK: %[[SLICE3:.+]] = tensor.extract_slice %[[CST:.+]][1, 1] [2, 2] [1, 1] : tensor<4x4xf32> to tensor<2x2xf32> +// CHECK: %[[MATMUL:.+]] = linalg.matmul ins(%[[SLICE0]], %[[SLICE1]] : tensor<2x4xf32>, tensor<4x2xf32>) outs(%[[SLICE3]] : tensor<2x2xf32>) -> tensor<2x2xf32> +// CHECK: return %[[MATMUL]] : tensor<2x2xf32> + +//----- + +func @conv_slice(%input: tensor<1x225x225x3xf32>, %filter: tensor<3x3x3x32xf32>) -> tensor<1x32x32x16xf32> { + %c112 = arith.constant 112 : index + %c32 = arith.constant 32 : index + %c16 = arith.constant 16 : index + %c8 = arith.constant 8 : index + %c4 = arith.constant 4 : index + %c0 = arith.constant 0 : index + %cst = arith.constant 0.0 : f32 + + %init = linalg.init_tensor [1, 112, 112, 32] : tensor<1x112x112x32xf32> + %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<1x112x112x32xf32>) -> tensor<1x112x112x32xf32> + + %conv = linalg.conv_2d_nhwc_hwcf + {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} + ins(%input, %filter : tensor<1x225x225x3xf32>, tensor<3x3x3x32xf32>) + outs(%fill : tensor<1x112x112x32xf32>) -> tensor<1x112x112x32xf32> + + %slice = tensor.extract_slice %conv [0, 64, 64, 16] [1, 32, 32, 16] [1, 1, 1, 1] : tensor<1x112x112x32xf32> to tensor<1x32x32x16xf32> + + return %slice : tensor<1x32x32x16xf32> +} + +// CHECK: func @conv_slice +// CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 112, 112, 32] : tensor<1x112x112x32xf32> +// CHECK: %[[SLICE0:.+]] = tensor.extract_slice %arg0[0, 128, 128, 0] [1, 65, 65, 3] [1, 1, 1, 1] : tensor<1x225x225x3xf32> to tensor<1x65x65x3xf32> +// CHECK: %[[SLICE1:.+]] = tensor.extract_slice %arg1[0, 0, 0, 16] [3, 3, 3, 16] [1, 1, 1, 1] : tensor<3x3x3x32xf32> to tensor<3x3x3x16xf32> +// CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[INIT]][0, 64, 64, 16] [1, 32, 32, 16] [1, 1, 1, 1] : tensor<1x112x112x32xf32> to tensor<1x32x32x16xf32> +// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CST:.+]] : f32) outs(%[[SLICE2]] : tensor<1x32x32x16xf32>) -> tensor<1x32x32x16xf32> +// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%[[SLICE0]], %[[SLICE1]] : tensor<1x65x65x3xf32>, tensor<3x3x3x16xf32>) outs(%[[FILL]] : tensor<1x32x32x16xf32>) -> tensor<1x32x32x16xf32> +// CHECK: return %[[CONV]] : tensor<1x32x32x16xf32> diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -132,6 +132,11 @@ llvm::cl::desc("Specify the type of loops to generate: for, parallel or " "tiled_loop"), llvm::cl::init("for")}; + Option testBubbleUpExtractSliceOpPattern{ + *this, "test-bubble-up-extract-slice-op-pattern", + llvm::cl::desc("Test rewrite of linalgOp + extract_slice into " + "extract_slice + linalgOp"), + llvm::cl::init(false)}; }; } // namespace @@ -635,6 +640,12 @@ (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); } +static void applyBubbleUpExtractSliceOpPattern(FuncOp funcOp) { + RewritePatternSet patterns(funcOp.getContext()); + populateBubbleUpExtractSliceOpPatterns(patterns); + (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); +} + /// Apply transformations specified as patterns. void TestLinalgTransforms::runOnOperation() { auto lambda = [&](void *) { @@ -686,6 +697,8 @@ /*peeledLoops=*/{}, /*scalarizeDynamicDims=*/true); if (testSplitReduction) return applySplitReduction(getOperation()); + if (testBubbleUpExtractSliceOpPattern) + return applyBubbleUpExtractSliceOpPattern(getOperation()); } namespace mlir {