diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td @@ -123,6 +123,10 @@ "Return the range over inputs (irrespective of type) and output buffers.", "Operation::operand_range", "getInputsAndOutputBuffers" >, + InterfaceMethod< + "Return the shaped types for all the inputs and outputs", + "SmallVector", "getInputOutputShapedTypes" + >, //===------------------------------------------------------------------===// // Other interface methods. @@ -153,6 +157,10 @@ "Return the indexing maps attribute within the current operation.", "ArrayAttr", "indexing_maps" >, + InterfaceMethod< + "Return the indexing maps within the current operation.", + "SmallVector", "getIndexingMaps" + >, InterfaceMethod<"Return the input or output indexing map at index `i`.", "AffineMap", "getIndexingMap", (ins "unsigned":$i) >, diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h @@ -217,6 +217,18 @@ return getOutputTensorTypes()[i - getNumInputsAndOutputBuffers()] .template cast(); } + /// Return the shaped types for all the inputs and outputs + SmallVector getInputOutputShapedTypes() { + SmallVector inputOutputTypes( + this->getOperation()->operand_type_begin(), + this->getOperation()->operand_type_end()); + inputOutputTypes.append(this->getOperation()->result_type_begin(), + this->getOperation()->result_type_end()); + return llvm::to_vector<4>( + llvm::map_range(inputOutputTypes, [](Type type) -> ShapedType { + return type.cast(); + })); + } //==========================================================================// // Other interface methods. @@ -295,6 +307,13 @@ return attr; } + SmallVector getIndexingMaps() { + return llvm::to_vector<4>( + llvm::map_range(indexing_maps(), [](Attribute attr) -> AffineMap { + return attr.cast().getValue(); + })); + } + AffineMap getIndexingMap(unsigned i) { assert(i < getNumInputsAndOutputs()); return indexing_maps() diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -24,6 +24,8 @@ class OwningRewritePatternList; class Pass; +std::unique_ptr> createLinalgFoldUnitExtentDimsPass(); + std::unique_ptr> createLinalgFusionPass(); std::unique_ptr createLinalgFusionOfTensorOpsPass(); @@ -59,6 +61,11 @@ void populateLinalgTensorOpsFusionPatterns(MLIRContext *context, OwningRewritePatternList &patterns); +/// Patterns to fold unit-extent dimensions in operands/results of linalg ops on +/// tensors. +void populateLinalgFoldUnitExtentDimsPatterns( + MLIRContext *context, OwningRewritePatternList &patterns); + } // namespace mlir #endif // MLIR_DIALECT_LINALG_PASSES_H_ diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -11,6 +11,17 @@ include "mlir/Pass/PassBase.td" +def LinalgFoldUnitExtentDims : FunctionPass<"linalg-fold-unit-extent-dims"> { + let summary = "Remove unit-extent dimension in Linalg ops on tensors"; + let constructor = "mlir::createLinalgFoldUnitExtentDimsPass()"; + let options = [ + Option<"foldOneTripLoopsOnly", "fold-one-trip-loops-only", "bool", + /*default=*/"false", + "Only folds the one-trip loops from Linalg ops on tensors " + "(for testing purposes only)"> + ]; +} + def LinalgFusion : FunctionPass<"linalg-fusion"> { let summary = "Fuse operations in the linalg dialect"; let constructor = "mlir::createLinalgFusionPass()"; diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -265,7 +265,7 @@ static ArrayAttr collapseReassociationMaps(ArrayRef mapsProducer, ArrayRef mapsConsumer, MLIRContext *context) { - if (mapsProducer.size() == 0 || mapsConsumer.size() == 0 || + if (mapsProducer.empty() || mapsConsumer.empty() || mapsProducer[0].getNumDims() < mapsConsumer[0].getNumDims() || mapsProducer.size() != mapsConsumer[0].getNumDims()) return nullptr; @@ -277,7 +277,7 @@ for (AffineExpr rhsExpr : rhs.getResults()) { AffineDimExpr dimExpr = rhsExpr.cast(); for (int i = 0, e = mapsProducer[dimExpr.getPosition()].getNumResults(); - i != e; ++i) { + i < e; ++i) { reassociations.push_back(getAffineDimExpr(currDim++, context)); } } @@ -1129,8 +1129,6 @@ return {}; } OpFoldResult TensorReshapeOp::fold(ArrayRef) { - if (succeeded(foldMemRefCast(*this))) - return getResult(); return foldReshapeOp(*this); } OpFoldResult TransposeOp::fold(ArrayRef) { diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_dialect_library(MLIRLinalgTransforms + DropUnitDims.cpp Fusion.cpp Interchange.cpp Loops.cpp diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -0,0 +1,375 @@ +//===- DropUnitDims.cpp - Pass to drop use of unit-extent for broadcasting ===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements patterns/pass to remove usage of unit-extent dimensions +// to specify broadcasting in favor of more canonical representation of the +// computation +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Linalg/IR/LinalgTypes.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/FoldUtils.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "linalg-drop-unit-dims" + +using namespace mlir; +using namespace mlir::edsc; +using namespace mlir::edsc::intrinsics; +using namespace mlir::linalg; + +/// Implements a pass that canonicalizes the uses of unit-extent dimensions for +/// broadcasting. For example, +/// +/// ```mlir +/// #accesses = [ +/// affine_map<(d0, d1) -> (0, d1)>, +/// affine_map<(d0, d1) -> (d0, 0)>, +/// affine_map<(d0, d1) -> (d0, d1)> +/// ] +/// +/// #trait = { +/// args_in = 2, +/// args_out = 1, +/// indexing_maps = #accesses, +/// iterator_types = ["parallel", "parallel"], +/// library_call = "some_external_fn" +/// } +/// +/// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) -> +/// tensor<5x5xf32> +/// { +/// %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>] : +/// tensor<5xf32> into tensor<1x5xf32> +/// %1 = linalg.tensor_reshape %arg1 [affine_map<(d0, d1) -> (d0, d1)>] : +/// tensor<5xf32> into tensor<5x1xf32> +/// %2 = linalg.generic #trait %0, %1 { +/// ^bb0(%arg2: f32, %arg3: f32): +/// %3 = addf %arg2, %arg3 : f32 +/// linalg.yield %3 : f32 +/// } : tensor<1x5xf32>, tensor<5x1xf32> -> tensor<5x5xf32> +/// return %2 : tensor<5x5xf32> +/// } +/// +/// would canonicalize to +/// +/// ```mlir +/// #accesses = [ +/// affine_map<(d0, d1) -> (d1)>, +/// affine_map<(d0, d1) -> (d0)>, +/// affine_map<(d0, d1) -> (d0, d1)> +/// ] +/// +/// #trait = { +/// args_in = 2, +/// args_out = 1, +/// indexing_maps = #accesses, +/// iterator_types = ["parallel", "parallel"], +/// library_call = "some_external_fn" +/// } +/// +/// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) -> +/// tensor<5x5xf32> +/// { +/// %0 = linalg.generic #trait %arg0, %arg1 { +/// ^bb0(%arg2: f32, %arg3: f32): +/// %3 = addf %arg2, %arg3 : f32 +/// linalg.yield %3 : f32 +/// } : tensor<5xf32>, tensor<5xf32> -> tensor<5x5xf32> +/// return %0 : tensor<5x5xf32> +/// } + +/// Given dims of the iteration space of a structured op that are known to be +/// single trip count (`unitDims`), return the indexing maps to use in the +/// canonicalized op with these dims removed, given the original `indexingMaps`. +static ArrayAttr replaceUnitDims(DenseSet &unitDims, + ArrayRef indexingMaps, + MLIRContext *context) { + if (indexingMaps.empty()) + return nullptr; + unsigned numIterationDims = indexingMaps.front().getNumDims(); + unsigned numSymbols = indexingMaps.front().getNumSymbols(); + + // Compute the replacement for each dim expr. + SmallVector dimReplacements; + dimReplacements.reserve(numIterationDims); + unsigned numKeptDims = 0; + for (unsigned dim : llvm::seq(0, numIterationDims)) { + if (unitDims.count(dim)) + dimReplacements.push_back(getAffineConstantExpr(0, context)); + else + dimReplacements.push_back(getAffineDimExpr(numKeptDims++, context)); + } + + // Symbols remain the same. + SmallVector symReplacements; + symReplacements.reserve(numSymbols); + for (unsigned symbol : llvm::seq(0, numSymbols)) + symReplacements.push_back(getAffineSymbolExpr(symbol, context)); + + SmallVector newIndexingMaps; + newIndexingMaps.reserve(indexingMaps.size()); + for (AffineMap operandMap : indexingMaps) { + // Expected indexing maps to have no symbols. + if (operandMap.getNumSymbols()) + return nullptr; + newIndexingMaps.push_back(simplifyAffineMap( + operandMap.replaceDimsAndSymbols(dimReplacements, symReplacements, + numIterationDims - unitDims.size(), + numSymbols))); + } + + // Check that the new index maps are invertible. If not, something went + // wrong, so abort. + if (!inversePermutation(concatAffineMaps(newIndexingMaps))) + return nullptr; + return ArrayAttr::get( + llvm::to_vector<4>(llvm::map_range( + newIndexingMaps, + [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); })), + context); +} + +namespace { +/// Pattern to fold unit-trip count loops in GenericOps. +// TODO: Generalize this to indexed-generic as well by modifying the region args +// as well. +struct FoldUnitDimLoops : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(GenericOp genericOp, + PatternRewriter &rewriter) const override { + SmallVector indexingMaps = genericOp.getIndexingMaps(); + if (indexingMaps.empty()) + return failure(); + + // Check if any of the iteration dimensions are unit-trip count. They will + // end up being unit-trip count if they are used to index into a unit-dim + // tensor/memref. + AffineMap invertedMap = inversePermutation(concatAffineMaps(indexingMaps)); + if (!invertedMap) + return failure(); + SmallVector dims; + for (ShapedType shapedType : genericOp.getInputOutputShapedTypes()) + dims.append(shapedType.getShape().begin(), shapedType.getShape().end()); + DenseSet unitDims; + ArrayAttr iteratorTypes = genericOp.iterator_types(); + for (auto expr : enumerate(invertedMap.getResults())) { + if (AffineDimExpr dimExpr = expr.value().dyn_cast()) + if (dims[dimExpr.getPosition()] == 1 && + iteratorTypes[expr.index()].dyn_cast().getValue() == + getParallelIteratorTypeName()) + unitDims.insert(expr.index()); + } + if (unitDims.empty()) + return failure(); + + // Compute the modified indexing maps. + MLIRContext *context = rewriter.getContext(); + ArrayAttr newIndexingMapAttr = + replaceUnitDims(unitDims, indexingMaps, context); + if (!newIndexingMapAttr) + return genericOp.emitError("unable to compute modified indexing_maps"); + + // Compute the iterator types of the modified op by dropping the one-trip + // count loops. + SmallVector newIteratorTypes; + for (auto attr : llvm::enumerate(iteratorTypes)) { + if (!unitDims.count(attr.index())) + newIteratorTypes.push_back(attr.value()); + } + + rewriter.startRootUpdate(genericOp); + genericOp.indexing_mapsAttr(newIndexingMapAttr); + genericOp.iterator_typesAttr(ArrayAttr::get(newIteratorTypes, context)); + rewriter.finalizeRootUpdate(genericOp); + return success(); + } +}; + +struct UnitExtentReplacementInfo { + RankedTensorType type; + AffineMap indexMap; + ArrayAttr reassociation; +}; +} // namespace + +/// Utility function for replacing operands/results to a linalg generic +/// operation on tensors with unit-extent dimensions. These can be replaced with +/// an operand/result with the unit-extent dimension removed. This is only done +/// if the indexing map used to access that didimensionmension has a +/// AffineConstantExpr of value 0. Given the `type` of an result/operand of a +/// Linalg op, and its `indexMap` the utility function returns: +/// - the new type with dimensions of size 1 removed. +/// - modified index map that can be used to access the replaced result/operand +/// - the reassociation that converts from the original tensor type to the +/// modified tensor type. +static UnitExtentReplacementInfo replaceUnitExtents(AffineMap indexMap, + RankedTensorType type, + MLIRContext *context) { + ArrayRef shape = type.getShape(); + ArrayRef exprs = indexMap.getResults(); + SmallVector reassociations; + SmallVector reassociationMaps; + SmallVector newIndexExprs; + SmallVector newShape; + + int64_t origRank = type.getRank(); + AffineExpr zeroExpr = getAffineConstantExpr(0, context); + auto isUnitExtent = [&](int64_t dim) -> bool { + return shape[dim] == 1 && exprs[dim] == zeroExpr; + }; + + unsigned dim = 0; + // Fold dimensions that are unit-extent at the beginning of the tensor. + while (dim < origRank && isUnitExtent(dim)) + reassociations.push_back(getAffineDimExpr(dim++, context)); + while (dim < origRank) { + reassociations.push_back(getAffineDimExpr(dim, context)); + newIndexExprs.push_back(exprs[dim]); + newShape.push_back(shape[dim]); + // Fold all following dimensions that are unit-extent. + while (dim + 1 < origRank && isUnitExtent(dim + 1)) { + ++dim; + reassociations.push_back(getAffineDimExpr(dim, context)); + } + reassociationMaps.push_back(AffineMapAttr::get(AffineMap::get( + origRank, /*numSymbols = */ 0, reassociations, context))); + reassociations.clear(); + ++dim; + } + UnitExtentReplacementInfo info = { + RankedTensorType::get(newShape, type.getElementType()), + AffineMap::get(indexMap.getNumDims(), indexMap.getNumSymbols(), + newIndexExprs, context), + ArrayAttr::get(reassociationMaps, context)}; + return info; +} + +namespace { +/// Pattern to replace tensors operands/results that are unit extents. +struct ReplaceUnitExtentTensors : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(GenericOp genericOp, + PatternRewriter &rewriter) const override { + if (!genericOp.hasTensorSemantics()) + return failure(); + + MLIRContext *context = rewriter.getContext(); + Location loc = genericOp.getLoc(); + + SmallVector newIndexingMaps; + SmallVector reassociationMaps; + SmallVector newInputOutputTypes; + bool doCanonicalization = false; + for (auto it : llvm::zip(genericOp.getIndexingMaps(), + genericOp.getInputOutputShapedTypes())) { + auto replacementInfo = replaceUnitExtents( + std::get<0>(it), std::get<1>(it).cast(), context); + reassociationMaps.push_back(replacementInfo.reassociation); + newIndexingMaps.push_back(replacementInfo.indexMap); + newInputOutputTypes.push_back(replacementInfo.type); + doCanonicalization = + doCanonicalization || replacementInfo.type != std::get<1>(it); + } + + // If the indexing maps of the result operation are not invertible (i.e. not + // legal), abort. + if (!doCanonicalization || + !inversePermutation(concatAffineMaps(newIndexingMaps))) + return failure(); + + // If any operand type change, insert a reshape to convert from the original + // type to the new type. + SmallVector newOperands; + newOperands.reserve(genericOp.getNumOperands()); + for (auto operand : llvm::enumerate(genericOp.getOperands())) { + if (operand.value().getType() == newInputOutputTypes[operand.index()]) { + newOperands.push_back(operand.value()); + } else { + newOperands.push_back(rewriter.create( + loc, newInputOutputTypes[operand.index()], operand.value(), + reassociationMaps[operand.index()])); + } + } + + // If any result type change, insert a reshape to convert from the original + // type to the new type. + SmallVector resultTypes; + resultTypes.reserve(genericOp.getNumResults()); + for (unsigned i : llvm::seq(0, genericOp.getNumResults())) + resultTypes.push_back( + newInputOutputTypes[i + genericOp.getNumOperands()]); + GenericOp replacementOp = rewriter.create( + loc, resultTypes, newOperands, genericOp.args_in(), + genericOp.args_out(), rewriter.getAffineMapArrayAttr(newIndexingMaps), + genericOp.iterator_types(), + /*doc = */ nullptr, + /*library_call = */ nullptr); + rewriter.inlineRegionBefore(genericOp.region(), replacementOp.region(), + replacementOp.region().begin()); + + // If any result tensor has a modified shape, then add reshape to recover + // the original shape. + SmallVector resultReplacements; + for (auto result : llvm::enumerate(replacementOp.getResults())) { + unsigned index = result.index() + replacementOp.getNumOperands(); + RankedTensorType origResultType = genericOp.getResult(result.index()) + .getType() + .cast(); + if (origResultType != result.value().getType()) { + resultReplacements.push_back(rewriter.create( + loc, origResultType, result.value(), reassociationMaps[index])); + } else { + resultReplacements.push_back(result.value()); + } + } + rewriter.replaceOp(genericOp, resultReplacements); + return success(); + } +}; +} // namespace + +/// Patterns that are used to canonicalize the use of unit-extent dims for +/// broadcasting. +void mlir::populateLinalgFoldUnitExtentDimsPatterns( + MLIRContext *context, OwningRewritePatternList &patterns) { + patterns.insert(context); + TensorReshapeOp::getCanonicalizationPatterns(patterns, context); +} + +namespace { +/// Pass that removes unit-extent dims within generic ops. +struct LinalgFoldUnitExtentDimsPass + : public LinalgFoldUnitExtentDimsBase { + void runOnFunction() override { + OwningRewritePatternList patterns; + FuncOp funcOp = getFunction(); + MLIRContext *context = funcOp.getContext(); + if (foldOneTripLoopsOnly) + patterns.insert(context); + else + populateLinalgFoldUnitExtentDimsPatterns(context, patterns); + applyPatternsAndFoldGreedily(funcOp.getBody(), patterns); + } +}; +} // namespace + +std::unique_ptr> +mlir::createLinalgFoldUnitExtentDimsPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -575,8 +575,8 @@ if (auto yieldOp = dyn_cast(op)) { // Lookup the value the yield operation is mapped to. Value yieldVal = yieldOp.getOperand(0); - auto clonedVal = mapper.lookup(yieldVal); - mapper.map(consumerBlock.getArgument(consumerIdx), clonedVal); + if (Value clonedVal = mapper.lookupOrNull(yieldVal)) + mapper.map(consumerBlock.getArgument(consumerIdx), clonedVal); continue; } rewriter.clone(op, mapper); diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir @@ -0,0 +1,165 @@ +// RUN: mlir-opt %s -linalg-fold-unit-extent-dims -split-input-file | FileCheck %s + +#accesses = [ + affine_map<(i, j, k, l, m) -> (i, k, m)>, + affine_map<(i, j, k, l, m) -> (i, k, j, l, m)> +] + +#trait = { + args_in = 1, + args_out = 1, + iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"], + indexing_maps = #accesses, + library_call = "some_external_func" +} + +func @drop_one_trip_loops(%arg0 : tensor) -> tensor +{ + %0 = linalg.generic #trait %arg0 { + ^bb0(%arg1 : f32) : + linalg.yield %arg1 : f32 + } : tensor -> tensor + return %0 : tensor +} +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d2)> +// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)> +// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[MAP4:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)> +// CHECK-DAG: #[[MAP5:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)> +// CHECK-DAG: #[[MAP6:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d4)> +// CHECK-LABEL: func @drop_one_trip_loops +// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]], #[[MAP1]]] +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP3]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] +// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP4]], #[[MAP5]], #[[MAP6]]] + +// ----- + +#map0 = affine_map<(i, j) -> (i, j)> +#access = [#map0, #map0] +#trait = { + args_in = 1, + args_out = 1, + iterator_types = ["parallel", "parallel"], + indexing_maps = #access, + library_call = "some_external_func" +} + +func @drop_all_loops(%arg0 : tensor<1x1xf32>) -> tensor<1x1xf32> +{ + %0 = linalg.generic #trait %arg0 { + ^bb0(%arg1: f32) : + linalg.yield %arg1 : f32 + } : tensor<1x1xf32> -> tensor<1x1xf32> + return %0 : tensor<1x1xf32> +} +// CHECK-DAG: #[[MAP0:.*]] = affine_map<() -> ()> +// CHECK-LABEL: func @drop_all_loops +// CHECK: linalg.tensor_reshape %{{.*}} [] +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP0]]] +// CHECK-SAME: iterator_types = [] + +// ----- + +#accesses = [ + affine_map<(d0) -> (0, d0)>, + affine_map<(d0) -> (d0)> +] + +#trait = { + args_in = 1, + args_out = 1, + indexing_maps = #accesses, + iterator_types = ["parallel"], + library_call = "some_external_fn" +} + +func @leading_dim_1_canonicalization(%arg0: tensor<1x5xf32>) -> tensor<5xf32> { + %0 = linalg.generic #trait %arg0 { + ^bb0(%arg2: f32): // no predecessors + linalg.yield %arg2 : f32 + } : tensor<1x5xf32> -> tensor<5xf32> + return %0 : tensor<5xf32> +} +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func @leading_dim_1_canonicalization +// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]]] +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP1]], #[[MAP1]]] +// CHECK-SAME: iterator_types = ["parallel"] + +// ----- + +#accesses = [ + affine_map<(d0, d1) -> (0, d1)>, + affine_map<(d0, d1) -> (d0, 0)>, + affine_map<(d0, d1) -> (d0, d1)> +] + +#trait = { + args_in = 2, + args_out = 1, + indexing_maps = #accesses, + iterator_types = ["parallel", "parallel"], + library_call = "some_external_fn" +} + +func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) -> tensor<5x5xf32> +{ + %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>] : + tensor<5xf32> into tensor<1x5xf32> + %1 = linalg.tensor_reshape %arg1 [affine_map<(d0, d1) -> (d0, d1)>] : + tensor<5xf32> into tensor<5x1xf32> + %2 = linalg.generic #trait %0, %1 { + ^bb0(%arg2: f32, %arg3: f32): + %3 = addf %arg2, %arg3 : f32 + linalg.yield %3 : f32 + } : tensor<1x5xf32>, tensor<5x1xf32> -> tensor<5x5xf32> + return %2 : tensor<5x5xf32> +} +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d1)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)> +// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func @broadcast_test +// CHECK-NOT: linalg.tensor_reshape +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel"] +// CHECK-NOT: linalg.tensor_reshape + +// ----- + +#accesses = [ + affine_map<(d0, d1) -> (0, 0)>, + affine_map<(d0, d1) -> (d0, d1)> +] + +#trait = { + args_in = 1, + args_out = 1, + indexing_maps = #accesses, + iterator_types = ["parallel", "parallel"], + library_call = "some_external_fn" +} + +func @broadcast_scalar(%arg0 : tensor<1x1xf32>) -> tensor +{ + %0 = linalg.generic #trait %arg0 { + ^bb0(%arg1 : f32): + linalg.yield %arg1 : f32 + } : tensor<1x1xf32> -> tensor + return %0 : tensor +} +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> ()> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func @broadcast_scalar +// CHECK-SAME: %[[ARG0:.*]]: tensor<1x1xf32> +// CHECK: %[[A:.*]] = linalg.tensor_reshape %[[ARG0]] [] +// CHECK-SAME: tensor<1x1xf32> into tensor +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel"] +// CHECK-SAME: %[[A]] diff --git a/mlir/test/Dialect/Linalg/fold-unit-trip-loops.mlir b/mlir/test/Dialect/Linalg/fold-unit-trip-loops.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/fold-unit-trip-loops.mlir @@ -0,0 +1,110 @@ +// RUN: mlir-opt %s -linalg-fold-unit-extent-dims="fold-one-trip-loops-only" -split-input-file | FileCheck %s + +#accesses = [ + affine_map<(i, j, k, l, m) -> (i, k, m)>, + affine_map<(i, j, k, l, m) -> (i, k, j, l, m)> +] + +#trait = { + args_in = 1, + args_out = 1, + iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"], + indexing_maps = #accesses, + library_call = "some_external_func" +} + +func @drop_one_trip_loops(%arg0 : tensor) -> tensor +{ + %0 = linalg.generic #trait %arg0 { + ^bb0(%arg1 : f32) : + linalg.yield %arg1 : f32 + } : tensor -> tensor + return %0 : tensor +} +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, 0, d2)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, 0, d1, 0, d2)> +// CHECK-LABEL: func @drop_one_trip_loops +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] + +// ----- + +#map0 = affine_map<(i, j) -> (i, j)> +#access = [#map0, #map0] +#trait = { + args_in = 1, + args_out = 1, + iterator_types = ["parallel", "parallel"], + indexing_maps = #access, + library_call = "some_external_func" +} + +func @drop_all_loops(%arg0 : tensor<1x1xf32>) -> tensor<1x1xf32> +{ + %0 = linalg.generic #trait %arg0 { + ^bb0(%arg1: f32) : + linalg.yield %arg1 : f32 + } : tensor<1x1xf32> -> tensor<1x1xf32> + return %0 : tensor<1x1xf32> +} +// CHECK-DAG: #[[MAP0:.*]] = affine_map<() -> (0, 0)> +// CHECK-LABEL: func @drop_all_loops +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP0]]] +// CHECK-SAME: iterator_types = [] + +// ----- + +#map0 = affine_map<(i, j) -> (i, j)> +#access = [#map0, #map0] +#trait = { + args_in = 1, + args_out = 1, + iterator_types = ["parallel", "parallel"], + indexing_maps = #access, + library_call = "some_external_func" +} + +func @drop_all_loops(%arg0 : memref<1x1xf32>, %arg1 : memref<1x1xf32>) +{ + linalg.generic #trait %arg0, %arg1 { + ^bb0(%arg2: f32, %arg3 : f32) : + linalg.yield %arg2 : f32 + } : memref<1x1xf32>, memref<1x1xf32> + return +} +// CHECK-DAG: #[[MAP0:.*]] = affine_map<() -> (0, 0)> +// CHECK-LABEL: func @drop_all_loops +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP0]]] +// CHECK-SAME: iterator_types = [] + +// ----- + +#accesses = [ + affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d1)> +] + +#trait = { + args_in = 1, + args_out = 1, + indexing_maps = #accesses, + iterator_types = ["parallel", "parallel"], + library_call = "some_external_fn" +} + +func @leading_dim_1_canonicalization(%arg0: tensor<1x5xf32>) -> tensor<5xf32> { + %0 = linalg.generic #trait %arg0 { + ^bb0(%arg2: f32): // no predecessors + linalg.yield %arg2 : f32 + } : tensor<1x5xf32> -> tensor<5xf32> + return %0 : tensor<5xf32> +} +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0) -> (0, d0)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0) -> (d0)> +// CHECK-LABEL: func @leading_dim_1_canonicalization +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] +// CHECK-SAME: iterator_types = ["parallel"]