diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -65,25 +65,31 @@ /// Function type which is used to control when to stop fusion. It is expected /// that OpOperand is not modified in the callback. The OpOperand is not marked /// as const to allow callers to use non-const methods. -using ControlElementwiseOpsFusionFn = +using ControlFusionFn = std::function; +/// Patterns for fusing linalg operation on tensors. + +/// Pattern to fuse `linalg.generic` -> `linalg.generic` operations +/// when both operations are fusable elementwise operations. +void populateElementwiseOpsFusionPatterns( + RewritePatternSet &patterns, + const ControlFusionFn &controlElementwiseOpFusion); + /// Patterns to fold an expanding (collapsing) tensor_reshape operation with its /// producer (consumer) generic operation by expanding the dimensionality of the /// loop in the generic op. void populateFoldReshapeOpsByExpansionPatterns( - RewritePatternSet &patterns, - const ControlElementwiseOpsFusionFn &controlFoldingReshapes = - skipUnitDimReshape); + RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes); /// Patterns to fold an expanding tensor.expand_shape operation with its /// producer generic operation by collapsing the dimensions of the generic op. void populateFoldReshapeOpsByCollapsingPatterns( - RewritePatternSet &patterns, - const ControlElementwiseOpsFusionFn &controlFoldingReshapes = - [](const OpResult & /*producer*/, OpOperand & /*consumer*/) { - return true; - }); + RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes); + +/// Patterns to constant fold Linalg operations. +void populateConstantFoldLinalgOperations(RewritePatternSet &patterns, + const ControlFusionFn &controlFn); /// Patterns to fold a collapsing (expanding) tensor_reshape operation with its /// producer (consumer) generic operation by linearizing the indexing map used @@ -122,39 +128,6 @@ /// Patterns that are used to bubble up extract slice op above linalg op. void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns); -/// Options that control fusion of elementwise operations. -struct LinalgElementwiseFusionOptions { - /// Enable fusion of reshapes into the shape with elementwise operations. By - /// default it is disabled for unit dimensions reshape. - ControlElementwiseOpsFusionFn controlFoldingReshapesFn = skipUnitDimReshape; - - LinalgElementwiseFusionOptions & - setControlFoldingReshapes(ControlElementwiseOpsFusionFn fun) { - controlFoldingReshapesFn = std::move(fun); - return *this; - } - - /// Function to allow the caller to control when to stop fusion. Once a - /// producer is deemed fusable with the consumer (structurally), this callback - /// can be used to abort the fusion based on non-structural constraints. This - /// is the hook for cost models to control the amount of fusion done. - ControlElementwiseOpsFusionFn controlElementwiseOpsFusionFn = - [](const OpResult & /*producer */, OpOperand & /*consumer */) { - return true; - }; - - LinalgElementwiseFusionOptions & - setControlElementwiseOpsFusionFn(ControlElementwiseOpsFusionFn fun) { - controlElementwiseOpsFusionFn = std::move(fun); - return *this; - } -}; - -/// Patterns for fusing linalg operation on tensors. -void populateElementwiseOpsFusionPatterns( - RewritePatternSet &patterns, - LinalgElementwiseFusionOptions options = LinalgElementwiseFusionOptions()); - /// Patterns to push reshape op towards the end of the graph in order to expose /// more fusion opportunities. /// TODO(ravishankarm): These patterns are to be deprecated in favor of using diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -4,6 +4,7 @@ Bufferize.cpp CodegenStrategy.cpp ComprehensiveBufferizePass.cpp + ConstantFold.cpp Detensorize.cpp DropUnitDims.cpp ElementwiseOpFusion.cpp diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp @@ -0,0 +1,308 @@ +//===- ConstantFold.cpp - Implementation of constant folding on Linalg ops ===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements constant folding on Linalg operations. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using namespace mlir::linalg; + +namespace { +/// Base class for constant folding linalg.generic ops with N inputs, 1 output, +/// and permutation indexing maps. +/// +/// `ConcreteType` should provide methods with signatures +/// +/// ```c++ +/// bool matchIndexingMaps(GenericOp genericOp) const; +/// RegionComputationFn getRegionComputeFn(GenericOp) const; +/// ``` +/// +/// The latter inspects the region and returns the computation inside as a +/// functor. The functor will be invoked with constant elements for all inputs +/// and should return the corresponding computed constant element for output. +template +class FoldConstantBase : public OpRewritePattern { +public: + struct APIntOrFloat { + Optional apInt; + Optional apFloat; + }; + struct APIntOrFloatArray { + SmallVector apInts; + SmallVector apFloats; + }; + using RegionComputationFn = + std::function; + + FoldConstantBase(MLIRContext *context, const ControlFusionFn &controlFn, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), controlFn(controlFn) {} + + LogicalResult matchAndRewrite(GenericOp genericOp, + PatternRewriter &rewriter) const override { + if (genericOp.hasBufferSemantics()) + return failure(); + + // Only support ops generating one output for now. + if (genericOp.getNumOutputs() != 1) + return failure(); + + auto outputType = genericOp.getResultTypes().front().dyn_cast(); + // Require the output types to be static given that we are generating + // constants. + if (!outputType || !outputType.hasStaticShape()) + return failure(); + + if (!llvm::all_of(genericOp.getInputOperands(), [](OpOperand *operand) { + return operand->get().getType().isa(); + })) + return failure(); + + // Make sure all element types are the same. + auto getOperandElementType = [](OpOperand *operand) { + return operand->get().getType().cast().getElementType(); + }; + if (!llvm::is_splat(llvm::map_range(genericOp.getInputAndOutputOperands(), + getOperandElementType))) + return failure(); + + // We can only handle the case where we have int/float elements. + auto elementType = outputType.getElementType(); + if (!elementType.isIntOrFloat()) + return failure(); + + // Require all indexing maps to be permutations for now. This is common and + // it simplifies input/output access greatly: we can do the data shuffling + // entirely in the compiler, without needing to turn all indices into + // Values, and then do affine apply on them, and then match back the + // constant again. + if (!llvm::all_of(genericOp.getIndexingMaps(), + [](AffineMap map) { return map.isPermutation(); })) + return failure(); + + for (OpOperand *operand : genericOp.getOutputOperands()) { + if (genericOp.payloadUsesValueFromOperand(operand)) + return failure(); + } + + // Further check the indexing maps are okay for the ConcreteType. + if (!static_cast(this)->matchIndexingMaps(genericOp)) + return failure(); + + // Defer to the concrete type to check the region and discover the + // computation inside. + RegionComputationFn computeFn = + static_cast(this)->getRegionComputeFn(genericOp); + if (!computeFn) + return failure(); + + // All inputs should be constants. + int numInputs = genericOp.getNumInputs(); + SmallVector inputValues(numInputs); + for (const auto &operand : llvm::enumerate(genericOp.getInputOperands())) { + if (!matchPattern(operand.value()->get(), + m_Constant(&inputValues[operand.index()]))) + return failure(); + } + + // Identified this as a potential candidate for folding. Now check the + // policy to see whether we are allowed to proceed. + for (int i = 0; i < numInputs; ++i) { + OpOperand *consumer = genericOp.getInputOperand(i); + OpResult producer = consumer->get().cast(); + if (!controlFn(producer, *consumer)) + return failure(); + } + + auto linalgOp = cast(genericOp.getOperation()); + SmallVector loopBounds = linalgOp.computeStaticLoopSizes(); + int64_t numElements = outputType.getNumElements(); + + // Use APInt/APFloat instead of Attribute here for constructing the output. + // This helps to avoid blowing up compiler memory usage: Attributes would + // unify the following cases but they have lifetime as the MLIRContext. + SmallVector intOutputValues; + SmallVector fpOutputValues; + if (elementType.template isa()) + fpOutputValues.resize(numElements, APFloat(0.f)); + else + intOutputValues.resize(numElements); + + // Return the constant dim positions from the given permutation map. + auto getDimPositions = [](AffineMap map) { + SmallVector dims; + dims.reserve(map.getNumResults()); + for (AffineExpr result : map.getResults()) { + dims.push_back(result.cast().getPosition()); + } + return dims; + }; + + SmallVector> inputDims; + for (int i = 0; i < numInputs; ++i) + inputDims.push_back(getDimPositions(genericOp.getIndexingMaps()[i])); + auto outputDims = getDimPositions(genericOp.getIndexingMaps().back()); + auto outputShape = outputType.getShape(); + + // Allocate small vectors for index delinearization. Initial values do not + // matter here as they will be overwritten later. + SmallVector indices(loopBounds.size(), 0); + SmallVector dstIndices(loopBounds.size(), 0); + SmallVector> srcIndices( + numInputs, SmallVector(loopBounds.size(), 0)); + SmallVector srcLinearIndices(numInputs, 0); + uint64_t dstLinearIndex = 0; + + // Allocate spaces for compute function inputs. Initial values do not matter + // here as they will be overwritten later. + APIntOrFloatArray computeFnInputs; + + auto inputShapes = llvm::to_vector<4>( + llvm::map_range(genericOp.getInputOperands(), [](OpOperand *operand) { + return operand->get().getType().cast().getShape(); + })); + + // Given a `linearIndex`, remap it to a linear index to access linalg op + // inputs/ouputs. This mutates `indices`, `srcIndices`, `dstIndices`, + // `srcLinearIndices`, `dstLinearIndex` in place. + auto computeRemappedLinearIndex = [&](int linearIndex) { + int totalCount = linearIndex; + for (int dim = loopBounds.size() - 1; dim >= 0; --dim) { + indices[dim] = totalCount % loopBounds[dim]; + totalCount /= loopBounds[dim]; + } + + for (int dim = loopBounds.size() - 1; dim >= 0; --dim) { + for (int i = 0; i < numInputs; ++i) + srcIndices[i][dim] = indices[inputDims[i][dim]]; + dstIndices[dim] = indices[outputDims[dim]]; + } + + dstLinearIndex = dstIndices.front(); + for (int i = 0; i < numInputs; ++i) + srcLinearIndices[i] = srcIndices[i].front(); + + for (int dim = 1; dim < outputType.getRank(); ++dim) { + dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim]; + for (int i = 0; i < numInputs; ++i) + srcLinearIndices[i] = + srcLinearIndices[i] * inputShapes[i][dim] + srcIndices[i][dim]; + } + }; + + bool isFloat = elementType.isa(); + if (isFloat) { + SmallVector> inFpRanges; + for (int i = 0; i < numInputs; ++i) + inFpRanges.push_back(inputValues[i].getValues()); + + computeFnInputs.apFloats.resize(numInputs, APFloat(0.f)); + + // Transpose the input constant. Because we don't know its rank in + // advance, we need to loop over the range [0, element count) and + // delinearize the index. + for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) { + computeRemappedLinearIndex(linearIndex); + + // Collect constant elements for all inputs at this loop iteration. + for (int i = 0; i < numInputs; ++i) + computeFnInputs.apFloats[i] = inFpRanges[i][srcLinearIndices[i]]; + + // Invoke the computation to get the corresponding constant output + // element. + fpOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apFloat; + } + } else { + SmallVector> inIntRanges; + for (int i = 0; i < numInputs; ++i) + inIntRanges.push_back(inputValues[i].getValues()); + + computeFnInputs.apInts.resize(numInputs); + + // Transpose the input constant. Because we don't know its rank in + // advance, we need to loop over the range [0, element count) and + // delinearize the index. + for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) { + computeRemappedLinearIndex(linearIndex); + + // Collect constant elements for all inputs at this loop iteration. + for (int i = 0; i < numInputs; ++i) + computeFnInputs.apInts[i] = inIntRanges[i][srcLinearIndices[i]]; + + // Invoke the computation to get the corresponding constant output + // element. + intOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apInt; + } + } + + DenseElementsAttr outputAttr = + isFloat ? DenseElementsAttr::get(outputType, fpOutputValues) + : DenseElementsAttr::get(outputType, intOutputValues); + + rewriter.replaceOpWithNewOp(genericOp, outputAttr); + return success(); + } + +private: + ControlFusionFn controlFn; +}; + +// Folds linalg.generic ops that are actually transposes on constant values. +struct FoldConstantTranspose : public FoldConstantBase { + using FoldConstantBase::FoldConstantBase; + + bool matchIndexingMaps(GenericOp genericOp) const { + // We should have one input and one output. + return genericOp.getIndexingMaps().size() == 2; + } + + RegionComputationFn getRegionComputeFn(GenericOp genericOp) const { + // Make sure the region only contains a yield op. + Block &body = genericOp.region().front(); + if (!llvm::hasSingleElement(body)) + return nullptr; + auto yieldOp = dyn_cast(body.getTerminator()); + if (!yieldOp) + return nullptr; + + // The yield op should return the block argument corresponds to the input. + for (Value yieldVal : yieldOp.values()) { + auto yieldArg = yieldVal.dyn_cast(); + if (!yieldArg || yieldArg.getOwner() != &body) + return nullptr; + if (yieldArg.getArgNumber() != 0) + return nullptr; + } + + // No computation; just return the orginal value. + return [](const APIntOrFloatArray &inputs) { + if (inputs.apFloats.empty()) + return APIntOrFloat{inputs.apInts.front(), llvm::None}; + return APIntOrFloat{llvm::None, inputs.apFloats.front()}; + }; + } + + ControlFusionFn controlFn; +}; +} // namespace + +void mlir::linalg::populateConstantFoldLinalgOperations( + RewritePatternSet &patterns, const ControlFusionFn &controlFn) { + MLIRContext *context = patterns.getContext(); + patterns.insert(context, controlFn); +} diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -248,7 +248,7 @@ static Optional> fuseElementwiseOpsImpl(GenericOp producer, OpOperand *consumerOpOperand, - const ControlElementwiseOpsFusionFn &controlFn, + const ControlFusionFn &controlFn, PatternRewriter &rewriter) { auto consumer = cast(consumerOpOperand->getOwner()); if (!areElementwiseOpsFusable(producer, consumer, consumerOpOperand) || @@ -352,8 +352,7 @@ static Optional> fuseElementwiseOps(PatternRewriter &rewriter, OpOperand *consumerOpOperand, - GenericOp producer, - const ControlElementwiseOpsFusionFn &controlFn) { + GenericOp producer, const ControlFusionFn &controlFn) { if (producer->getNumResults() != 1) return llvm::None; @@ -365,9 +364,10 @@ /// Patterns to fuse a generic op, with the producer of its operands. class FuseElementwiseOps : public OpRewritePattern { public: - FuseElementwiseOps(MLIRContext *context, ControlElementwiseOpsFusionFn &fun, + FuseElementwiseOps(MLIRContext *context, ControlFusionFn fun, PatternBenefit benefit = 1) - : OpRewritePattern(context, benefit), controlFn(fun) {} + : OpRewritePattern(context, benefit), + controlFn(std::move(fun)) {} LogicalResult matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const override { @@ -388,7 +388,7 @@ } private: - ControlElementwiseOpsFusionFn controlFn; + ControlFusionFn controlFn; }; } // namespace @@ -1078,9 +1078,9 @@ class FoldWithProducerReshapeOpByExpansion : public OpRewritePattern { public: - FoldWithProducerReshapeOpByExpansion( - MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes, - PatternBenefit benefit = 1) + FoldWithProducerReshapeOpByExpansion(MLIRContext *context, + ControlFusionFn foldReshapes, + PatternBenefit benefit = 1) : OpRewritePattern(context, benefit), controlFoldingReshapes(std::move(foldReshapes)) {} @@ -1109,7 +1109,7 @@ } private: - ControlElementwiseOpsFusionFn controlFoldingReshapes; + ControlFusionFn controlFoldingReshapes; }; /// Pattern to fold a tensor_expand_shape op with its producer generic op @@ -1117,9 +1117,9 @@ struct FoldReshapeWithGenericOpByExpansion : public OpRewritePattern { - FoldReshapeWithGenericOpByExpansion( - MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes, - PatternBenefit benefit = 1) + FoldReshapeWithGenericOpByExpansion(MLIRContext *context, + ControlFusionFn foldReshapes, + PatternBenefit benefit = 1) : OpRewritePattern(context, benefit), controlFoldingReshapes(std::move(foldReshapes)) {} @@ -1142,7 +1142,7 @@ } private: - ControlElementwiseOpsFusionFn controlFoldingReshapes; + ControlFusionFn controlFoldingReshapes; }; } // namespace @@ -1562,9 +1562,9 @@ class FoldWithProducerReshapeOpByCollapsing : public OpRewritePattern { public: - FoldWithProducerReshapeOpByCollapsing( - MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes, - PatternBenefit benefit = 1) + FoldWithProducerReshapeOpByCollapsing(MLIRContext *context, + ControlFusionFn foldReshapes, + PatternBenefit benefit = 1) : OpRewritePattern(context, benefit), controlFoldingReshapes(std::move(foldReshapes)) {} @@ -1596,7 +1596,7 @@ } private: - ControlElementwiseOpsFusionFn controlFoldingReshapes; + ControlFusionFn controlFoldingReshapes; }; } // namespace @@ -1777,10 +1777,8 @@ /// handle cases where the constant is not single-valued. class FoldScalarOrSplatConstant : public OpRewritePattern { public: - FoldScalarOrSplatConstant(MLIRContext *context, - ControlElementwiseOpsFusionFn &fun, - PatternBenefit benefit = 1) - : OpRewritePattern(context, benefit), controlFn(fun) {} + FoldScalarOrSplatConstant(MLIRContext *context, PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit) {} LogicalResult matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const override { @@ -1817,8 +1815,7 @@ }; auto resultValue = opOperand->get().dyn_cast(); - if (!def || !resultValue || !isScalarOrSplatConstantOp(def) || - !controlFn(resultValue, *opOperand)) + if (!def || !resultValue || !isScalarOrSplatConstantOp(def)) continue; // The operands and the indexing_maps of the fused operation the same as @@ -1876,287 +1873,6 @@ } return failure(); } - -private: - ControlElementwiseOpsFusionFn controlFn; -}; - -/// Base class for constant folding linalg.generic ops with N inputs, 1 output, -/// and permutation indexing maps. -/// -/// `ConcreteType` should provide methods with signatures -/// -/// ```c++ -/// bool matchIndexingMaps(GenericOp genericOp) const; -/// RegionComputationFn getRegionComputeFn(GenericOp) const; -/// ``` -/// -/// The latter inspects the region and returns the computation inside as a -/// functor. The functor will be invoked with constant elements for all inputs -/// and should return the corresponding computea constant element for output. -template -class FoldConstantBase : public OpRewritePattern { -public: - struct APIntOrFloat { - Optional apInt; - Optional apFloat; - }; - struct APIntOrFloatArray { - SmallVector apInts; - SmallVector apFloats; - }; - using RegionComputationFn = - std::function; - - FoldConstantBase(MLIRContext *context, - const ControlElementwiseOpsFusionFn &controlFn, - PatternBenefit benefit = 1) - : OpRewritePattern(context, benefit), controlFn(controlFn) {} - - LogicalResult matchAndRewrite(GenericOp genericOp, - PatternRewriter &rewriter) const override { - if (genericOp.hasBufferSemantics()) - return failure(); - - // Only support ops generating one output for now. - if (genericOp.getNumOutputs() != 1) - return failure(); - - auto outputType = genericOp.getResultTypes().front().dyn_cast(); - // Require the output types to be static give we are generating constants. - if (!outputType || !outputType.hasStaticShape()) - return failure(); - - if (!llvm::all_of(genericOp.getInputOperands(), [](OpOperand *operand) { - return operand->get().getType().isa(); - })) - return failure(); - - // Make sure all element types are the same. - auto getOperandElementType = [](OpOperand *operand) { - return operand->get().getType().cast().getElementType(); - }; - if (!llvm::is_splat(llvm::map_range(genericOp.getInputAndOutputOperands(), - getOperandElementType))) - return failure(); - - // We can only handle the case where we have int/float elements. - auto elementType = outputType.getElementType(); - if (!elementType.isIntOrFloat()) - return failure(); - - // Require all indexing maps to be permutations for now. This is common and - // it simplifies input/output access greatly: we can do the data shuffling - // entirely in the compiler, without needing to turn all indices into - // Values, and then do affine apply on them, and then match back the - // constant again. - if (!llvm::all_of(genericOp.getIndexingMaps(), - [](AffineMap map) { return map.isPermutation(); })) - return failure(); - - for (OpOperand *operand : genericOp.getOutputOperands()) { - if (genericOp.payloadUsesValueFromOperand(operand)) - return failure(); - } - - // Further check the indexing maps are okay for the ConcreteType. - if (!static_cast(this)->matchIndexingMaps(genericOp)) - return failure(); - - // Defer to the concrete type to check the region and discover the - // computation inside. - RegionComputationFn computeFn = - static_cast(this)->getRegionComputeFn(genericOp); - if (!computeFn) - return failure(); - - // All inputs should be constants. - int numInputs = genericOp.getNumInputs(); - SmallVector inputValues(numInputs); - for (const auto &operand : llvm::enumerate(genericOp.getInputOperands())) { - if (!matchPattern(operand.value()->get(), - m_Constant(&inputValues[operand.index()]))) - return failure(); - } - - // Identified this as a potential candidate for folding. Now check the - // policy to see whether we are allowed to proceed. - for (int i = 0; i < numInputs; ++i) { - OpOperand *consumer = genericOp.getInputOperand(i); - OpResult producer = consumer->get().cast(); - if (!controlFn(producer, *consumer)) - return failure(); - } - - auto linalgOp = cast(genericOp.getOperation()); - SmallVector loopBounds = linalgOp.computeStaticLoopSizes(); - int64_t numElements = outputType.getNumElements(); - - // Use APInt/APFloat instead of Attribute here for constructing the output. - // This helps to avoid blowing up compiler memory usage: Attributes would - // unify the following cases but they have lifetime as the MLIRContext. - SmallVector intOutputValues; - SmallVector fpOutputValues; - if (elementType.template isa()) - fpOutputValues.resize(numElements, APFloat(0.f)); - else - intOutputValues.resize(numElements); - - // Return the constant dim positions from the given permutation map. - auto getDimPositions = [](AffineMap map) { - SmallVector dims; - dims.reserve(map.getNumResults()); - for (AffineExpr result : map.getResults()) { - dims.push_back(result.cast().getPosition()); - } - return dims; - }; - - SmallVector> inputDims; - for (int i = 0; i < numInputs; ++i) - inputDims.push_back(getDimPositions(genericOp.getIndexingMaps()[i])); - auto outputDims = getDimPositions(genericOp.getIndexingMaps().back()); - auto outputShape = outputType.getShape(); - - // Allocate small vectors for index delinearization. Initial values do not - // matter here as they will be overwritten later. - SmallVector indices(loopBounds.size(), 0); - SmallVector dstIndices(loopBounds.size(), 0); - SmallVector> srcIndices( - numInputs, SmallVector(loopBounds.size(), 0)); - SmallVector srcLinearIndices(numInputs, 0); - uint64_t dstLinearIndex = 0; - - // Allocate spaces for compute function inputs. Initial values do not matter - // here as they will be overwritten later. - APIntOrFloatArray computeFnInputs; - - auto inputShapes = llvm::to_vector<4>( - llvm::map_range(genericOp.getInputOperands(), [](OpOperand *operand) { - return operand->get().getType().cast().getShape(); - })); - - // Given a `linearIndex`, remap it to a linear index to access linalg op - // inputs/ouputs. This mutates `indices`, `srcIndices`, `dstIndices`, - // `srcLinearIndices`, `dstLinearIndex` in place. - auto computeRemappedLinearIndex = [&](int linearIndex) { - int totalCount = linearIndex; - for (int dim = loopBounds.size() - 1; dim >= 0; --dim) { - indices[dim] = totalCount % loopBounds[dim]; - totalCount /= loopBounds[dim]; - } - - for (int dim = loopBounds.size() - 1; dim >= 0; --dim) { - for (int i = 0; i < numInputs; ++i) - srcIndices[i][dim] = indices[inputDims[i][dim]]; - dstIndices[dim] = indices[outputDims[dim]]; - } - - dstLinearIndex = dstIndices.front(); - for (int i = 0; i < numInputs; ++i) - srcLinearIndices[i] = srcIndices[i].front(); - - for (int dim = 1; dim < outputType.getRank(); ++dim) { - dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim]; - for (int i = 0; i < numInputs; ++i) - srcLinearIndices[i] = - srcLinearIndices[i] * inputShapes[i][dim] + srcIndices[i][dim]; - } - }; - - bool isFloat = elementType.isa(); - if (isFloat) { - SmallVector> inFpRanges; - for (int i = 0; i < numInputs; ++i) - inFpRanges.push_back(inputValues[i].getValues()); - - computeFnInputs.apFloats.resize(numInputs, APFloat(0.f)); - - // Transpose the input constant. Because we don't know its rank in - // advance, we need to loop over the range [0, element count) and - // delinearize the index. - for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) { - computeRemappedLinearIndex(linearIndex); - - // Collect constant elements for all inputs at this loop iteration. - for (int i = 0; i < numInputs; ++i) - computeFnInputs.apFloats[i] = inFpRanges[i][srcLinearIndices[i]]; - - // Invoke the computation to get the corresponding constant output - // element. - fpOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apFloat; - } - } else { - SmallVector> inIntRanges; - for (int i = 0; i < numInputs; ++i) - inIntRanges.push_back(inputValues[i].getValues()); - - computeFnInputs.apInts.resize(numInputs); - - // Transpose the input constant. Because we don't know its rank in - // advance, we need to loop over the range [0, element count) and - // delinearize the index. - for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) { - computeRemappedLinearIndex(linearIndex); - - // Collect constant elements for all inputs at this loop iteration. - for (int i = 0; i < numInputs; ++i) - computeFnInputs.apInts[i] = inIntRanges[i][srcLinearIndices[i]]; - - // Invoke the computation to get the corresponding constant output - // element. - intOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apInt; - } - } - - DenseElementsAttr outputAttr = - isFloat ? DenseElementsAttr::get(outputType, fpOutputValues) - : DenseElementsAttr::get(outputType, intOutputValues); - - rewriter.replaceOpWithNewOp(genericOp, outputAttr); - return success(); - } - -private: - ControlElementwiseOpsFusionFn controlFn; -}; - -// Folds linalg.generic ops that are actually transposes on constant values. -struct FoldConstantTranspose : public FoldConstantBase { - using FoldConstantBase::FoldConstantBase; - - bool matchIndexingMaps(GenericOp genericOp) const { - // We should have one input and one output. - return genericOp.getIndexingMaps().size() == 2; - } - - RegionComputationFn getRegionComputeFn(GenericOp genericOp) const { - // Make sure the region only contains a yield op. - Block &body = genericOp.region().front(); - if (!llvm::hasSingleElement(body)) - return nullptr; - auto yieldOp = dyn_cast(body.getTerminator()); - if (!yieldOp) - return nullptr; - - // The yield op should return the block argument corresponds to the input. - for (Value yieldVal : yieldOp.values()) { - auto yieldArg = yieldVal.dyn_cast(); - if (!yieldArg || yieldArg.getOwner() != &body) - return nullptr; - if (yieldArg.getArgNumber() != 0) - return nullptr; - } - - // No computation; just return the orginal value. - return [](const APIntOrFloatArray &inputs) { - if (inputs.apFloats.empty()) - return APIntOrFloat{inputs.apInts.front(), llvm::None}; - return APIntOrFloat{llvm::None, inputs.apFloats.front()}; - }; - } - - ControlElementwiseOpsFusionFn controlFn; }; } // namespace @@ -2264,7 +1980,7 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns( RewritePatternSet &patterns, - const ControlElementwiseOpsFusionFn &controlFoldingReshapes) { + const ControlFusionFn &controlFoldingReshapes) { patterns.add(patterns.getContext(), controlFoldingReshapes); patterns.add(patterns.getContext(), @@ -2273,27 +1989,18 @@ void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns( RewritePatternSet &patterns, - const ControlElementwiseOpsFusionFn &controlFoldingReshapes) { + const ControlFusionFn &controlFoldingReshapes) { patterns.add(patterns.getContext(), controlFoldingReshapes); } void mlir::linalg::populateElementwiseOpsFusionPatterns( - RewritePatternSet &patterns, LinalgElementwiseFusionOptions options) { + RewritePatternSet &patterns, + const ControlFusionFn &controlElementwiseOpsFusion) { auto *context = patterns.getContext(); - patterns.add(context, - options.controlElementwiseOpsFusionFn); - patterns.add(context); - populateSparseTensorRewriting(patterns); - populateFoldReshapeOpsByExpansionPatterns(patterns, - options.controlFoldingReshapesFn); - AffineApplyOp::getCanonicalizationPatterns(patterns, context); - GenericOp::getCanonicalizationPatterns(patterns, context); - tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context); - tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context); - context->getLoadedDialect()->getCanonicalizationPatterns( - patterns); + patterns.add(context, controlElementwiseOpsFusion); + patterns.add(context); } void mlir::linalg::populatePushReshapeOpsPatterns(RewritePatternSet &patterns) { @@ -2321,19 +2028,44 @@ namespace { /// Pass that fuses generic ops on tensors. Used only for testing. +// TODO(ravishankarm): This pass is to be deprecated. The efficacy of the +// patterns added here heavily depends on the cost function used. Having an +// opinionated pass of this form is not recommended. Deprecate this pass in +// favor of test passes that check the functionality of each of the patterns +// added here individually. struct LinalgElementwiseOpFusionPass : public LinalgElementwiseOpFusionBase { void runOnOperation() override { Operation *op = getOperation(); - RewritePatternSet patterns(op->getContext()); - ControlElementwiseOpsFusionFn allowFoldingFn = - [](const OpResult &producer, const OpOperand &consumer) { - return true; - }; - populateElementwiseOpsFusionPatterns( + MLIRContext *context = op->getContext(); + RewritePatternSet patterns(context); + + // Add folding with reshape by expansion patterns. + ControlFusionFn defaultControlFn = [](const OpResult &producer, + const OpOperand &consumer) { + return producer.hasOneUse(); + }; + + // Add elementwise op fusion patterns. + populateElementwiseOpsFusionPatterns(patterns, defaultControlFn); + + populateFoldReshapeOpsByExpansionPatterns( patterns, - LinalgElementwiseFusionOptions().setControlFoldingReshapes( - allowFoldingUnitDimReshapes ? allowFoldingFn : skipUnitDimReshape)); + allowFoldingUnitDimReshapes ? defaultControlFn : skipUnitDimReshape); + + // Add the sparse tensor rewriting patterns. + populateSparseTensorRewriting(patterns); + + // General canonicalization patterns. + AffineApplyOp::getCanonicalizationPatterns(patterns, context); + GenericOp::getCanonicalizationPatterns(patterns, context); + tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context); + tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context); + context->getLoadedDialect()->getCanonicalizationPatterns( + patterns); + + // Add constant folding patterns. + populateConstantFoldLinalgOperations(patterns, defaultControlFn); // Use TopDownTraversal for compile time reasons GreedyRewriteConfig grc; diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp @@ -100,10 +100,8 @@ if (fuseGenericOps) { RewritePatternSet fusionPatterns(context); - linalg::populateElementwiseOpsFusionPatterns( - fusionPatterns, - linalg::LinalgElementwiseFusionOptions() - .setControlElementwiseOpsFusionFn(setFusedOpOperandLimit<4>)); + linalg::populateElementwiseOpsFusionPatterns(fusionPatterns, + setFusedOpOperandLimit<4>); (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(fusionPatterns)); @@ -113,7 +111,7 @@ if (controlFuseByExpansion) { RewritePatternSet fusionPatterns(context); - linalg::ControlElementwiseOpsFusionFn controlReshapeFusionFn = + linalg::ControlFusionFn controlReshapeFusionFn = [](const OpResult &producer, OpOperand &consumer) { if (auto collapseOp = producer.getDefiningOp()) { @@ -148,14 +146,16 @@ if (fuseWithReshapeByCollapsing) { RewritePatternSet patterns(context); - linalg::populateFoldReshapeOpsByCollapsingPatterns(patterns); + linalg::populateFoldReshapeOpsByCollapsingPatterns( + patterns, [](const OpResult & /*producer*/, + OpOperand & /*consumer*/) { return true; }); (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns)); } if (fuseWithReshapeByCollapsingWithControlFn) { RewritePatternSet patterns(context); - linalg::ControlElementwiseOpsFusionFn controlFn = - [](const OpResult &producer, OpOperand &consumer) -> bool { + linalg::ControlFusionFn controlFn = [](const OpResult &producer, + OpOperand &consumer) -> bool { if (isa(producer.getDefiningOp())) { // Skip fusing the first operand. return consumer.getOperandNumber();