diff --git a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td --- a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td +++ b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td @@ -87,6 +87,17 @@ let assemblyFormat = "attr-dict"; } +def ApplyRewriteTensorOpsAsConstantPatternsOp : Op]> { + let description = [{ + Indicates that tensor ops (such as tensor.generate) should be replaced with + constants (arith.constant) when possible. + }]; + + let assemblyFormat = "attr-dict"; +} + def Transform_TensorPadOp : Transform_ConcreteOpType<"tensor.pad">; def MakeLoopIndependentOp diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h @@ -67,6 +67,10 @@ /// respectively. void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns); +/// Populates `patterns` with patterns that replace tensor ops (such as +/// tensor.generate) with constants when possible. +void populateRewriteAsConstantPatterns(RewritePatternSet &patterns); + //===----------------------------------------------------------------------===// // Transform helpers //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp --- a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp +++ b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp @@ -113,6 +113,11 @@ tensor::populateReassociativeReshapeFoldingPatterns(patterns); } +void transform::ApplyRewriteTensorOpsAsConstantPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + tensor::populateRewriteAsConstantPatterns(patterns); +} + //===----------------------------------------------------------------------===// // MakeLoopIndependentOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt @@ -8,6 +8,7 @@ IndependenceTransforms.cpp MergeConsecutiveInsertExtractSlicePatterns.cpp ReshapePatterns.cpp + RewriteAsConstant.cpp SwapExtractSliceWithProducerPatterns.cpp ADDITIONAL_HEADER_DIRS diff --git a/mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp b/mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp @@ -0,0 +1,53 @@ +//===- RewriteAsConstant.cpp - Patterns to rewrite tensor ops as constants ===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Transforms/Transforms.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" + +using namespace mlir; +using namespace mlir::tensor; + +namespace { + +/// Rewrite tensor.generate with arith.constant if the yielded value is a +/// constant and the tensor type is static. +struct GenerateToConstant : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(GenerateOp generateOp, + PatternRewriter &rewriter) const override { + auto tensorType = + llvm::cast(generateOp.getResult().getType()); + if (!tensorType.hasStaticShape()) + return failure(); + auto terminatorOp = + cast(generateOp.getBody().front().getTerminator()); + Attribute attr; + if (!matchPattern(terminatorOp.getValue(), m_Constant(&attr))) + return failure(); + Operation *constantOp = + rewriter.getContext() + ->getLoadedDialect() + ->materializeConstant(rewriter, + DenseElementsAttr::get(tensorType, attr), + tensorType, generateOp->getLoc()); + if (!constantOp) + return failure(); + rewriter.replaceOp(generateOp, constantOp->getResults()); + return success(); + } +}; + +} // namespace + +void mlir::tensor::populateRewriteAsConstantPatterns( + RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} diff --git a/mlir/test/Dialect/Tensor/rewrite-as-constant.mlir b/mlir/test/Dialect/Tensor/rewrite-as-constant.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Tensor/rewrite-as-constant.mlir @@ -0,0 +1,20 @@ +// RUN: mlir-opt -split-input-file -test-transform-dialect-interpreter %s | FileCheck %s + +transform.sequence failures(propagate) { +^bb1(%module_op: !transform.any_op): + transform.apply_patterns to %module_op { + transform.apply_patterns.tensor.rewrite_as_constant + } : !transform.any_op +} + +// CHECK-LABEL: func @tensor_generate_constant( +// CHECK: %[[cst:.*]] = arith.constant dense<5.000000e+00> : tensor<2x3x5xf32> +// CHECK: return %[[cst]] +func.func @tensor_generate_constant() -> tensor<2x3x5xf32> { + %cst = arith.constant 5.0 : f32 + %0 = tensor.generate { + ^bb0(%arg0: index, %arg1: index, %arg2: index): + tensor.yield %cst : f32 + } : tensor<2x3x5xf32> + return %0 : tensor<2x3x5xf32> +}