diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -36,6 +36,8 @@ createLinalgPromotionPass(bool dynamicBuffers, bool useAlloca); std::unique_ptr> createLinalgPromotionPass(); +std::unique_ptr> createLinalgInlineScalarOperandsPass(); + /// Create a pass to convert Linalg tiled loops to `scf.for` and `scf.parallel` /// loops and memref.load/memref.store accesses. std::unique_ptr> createConvertLinalgTiledLoopsToSCFPass(); diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -84,6 +84,14 @@ ]; } +def LinalgInlineScalarOperands : FunctionPass<"linalg-inline-scalar-operands"> { + let summary = "Inline scalar operands into linalg generic ops"; + let constructor = "mlir::createLinalgInlineScalarOperandsPass()"; + let dependentDialects = [ + "linalg::LinalgDialect" + ]; +} + def LinalgLowerToAffineLoops : FunctionPass<"convert-linalg-to-affine-loops"> { let summary = "Lower the operations from the linalg dialect into affine " "loops"; diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -78,6 +78,9 @@ /// tensors. void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns); +/// Patterns that are used to inline constant operands into linalg generic ops. +void populateInlineConstantOperandsPatterns(RewritePatternSet &patterns); + /// Options that control fusion of elementwise operations. struct LinalgElementwiseFusionOptions { /// Enable fusion of reshapes into the shape with elementwise operations. By diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h --- a/mlir/include/mlir/IR/AffineMap.h +++ b/mlir/include/mlir/IR/AffineMap.h @@ -135,10 +135,17 @@ /// Returns true if this affine map is a single result constant function. bool isSingleConstant() const; + /// Returns true if this affine map has only constant results. + bool isConstant() const; + /// Returns the constant result of this map. This methods asserts that the map /// has a single constant result. int64_t getSingleConstantResult() const; + /// Returns the constant results of this map. This method asserts that the map + /// has all constant results. + SmallVector getConstantResults() const; + // Prints affine map to 'os'. void print(raw_ostream &os) const; void dump() const; diff --git a/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp b/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp @@ -0,0 +1,116 @@ +//===- InlineScalarOperands.cpp - Pass to inline scalar operands =============// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements patterns/pass to inline scalar operands into a generic +// operation. A scalar operand is an opernand whose indexing map has a constant +// rhs. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Linalg/IR/LinalgTypes.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/Transforms/FoldUtils.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "linalg-drop-unit-dims" + +using namespace mlir; +using namespace mlir::linalg; + +namespace { +struct InlineScalarOperands : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(GenericOp genericOp, + PatternRewriter &rewriter) const override { + if (!genericOp.hasTensorSemantics()) + return failure(); + + SmallVector scalarOperands; + SmallVector newIndexingMaps; + SmallVector newOperands; + for (auto it : llvm::enumerate(llvm::zip(genericOp.getInputIndexingMaps(), + genericOp.getInputTensors()))) { + AffineMap map = std::get<0>(it.value()); + if (map.isConstant()) { + scalarOperands.emplace_back(it.index()); + } else { + newIndexingMaps.emplace_back(map); + newOperands.emplace_back(std::get<1>(it.value())); + } + } + + if (scalarOperands.empty()) + return failure(); + + newIndexingMaps.append(genericOp.getOutputIndexingMaps()); + + Location loc = genericOp->getLoc(); + auto newOp = rewriter.create( + loc, genericOp->getResultTypes(), newOperands, + genericOp.getOutputTensors(), newIndexingMaps, + llvm::to_vector<4>( + genericOp.iterator_types().template getAsValueRange())); + rewriter.cloneRegionBefore(genericOp.region(), newOp.region(), + newOp.region().begin()); + + Block *body = newOp.getBody(); + PatternRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(body); + + for (auto idx : llvm::reverse(scalarOperands)) { + Value operand = genericOp.getInput(idx); + AffineMap map = genericOp.getInputIndexingMap(idx); + SmallVector indices = map.getConstantResults(); + SmallVector indicesValues; + for (auto idx : indices) + indicesValues.emplace_back(rewriter.create(loc, idx)); + operand = rewriter.create(loc, operand, indicesValues); + body->getArgument(idx).replaceAllUsesWith(operand); + body->eraseArgument(idx); + } + + rewriter.replaceOp(genericOp, newOp->getResults()); + return success(); + } +}; +} // namespace + +/// Patterns that are used to inline constant operands into linalg generic +/// ops. +void mlir::linalg::populateInlineConstantOperandsPatterns( + RewritePatternSet &patterns) { + auto *context = patterns.getContext(); + patterns.add(context); +} + +namespace { +/// Pass that removes unit-extent dims within generic ops. +struct LinalgInlineScalarOperandsPass + : public LinalgInlineScalarOperandsBase { + void runOnFunction() override { + FuncOp funcOp = getFunction(); + MLIRContext *context = funcOp.getContext(); + RewritePatternSet patterns(context); + + populateInlineConstantOperandsPatterns(patterns); + (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns)); + } +}; +} // namespace + +std::unique_ptr> +mlir::createLinalgInlineScalarOperandsPass() { + return std::make_unique(); +} diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp --- a/mlir/lib/IR/AffineMap.cpp +++ b/mlir/lib/IR/AffineMap.cpp @@ -287,11 +287,25 @@ return getNumResults() == 1 && getResult(0).isa(); } +bool AffineMap::isConstant() const { + return llvm::all_of(getResults(), [](AffineExpr expr) { + return expr.isa(); + }); +} + int64_t AffineMap::getSingleConstantResult() const { assert(isSingleConstant() && "map must have a single constant result"); return getResult(0).cast().getValue(); } +SmallVector AffineMap::getConstantResults() const { + assert(isConstant() && "map must have only constant results"); + SmallVector result; + for (auto expr : getResults()) + result.emplace_back(expr.cast().getValue()); + return result; +} + unsigned AffineMap::getNumDims() const { assert(map && "uninitialized map storage"); return map->numDims; diff --git a/mlir/test/Dialect/Linalg/inline-scalar-operands.mlir b/mlir/test/Dialect/Linalg/inline-scalar-operands.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/inline-scalar-operands.mlir @@ -0,0 +1,38 @@ +// RUN: mlir-opt %s -linalg-inline-scalar-operands -split-input-file | FileCheck %s + +#map0 = affine_map<(d0, d1) -> (d0, d1)> +#map1 = affine_map<(d0, d1) -> (d0)> +#map2 = affine_map<(d0) -> (d0)> +#map3 = affine_map<(d0) -> ()> +// CHECK-LABEL: mean_dynamic +module attributes {tf.versions = {producer = 0 : i32}} { + func @mean_dynamic(%arg0: tensor<4x?xf16>) -> tensor<4x1xf16> { + %c1 = constant 1 : index + %cst = constant 0.000000e+00 : f32 + %0 = memref.dim %arg0, %c1 : tensor<4x?xf16> + %1 = linalg.init_tensor [4] : tensor<4xf32> + %2 = linalg.fill(%1, %cst) : tensor<4xf32>, f32 -> tensor<4xf32> + // CHECK: linalg.generic + %3 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<4x?xf16>) outs(%2 : tensor<4xf32>) { + ^bb0(%arg1: f16, %arg2: f32): // no predecessors + %10 = fpext %arg1 : f16 to f32 + %11 = addf %10, %arg2 : f32 + linalg.yield %11 : f32 + } -> tensor<4xf32> + %4 = index_cast %0 : index to i64 + %5 = tensor.from_elements %4 : tensor<1xi64> + %6 = linalg.tensor_reshape %5 [] : tensor<1xi64> into tensor + %7 = linalg.init_tensor [4] : tensor<4xf16> + // CHECK: linalg.generic {indexing_maps = [#map{{.?}}, #map{{.?}}], iterator_types = ["parallel"]} ins(%{{.?}} : tensor<4xf32>) outs + %8 = linalg.generic {indexing_maps = [#map2, #map3, #map2], iterator_types = ["parallel"]} ins(%3, %6 : tensor<4xf32>, tensor) outs(%7 : tensor<4xf16>) { + ^bb0(%arg1: f32, %arg2: i64, %arg3: f16): // no predecessors + %10 = sitofp %arg2 : i64 to f32 + %11 = divf %arg1, %10 : f32 + %12 = fptrunc %11 : f32 to f16 + linalg.yield %12 : f16 + } -> tensor<4xf16> + %9 = linalg.tensor_reshape %8 [[0, 1]] : tensor<4xf16> into tensor<4x1xf16> + return %9 : tensor<4x1xf16> + } +} +