diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h @@ -49,6 +49,10 @@ /// respectively. void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns); +/// Populates `patterns` with patterns what convert non-destination-style ops +/// to destination style ops. +void populateConvertToDestinationStylePatterns(RewritePatternSet &patterns); + } // namespace tensor } // namespace mlir diff --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_dialect_library(MLIRTensorTransforms BufferizableOpInterfaceImpl.cpp Bufferize.cpp + ConvertToDestinationStyle.cpp EmptyOpPatterns.cpp ExtractSliceFromReshapeUtils.cpp FoldIntoPackAndUnpackPatterns.cpp diff --git a/mlir/lib/Dialect/Tensor/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Tensor/Transforms/ConvertToDestinationStyle.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Tensor/Transforms/ConvertToDestinationStyle.cpp @@ -0,0 +1,78 @@ +//===- ConvertToDestinationStyle.cpp - Convert non-DPS to DPS ops ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains patterns to convert non-DPS ops to DPS ops. New +// tensor.empty ops are inserted as a destination. Such tensor.empty can be +// eliminated with "empty tensor elimination", allowing them to bufferize +// without an allocation (assuming there are no further conflicts). +// +//===----------------------------------------------------------------------===// +// +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Transforms/Transforms.h" +#include "mlir/IR/PatternMatch.h" +#include "llvm/Support/Debug.h" + +using namespace mlir; +using namespace mlir::tensor; + +namespace { + +/// Lower tensor.generate to linalg.generic. +struct GenerateOpConverter : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(GenerateOp generateOp, + PatternRewriter &rewriter) const override { + // Only ops with exactly one block are supported. + if (!generateOp.getBody().hasOneBlock()) + return failure(); + + Location loc = generateOp.getLoc(); + RankedTensorType tensorType = generateOp.getType().cast(); + + // Create tensor.empty. + auto emptyOp = rewriter.create(loc, tensorType, + generateOp.getDynamicExtents()); + + // Create linalg.generic. + SmallVector iteratorTypes( + tensorType.getRank(), utils::IteratorType::parallel); + SmallVector indexingMaps( + 1, rewriter.getMultiDimIdentityMap(tensorType.getRank())); + auto genericOp = rewriter.create( + loc, tensorType, /*inputs=*/ValueRange(), + /*outputs=*/ValueRange{emptyOp.getResult()}, /*indexingMaps=*/ + indexingMaps, iteratorTypes); + Block *body = rewriter.createBlock(&genericOp->getRegion(0), {}, + tensorType.getElementType(), loc); + rewriter.setInsertionPointToStart(body); + SmallVector bbArgReplacements; + for (int64_t i = 0; i < tensorType.getRank(); ++i) + bbArgReplacements.push_back(rewriter.create(loc, i)); + rewriter.mergeBlocks(&generateOp.getBody().front(), body, + bbArgReplacements); + + // Update terminator. + auto yieldOp = cast(body->getTerminator()); + rewriter.replaceOpWithNewOp(yieldOp, yieldOp.getValue()); + + // Replace tensor.generate. + rewriter.replaceOp(generateOp, genericOp->getResult(0)); + return success(); + } +}; + +} // namespace + +void tensor::populateConvertToDestinationStylePatterns( + RewritePatternSet &patterns) { + patterns.insert(patterns.getContext()); +} diff --git a/mlir/test/Dialect/Tensor/convert-to-destination-style.mlir b/mlir/test/Dialect/Tensor/convert-to-destination-style.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Tensor/convert-to-destination-style.mlir @@ -0,0 +1,23 @@ +// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-convert-to-destination-style-patterns %s | FileCheck %s + +// CHECK: #[[$map:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func @tensor_generate( +// CHECK-SAME: %[[s1:.*]]: index, %[[s2:.*]]: index +// CHECK: %[[empty:.*]] = tensor.empty(%[[s1]], %[[s2]]) : tensor +// CHECK: %[[generic:.*]] = linalg.generic +// CHECK-SAME: {indexing_maps = [#[[$map]]], iterator_types = ["parallel", "parallel"]} +// CHECK-SAME: outs(%[[empty]] : tensor) { +// CHECK: %[[i0:.*]] = linalg.index 0 +// CHECK: %[[i1:.*]] = linalg.index 1 +// CHECK: %[[added:.*]] = arith.addi %[[i0]], %[[i1]] +// CHECK: linalg.yield %[[added]] +// CHECK: } +// CHECK: return %[[generic]] +func.func @tensor_generate(%s1: index, %s2: index) -> tensor { + %0 = tensor.generate %s1, %s2 { + ^bb0(%arg0: index, %arg1: index): + %1 = arith.addi %arg0, %arg1 : index + tensor.yield %1 : index + } : tensor + return %0 : tensor +} diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp --- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp +++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp @@ -43,6 +43,11 @@ void runOnOperation() override; + Option testConvertToDestinationStylePatterns{ + *this, "test-convert-to-destination-style-patterns", + llvm::cl::desc("Test patterns that convert ops to destination style"), + llvm::cl::init(false)}; + Option testSplitPaddingPatterns{ *this, "test-split-padding-patterns", llvm::cl::desc("Test patterns to split tensor.pad ops"), @@ -93,6 +98,12 @@ }; } // namespace +static void applyConvertToDestinationStylePatterns(Operation *rootOp) { + RewritePatternSet patterns(rootOp->getContext()); + tensor::populateConvertToDestinationStylePatterns(patterns); + (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); +} + static void applyReassociativeReshapeFoldingPatterns(Operation *rootOp) { RewritePatternSet patterns(rootOp->getContext()); tensor::populateReassociativeReshapeFoldingPatterns(patterns); @@ -288,6 +299,8 @@ void TestTensorTransforms::runOnOperation() { Operation *rootOp = getOperation(); + if (testConvertToDestinationStylePatterns) + applyConvertToDestinationStylePatterns(rootOp); if (testSimplifyPackPatterns) applySimplifyPackPatterns(rootOp); if (testSplitPaddingPatterns)