diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -55,6 +55,10 @@ void populateDecomposeLinalgOpsPattern(RewritePatternSet &patterns, bool removeDeadArgsAndResults = true); +/// Populate patterns that convert non-destination-style ops to destination +/// style ops. +void populateConvertToDestinationStylePatterns(RewritePatternSet &patterns); + /// Populate patterns for vectorizing low-D convolution ops. This is a step in /// progressive lowering for convolution ops, it assume high-D convolution ops /// were decomposed previously. diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -3,6 +3,7 @@ BufferizableOpInterfaceImpl.cpp Bufferize.cpp ConstantFold.cpp + ConvertToDestinationStyle.cpp DataLayoutPropagation.cpp DecomposeLinalgOps.cpp Detensorize.cpp diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp @@ -0,0 +1,78 @@ +//===- ConvertToDestinationStyle.cpp - Convert non-DPS to DPS ops ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains patterns to convert non-DPS ops to DPS ops. New +// tensor.empty ops are inserted as a destination. Such tensor.empty can be +// eliminated with "empty tensor elimination", allowing them to bufferize +// without an allocation (assuming there are no further conflicts). +// +//===----------------------------------------------------------------------===// +// +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/PatternMatch.h" +#include "llvm/Support/Debug.h" + +using namespace mlir; +using namespace mlir::tensor; + +namespace { + +/// Lower tensor.generate to linalg.generic. +struct GenerateOpConverter : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(GenerateOp generateOp, + PatternRewriter &rewriter) const override { + // Only ops with exactly one block are supported. + if (!generateOp.getBody().hasOneBlock()) + return failure(); + + Location loc = generateOp.getLoc(); + RankedTensorType tensorType = generateOp.getType().cast(); + + // Create tensor.empty. + auto emptyOp = rewriter.create(loc, tensorType, + generateOp.getDynamicExtents()); + + // Create linalg.generic. + SmallVector iteratorTypes( + tensorType.getRank(), utils::IteratorType::parallel); + SmallVector indexingMaps( + 1, rewriter.getMultiDimIdentityMap(tensorType.getRank())); + auto genericOp = rewriter.create( + loc, tensorType, /*inputs=*/ValueRange(), + /*outputs=*/ValueRange{emptyOp.getResult()}, /*indexingMaps=*/ + indexingMaps, iteratorTypes); + Block *body = rewriter.createBlock(&genericOp->getRegion(0), {}, + tensorType.getElementType(), loc); + rewriter.setInsertionPointToStart(body); + SmallVector bbArgReplacements; + for (int64_t i = 0; i < tensorType.getRank(); ++i) + bbArgReplacements.push_back(rewriter.create(loc, i)); + rewriter.mergeBlocks(&generateOp.getBody().front(), body, + bbArgReplacements); + + // Update terminator. + auto yieldOp = cast(body->getTerminator()); + rewriter.replaceOpWithNewOp(yieldOp, yieldOp.getValue()); + + // Replace tensor.generate. + rewriter.replaceOp(generateOp, genericOp->getResult(0)); + return success(); + } +}; + +} // namespace + +void linalg::populateConvertToDestinationStylePatterns( + RewritePatternSet &patterns) { + patterns.insert(patterns.getContext()); +} diff --git a/mlir/test/Dialect/Linalg/convert-to-destination-style.mlir b/mlir/test/Dialect/Linalg/convert-to-destination-style.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/convert-to-destination-style.mlir @@ -0,0 +1,23 @@ +// RUN: mlir-opt -split-input-file -test-linalg-transform-patterns=test-convert-to-destination-style-patterns %s | FileCheck %s + +// CHECK: #[[$map:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func @tensor_generate( +// CHECK-SAME: %[[s1:.*]]: index, %[[s2:.*]]: index +// CHECK: %[[empty:.*]] = tensor.empty(%[[s1]], %[[s2]]) : tensor +// CHECK: %[[generic:.*]] = linalg.generic +// CHECK-SAME: {indexing_maps = [#[[$map]]], iterator_types = ["parallel", "parallel"]} +// CHECK-SAME: outs(%[[empty]] : tensor) { +// CHECK: %[[i0:.*]] = linalg.index 0 +// CHECK: %[[i1:.*]] = linalg.index 1 +// CHECK: %[[added:.*]] = arith.addi %[[i0]], %[[i1]] +// CHECK: linalg.yield %[[added]] +// CHECK: } +// CHECK: return %[[generic]] +func.func @tensor_generate(%s1: index, %s2: index) -> tensor { + %0 = tensor.generate %s1, %s2 { + ^bb0(%arg0: index, %arg1: index): + %1 = arith.addi %arg0, %arg1 : index + tensor.yield %1 : index + } : tensor + return %0 : tensor +} diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -128,6 +128,10 @@ *this, "test-erase-unnecessary-inputs", llvm::cl::desc("Test patterns to erase unnecessary inputs"), llvm::cl::init(false)}; + Option testConvertToDestinationStylePatterns{ + *this, "test-convert-to-destination-style-patterns", + llvm::cl::desc("Test patterns that convert ops to destination style"), + llvm::cl::init(false)}; }; } // namespace @@ -218,6 +222,12 @@ (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); } +static void applyConvertToDestinationStylePatterns(Operation *rootOp) { + RewritePatternSet patterns(rootOp->getContext()); + populateConvertToDestinationStylePatterns(patterns); + (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); +} + /// Apply transformations specified as patterns. void TestLinalgTransforms::runOnOperation() { if (testPatterns) @@ -244,6 +254,8 @@ return applyEraseUnusedOperandsAndResultsPatterns(getOperation()); if (testEraseUnnecessaryInputs) return applyEraseUnnecessaryInputs(getOperation()); + if (testConvertToDestinationStylePatterns) + applyConvertToDestinationStylePatterns(getOperation()); } namespace mlir {