diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -36,6 +36,8 @@ createLinalgPromotionPass(bool dynamicBuffers, bool useAlloca); std::unique_ptr> createLinalgPromotionPass(); +std::unique_ptr> createLinalgInlineScalarOperandsPass(); + /// Create a pass to convert Linalg tiled loops to `scf.for` and `scf.parallel` /// loops and memref.load/memref.store accesses. std::unique_ptr> createConvertLinalgTiledLoopsToSCFPass(); diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -84,6 +84,14 @@ ]; } +def LinalgInlineScalarOperands : FunctionPass<"linalg-inline-scalar-operands"> { + let summary = "Inline scalar operands into linalg generic ops"; + let constructor = "mlir::createLinalgInlineScalarOperandsPass()"; + let dependentDialects = [ + "linalg::LinalgDialect" + ]; +} + def LinalgLowerToAffineLoops : FunctionPass<"convert-linalg-to-affine-loops"> { let summary = "Lower the operations from the linalg dialect into affine " "loops"; diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -78,6 +78,9 @@ /// tensors. void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns); +/// Patterns that are used to inline constant operands into linalg generic ops. +void populateInlineConstantOperandsPatterns(RewritePatternSet &patterns); + /// Options that control fusion of elementwise operations. struct LinalgElementwiseFusionOptions { /// Enable fusion of reshapes into the shape with elementwise operations. By diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h --- a/mlir/include/mlir/IR/AffineMap.h +++ b/mlir/include/mlir/IR/AffineMap.h @@ -135,10 +135,17 @@ /// Returns true if this affine map is a single result constant function. bool isSingleConstant() const; + /// Returns true if this affine map has only constant results. + bool isConstant() const; + /// Returns the constant result of this map. This methods asserts that the map /// has a single constant result. int64_t getSingleConstantResult() const; + /// Returns the constant results of this map. This method asserts that the map + /// has all constant results. + SmallVector getConstantResults() const; + // Prints affine map to 'os'. void print(raw_ostream &os) const; void dump() const; diff --git a/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp b/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp @@ -0,0 +1,110 @@ +//===- InlineScalarOperands.cpp - Pass to inline scalar operands =============// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements patterns/pass to inline scalar operands into a generic +// operation. A scalar operand is an opernand whose indexing map has a constant +// rhs. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using namespace mlir::linalg; + +namespace { +struct InlineScalarOperands : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(GenericOp genericOp, + PatternRewriter &rewriter) const override { + if (!genericOp.hasTensorSemantics()) + return failure(); + + SmallVector scalarOperands; + SmallVector newIndexingMaps; + SmallVector newOperands; + for (auto it : llvm::enumerate(llvm::zip(genericOp.getInputIndexingMaps(), + genericOp.getInputTensors()))) { + AffineMap map = std::get<0>(it.value()); + if (map.isConstant()) { + scalarOperands.emplace_back(it.index()); + } else { + newIndexingMaps.emplace_back(map); + newOperands.emplace_back(std::get<1>(it.value())); + } + } + + if (scalarOperands.empty()) + return failure(); + + newIndexingMaps.append(genericOp.getOutputIndexingMaps()); + + Location loc = genericOp->getLoc(); + auto newOp = rewriter.create( + loc, genericOp->getResultTypes(), newOperands, + genericOp.getOutputTensors(), newIndexingMaps, + llvm::to_vector<4>( + genericOp.iterator_types().template getAsValueRange())); + rewriter.cloneRegionBefore(genericOp.region(), newOp.region(), + newOp.region().begin()); + + Block *body = newOp.getBody(); + PatternRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(body); + + for (auto idx : llvm::reverse(scalarOperands)) { + Value operand = genericOp.getInput(idx); + AffineMap map = genericOp.getInputIndexingMap(idx); + SmallVector indices = map.getConstantResults(); + SmallVector indicesValues; + for (auto idx : indices) + indicesValues.emplace_back(rewriter.create(loc, idx)); + operand = rewriter.create(loc, operand, indicesValues); + body->getArgument(idx).replaceAllUsesWith(operand); + body->eraseArgument(idx); + } + + rewriter.replaceOp(genericOp, newOp->getResults()); + return success(); + } +}; +} // namespace + +/// Patterns that are used to inline constant operands into linalg generic +/// ops. +void mlir::linalg::populateInlineConstantOperandsPatterns( + RewritePatternSet &patterns) { + auto *context = patterns.getContext(); + patterns.add(context); +} + +namespace { +/// Pass that removes unit-extent dims within generic ops. +struct LinalgInlineScalarOperandsPass + : public LinalgInlineScalarOperandsBase { + void runOnFunction() override { + FuncOp funcOp = getFunction(); + MLIRContext *context = funcOp.getContext(); + RewritePatternSet patterns(context); + + populateInlineConstantOperandsPatterns(patterns); + (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns)); + } +}; +} // namespace + +std::unique_ptr> +mlir::createLinalgInlineScalarOperandsPass() { + return std::make_unique(); +} diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp --- a/mlir/lib/IR/AffineMap.cpp +++ b/mlir/lib/IR/AffineMap.cpp @@ -287,11 +287,25 @@ return getNumResults() == 1 && getResult(0).isa(); } +bool AffineMap::isConstant() const { + return llvm::all_of(getResults(), [](AffineExpr expr) { + return expr.isa(); + }); +} + int64_t AffineMap::getSingleConstantResult() const { assert(isSingleConstant() && "map must have a single constant result"); return getResult(0).cast().getValue(); } +SmallVector AffineMap::getConstantResults() const { + assert(isConstant() && "map must have only constant results"); + SmallVector result; + for (auto expr : getResults()) + result.emplace_back(expr.cast().getValue()); + return result; +} + unsigned AffineMap::getNumDims() const { assert(map && "uninitialized map storage"); return map->numDims; diff --git a/mlir/test/Dialect/Linalg/inline-scalar-operands.mlir b/mlir/test/Dialect/Linalg/inline-scalar-operands.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/inline-scalar-operands.mlir @@ -0,0 +1,48 @@ +// RUN: mlir-opt %s -linalg-inline-scalar-operands -split-input-file | FileCheck %s + +// CHECK: #[[MAP:.*]] = affine_map<(d0) -> (d0)> +#map2 = affine_map<(d0) -> (d0)> +#map3 = affine_map<(d0) -> ()> + +// CHECK: func @inline_zerod(%[[ARG:.*]]: tensor<4xf32>, %[[SCALAR:.*]]: tensor) +func @inline_zerod(%arg0: tensor<4xf32>, %scalar: tensor) -> tensor<4xf32> { + %0 = linalg.init_tensor [4] : tensor<4xf32> + // CHECK: linalg.generic {indexing_maps = [#[[MAP]], #[[MAP]]], + // CHECK-SAME: iterator_types = ["parallel"]} ins(%[[ARG]] : tensor<4xf32>) + %1 = linalg.generic {indexing_maps = [#map2, #map3, #map2], + iterator_types = ["parallel"]} + ins(%arg0, %scalar : tensor<4xf32>, tensor) + outs(%0 : tensor<4xf32>) { + // CHECK: ^bb0(%{{.*}}: f32, %{{.*}}: f32) + ^bb0(%arg1: f32, %arg2: f32, %arg3: f32): // no predecessors + // CHECK: tensor.extract %[[SCALAR]][] + %2 = divf %arg1, %arg2 : f32 + linalg.yield %2 : f32 + } -> tensor<4xf32> + return %1 : tensor<4xf32> +} + +// ----- + +// CHECK: #[[MAP:.*]] = affine_map<(d0) -> (d0)> +#map2 = affine_map<(d0) -> (d0)> +#map3 = affine_map<(d0) -> (0)> + +// CHECK: func @inline_oned(%[[ARG:.*]]: tensor<4xf32>, %[[SCALAR:.*]]: tensor<1xf32>) +func @inline_oned(%arg0: tensor<4xf32>, %scalar: tensor<1xf32>) -> tensor<4xf32> { + // CHECK: %[[ZERO:.*]] = constant 0 : index + %0 = linalg.init_tensor [4] : tensor<4xf32> + // CHECK: linalg.generic {indexing_maps = [#[[MAP]], #[[MAP]]], + // CHECK-SAME: iterator_types = ["parallel"]} ins(%[[ARG]] : tensor<4xf32>) + %1 = linalg.generic {indexing_maps = [#map2, #map3, #map2], + iterator_types = ["parallel"]} + ins(%arg0, %scalar : tensor<4xf32>, tensor<1xf32>) + outs(%0 : tensor<4xf32>) { + // CHECK: ^bb0(%{{.*}}: f32, %{{.*}}: f32) + ^bb0(%arg1: f32, %arg2: f32, %arg3: f32): // no predecessors + // CHECK: tensor.extract %[[SCALAR]][%[[ZERO]]] + %2 = divf %arg1, %arg2 : f32 + linalg.yield %2 : f32 + } -> tensor<4xf32> + return %1 : tensor<4xf32> +}