diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -28,10 +28,6 @@ let summary = "Remove unit-extent dimension in Linalg ops on tensors"; let constructor = "mlir::createLinalgFoldUnitExtentDimsPass()"; let options = [ - Option<"foldOneTripLoopsOnly", "fold-one-trip-loops-only", "bool", - /*default=*/"false", - "Only folds the one-trip loops from Linalg ops on tensors " - "(for testing purposes only)">, Option<"useRankReducingSlices", "use-rank-reducing-slices", "bool", /*default=*/"false", "Generate rank-reducing slices instead of reassociative reshapes"> diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -419,6 +419,25 @@ using LinalgLoops = SmallVector; +/// Transformation to drop unit-extent dimensions from `linalg.generic` +/// operations. +struct ControlDropUnitDims { + enum class RankReductionStrategy { ReassociativeReshape, ExtractInsertSlice }; + + RankReductionStrategy rankReductionStrategy = + RankReductionStrategy::ReassociativeReshape; + + using ControlFnTy = std::function(Operation *)>; + ControlFnTy controlFn = [](Operation *op) { + if (auto genericOp = dyn_cast_or_null(op)) { + return llvm::to_vector(llvm::seq(0, genericOp.getNumLoops())); + } + return SmallVector{}; + }; +}; +LogicalResult dropUnitDims(RewriterBase &rewriter, GenericOp genericOp, + const ControlDropUnitDims &options); + /// Fuse two `linalg.generic` operations that have a producer-consumer /// relationship captured through `fusedOperand`. The method expects /// that `areElementwiseOpsFusable` returns true for the given `fusedOperand`. @@ -1496,11 +1515,8 @@ /// Patterns to fold unit-extent dimensions in operands/results of linalg ops on /// tensors via reassociative reshape ops. -void populateFoldUnitExtentDimsViaReshapesPatterns(RewritePatternSet &patterns); - -/// Patterns to fold unit-extent dimensions in operands/results of linalg ops on -/// tensors via rank-reducing slices. -void populateFoldUnitExtentDimsViaSlicesPatterns(RewritePatternSet &patterns); +void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns, + ControlDropUnitDims &options); /// A pattern that converts init operands to input operands. void populateMoveInitOperandsToInputPattern(RewritePatternSet &patterns); diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -156,12 +156,16 @@ void transform::ApplyFoldUnitExtentDimsViaReshapesPatternsOp::populatePatterns( RewritePatternSet &patterns) { - linalg::populateFoldUnitExtentDimsViaReshapesPatterns(patterns); + linalg::ControlDropUnitDims options; + linalg::populateFoldUnitExtentDimsPatterns(patterns, options); } void transform::ApplyFoldUnitExtentDimsViaSlicesPatternsOp::populatePatterns( RewritePatternSet &patterns) { - linalg::populateFoldUnitExtentDimsViaSlicesPatterns(patterns); + linalg::ControlDropUnitDims options; + options.rankReductionStrategy = + linalg::ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice; + linalg::populateFoldUnitExtentDimsPatterns(patterns, options); } void transform::ApplyTilingCanonicalizationPatternsOp::populatePatterns( diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -43,196 +43,6 @@ using namespace mlir::linalg; namespace { -enum class RankReductionStrategy { ReassociativeReshape, ExtractInsertSlice }; -} // namespace - -/// Implements a pass that canonicalizes the uses of unit-extent dimensions for -/// broadcasting. For example, -/// -/// ```mlir -/// #accesses = [ -/// affine_map<(d0, d1) -> (0, d1)>, -/// affine_map<(d0, d1) -> (d0, 0)>, -/// affine_map<(d0, d1) -> (d0, d1)> -/// ] -/// -/// #trait = { -/// args_in = 2, -/// args_out = 1, -/// indexing_maps = #accesses, -/// iterator_types = ["parallel", "parallel"], -/// library_call = "some_external_fn" -/// } -/// -/// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) -> -/// tensor<5x5xf32> -/// { -/// %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>] : -/// tensor<5xf32> into tensor<1x5xf32> -/// %1 = linalg.tensor_reshape %arg1 [affine_map<(d0, d1) -> (d0, d1)>] : -/// tensor<5xf32> into tensor<5x1xf32> -/// %2 = linalg.generic #trait %0, %1 { -/// ^bb0(%arg2: f32, %arg3: f32): -/// %3 = arith.addf %arg2, %arg3 : f32 -/// linalg.yield %3 : f32 -/// } : tensor<1x5xf32>, tensor<5x1xf32> -> tensor<5x5xf32> -/// return %2 : tensor<5x5xf32> -/// } -/// -/// would canonicalize to -/// -/// ```mlir -/// #accesses = [ -/// affine_map<(d0, d1) -> (d1)>, -/// affine_map<(d0, d1) -> (d0)>, -/// affine_map<(d0, d1) -> (d0, d1)> -/// ] -/// -/// #trait = { -/// args_in = 2, -/// args_out = 1, -/// indexing_maps = #accesses, -/// iterator_types = ["parallel", "parallel"], -/// library_call = "some_external_fn" -/// } -/// -/// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) -> -/// tensor<5x5xf32> -/// { -/// %0 = linalg.generic #trait %arg0, %arg1 { -/// ^bb0(%arg2: f32, %arg3: f32): -/// %3 = arith.addf %arg2, %arg3 : f32 -/// linalg.yield %3 : f32 -/// } : tensor<5xf32>, tensor<5xf32> -> tensor<5x5xf32> -/// return %0 : tensor<5x5xf32> -/// } - -/// Given dims of the iteration space of a structured op that are known to be -/// single trip count (`unitDims`), return the indexing maps to use in the -/// canonicalized op with these dims removed, given the original `indexingMaps`. -static ArrayAttr replaceUnitDims(DenseSet &unitDims, - ArrayRef indexingMaps, - MLIRContext *context) { - if (indexingMaps.empty()) - return nullptr; - unsigned numIterationDims = indexingMaps.front().getNumDims(); - unsigned numSymbols = indexingMaps.front().getNumSymbols(); - - // Compute the replacement for each dim expr. - SmallVector dimReplacements; - dimReplacements.reserve(numIterationDims); - unsigned numKeptDims = 0; - for (unsigned dim : llvm::seq(0, numIterationDims)) { - if (unitDims.count(dim)) - dimReplacements.push_back(getAffineConstantExpr(0, context)); - else - dimReplacements.push_back(getAffineDimExpr(numKeptDims++, context)); - } - - // Symbols remain the same. - SmallVector symReplacements; - symReplacements.reserve(numSymbols); - for (unsigned symbol : llvm::seq(0, numSymbols)) - symReplacements.push_back(getAffineSymbolExpr(symbol, context)); - - SmallVector newIndexingMaps; - newIndexingMaps.reserve(indexingMaps.size()); - for (AffineMap operandMap : indexingMaps) { - // Expected indexing maps to have no symbols. - if (operandMap.getNumSymbols()) - return nullptr; - newIndexingMaps.push_back(simplifyAffineMap( - operandMap.replaceDimsAndSymbols(dimReplacements, symReplacements, - numIterationDims - unitDims.size(), - numSymbols))); - } - - // Check that the new index maps are invertible. If not, something went - // wrong, so abort. - if (!inversePermutation(concatAffineMaps(newIndexingMaps))) - return nullptr; - return ArrayAttr::get(context, - llvm::to_vector<4>(llvm::map_range( - newIndexingMaps, [](AffineMap map) -> Attribute { - return AffineMapAttr::get(map); - }))); -} - -/// Update the index accesses of linalg operations having index semantics. -static void replaceUnitDimIndexOps(GenericOp genericOp, - const DenseSet &unitDims, - PatternRewriter &rewriter) { - for (IndexOp indexOp : - llvm::make_early_inc_range(genericOp.getBody()->getOps())) { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPoint(indexOp); - if (unitDims.count(indexOp.getDim()) != 0) { - rewriter.replaceOpWithNewOp(indexOp, 0); - } else { - // Update the dimension of the index operation if needed. - unsigned droppedDims = llvm::count_if( - unitDims, [&](unsigned dim) { return dim < indexOp.getDim(); }); - if (droppedDims != 0) - rewriter.replaceOpWithNewOp(indexOp, - indexOp.getDim() - droppedDims); - } - } -} - -namespace { -/// Pattern to fold unit-trip count loops in GenericOps. -struct FoldUnitDimLoops : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(GenericOp genericOp, - PatternRewriter &rewriter) const override { - SmallVector indexingMaps = genericOp.getIndexingMapsArray(); - if (indexingMaps.empty()) - return failure(); - - // Check if any of the iteration dimensions are unit-trip count. They will - // end up being unit-trip count if they are used to index into a unit-dim - // tensor/memref. - AffineMap invertedMap = inversePermutation(concatAffineMaps(indexingMaps)); - if (!invertedMap) - return failure(); - SmallVector dims = genericOp.getStaticShape(); - - DenseSet unitDims; - SmallVector unitDimsReductionLoops; - ArrayAttr iteratorTypes = genericOp.getIteratorTypes(); - for (const auto &expr : enumerate(invertedMap.getResults())) { - if (AffineDimExpr dimExpr = expr.value().dyn_cast()) - if (dims[dimExpr.getPosition()] == 1) - unitDims.insert(expr.index()); - } - - if (unitDims.empty()) - return failure(); - - // Compute the modified indexing maps. - MLIRContext *context = rewriter.getContext(); - ArrayAttr newIndexingMapAttr = - replaceUnitDims(unitDims, indexingMaps, context); - if (!newIndexingMapAttr) - return genericOp.emitError("unable to compute modified indexing_maps"); - - // Compute the iterator types of the modified op by dropping the one-trip - // count loops. - SmallVector newIteratorTypes; - for (const auto &attr : llvm::enumerate(iteratorTypes)) { - if (!unitDims.count(attr.index())) - newIteratorTypes.push_back(attr.value()); - } - - rewriter.startRootUpdate(genericOp); - genericOp.setIndexingMapsAttr(newIndexingMapAttr); - genericOp.setIteratorTypesAttr(ArrayAttr::get(context, newIteratorTypes)); - replaceUnitDimIndexOps(genericOp, unitDims, rewriter); - rewriter.finalizeRootUpdate(genericOp); - return success(); - } -}; - /// Pattern to move init operands to ins when all the loops are parallel and /// blockArgument corresponding to init is used in the region. This is a fix-up /// when unit reduction dimensions are all folded away. In this context, it @@ -351,243 +161,405 @@ return success(); } }; +} // namespace + +//===---------------------------------------------------------------------===// +// Drop loops that are unit-extents within Linalg operations. +//===---------------------------------------------------------------------===// + +/// Implements a pass that canonicalizes the uses of unit-extent dimensions for +/// broadcasting. For example, +/// +/// ```mlir +/// #accesses = [ +/// affine_map<(d0, d1) -> (0, d1)>, +/// affine_map<(d0, d1) -> (d0, 0)>, +/// affine_map<(d0, d1) -> (d0, d1)> +/// ] +/// +/// #trait = { +/// args_in = 2, +/// args_out = 1, +/// indexing_maps = #accesses, +/// iterator_types = ["parallel", "parallel"], +/// library_call = "some_external_fn" +/// } +/// +/// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) -> +/// tensor<5x5xf32> +/// { +/// %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>] : +/// tensor<5xf32> into tensor<1x5xf32> +/// %1 = linalg.tensor_reshape %arg1 [affine_map<(d0, d1) -> (d0, d1)>] : +/// tensor<5xf32> into tensor<5x1xf32> +/// %2 = linalg.generic #trait %0, %1 { +/// ^bb0(%arg2: f32, %arg3: f32): +/// %3 = arith.addf %arg2, %arg3 : f32 +/// linalg.yield %3 : f32 +/// } : tensor<1x5xf32>, tensor<5x1xf32> -> tensor<5x5xf32> +/// return %2 : tensor<5x5xf32> +/// } +/// +/// would canonicalize to +/// +/// ```mlir +/// #accesses = [ +/// affine_map<(d0, d1) -> (d1)>, +/// affine_map<(d0, d1) -> (d0)>, +/// affine_map<(d0, d1) -> (d0, d1)> +/// ] +/// +/// #trait = { +/// args_in = 2, +/// args_out = 1, +/// indexing_maps = #accesses, +/// iterator_types = ["parallel", "parallel"], +/// library_call = "some_external_fn" +/// } +/// +/// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) -> +/// tensor<5x5xf32> +/// { +/// %0 = linalg.generic #trait %arg0, %arg1 { +/// ^bb0(%arg2: f32, %arg3: f32): +/// %3 = arith.addf %arg2, %arg3 : f32 +/// linalg.yield %3 : f32 +/// } : tensor<5xf32>, tensor<5xf32> -> tensor<5x5xf32> +/// return %0 : tensor<5x5xf32> +/// } +/// Update the index accesses of linalg operations having index semantics. +static void +replaceUnitDimIndexOps(GenericOp genericOp, + const llvm::SmallDenseSet &unitDims, + RewriterBase &rewriter) { + for (IndexOp indexOp : + llvm::make_early_inc_range(genericOp.getBody()->getOps())) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(indexOp); + if (unitDims.count(indexOp.getDim()) != 0) { + rewriter.replaceOpWithNewOp(indexOp, 0); + } else { + // Update the dimension of the index operation if needed. + unsigned droppedDims = llvm::count_if( + unitDims, [&](unsigned dim) { return dim < indexOp.getDim(); }); + if (droppedDims != 0) + rewriter.replaceOpWithNewOp(indexOp, + indexOp.getDim() - droppedDims); + } + } +} + +/// Expand the given `value` so that the type matches the type of `origDest`. +/// The `reassociation` is used when `rankReductionStrategy` is set to +/// `RankReductionStrategy::ReassociativeReshape`. +static Value +expandValue(RewriterBase &rewriter, Location loc, Value result, Value origDest, + ArrayRef reassociation, + ControlDropUnitDims::RankReductionStrategy rankReductionStrategy) { + // There are no results for memref outputs. + auto origResultType = cast(origDest.getType()); + if (rankReductionStrategy == + ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) { + unsigned rank = origResultType.getRank(); + SmallVector offsets(rank, rewriter.getIndexAttr(0)); + SmallVector sizes = + tensor::getMixedSizes(rewriter, loc, origDest); + SmallVector strides(rank, rewriter.getIndexAttr(1)); + return rewriter.createOrFold( + loc, result, origDest, offsets, sizes, strides); + } + + assert(rankReductionStrategy == + ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape && + "unknown rank reduction strategy"); + return rewriter.create(loc, origResultType, result, + reassociation); +} + +/// Collapse the given `value` so that the type matches the type of +/// `origOutput`. The `reassociation` is used when `rankReductionStrategy` is +/// set to `RankReductionStrategy::ReassociativeReshape`. +static Value collapseValue( + RewriterBase &rewriter, Location loc, Value operand, + ArrayRef targetShape, ArrayRef reassociation, + ControlDropUnitDims::RankReductionStrategy rankReductionStrategy) { + if (auto memrefType = dyn_cast(operand.getType())) { + if (rankReductionStrategy == + ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) { + FailureOr rankReducingExtract = + memref::SubViewOp::rankReduceIfNeeded(rewriter, loc, operand, + targetShape); + assert(succeeded(rankReducingExtract) && "not a unit-extent collapse"); + return *rankReducingExtract; + } + + assert( + rankReductionStrategy == + ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape && + "unknown rank reduction strategy"); + MemRefLayoutAttrInterface layout; + auto targetType = MemRefType::get(targetShape, memrefType.getElementType(), + layout, memrefType.getMemorySpace()); + return rewriter.create(loc, targetType, operand, + reassociation); + } + if (auto tensorType = dyn_cast(operand.getType())) { + if (rankReductionStrategy == + ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) { + FailureOr rankReducingExtract = + tensor::ExtractSliceOp::rankReduceIfNeeded(rewriter, loc, operand, + targetShape); + assert(succeeded(rankReducingExtract) && "not a unit-extent collapse"); + return *rankReducingExtract; + } + + assert( + rankReductionStrategy == + ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape && + "unknown rank reduction strategy"); + auto targetType = + RankedTensorType::get(targetShape, tensorType.getElementType()); + return rewriter.create(loc, targetType, operand, + reassociation); + } + llvm_unreachable("unsupported operand type"); +} + +/// Compute the modified metadata for an operands of operation +/// whose unit dims are being dropped. Return the new indexing map +/// to use, the shape of the operand in the replacement op +/// and the `reassocation` to use to go from original operand shape +/// to modified operand shape. struct UnitExtentReplacementInfo { AffineMap indexMap; SmallVector reassociation; SmallVector targetShape; }; -} // namespace - -/// Utility function for replacing operands/results to a linalg generic -/// operation with unit-extent dimensions. These can be replaced with -/// an operand/result with the unit-extent dimension removed. This is only done -/// if the indexing map used to access that dimension has a -/// AffineConstantExpr of value 0. Given the `type` of an result/operand of a -/// Linalg op, and its `indexMap` the utility function returns: -/// - the new type with dimensions of size 1 removed. -/// - modified index map that can be used to access the replaced result/operand -/// - the reassociation that converts from the original tensor type to the -/// modified tensor type. -static std::optional -replaceUnitExtents(GenericOp genericOp, OpOperand *opOperand, - MLIRContext *context) { +static UnitExtentReplacementInfo dropUnitExtentFromOperandMetadata( + MLIRContext *context, GenericOp genericOp, OpOperand *opOperand, + llvm::SmallDenseMap &oldDimsToNewDimsMap, + ArrayRef dimReplacements) { + UnitExtentReplacementInfo info; + ReassociationIndices reassociationGroup; + SmallVector newIndexExprs; AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand); - ArrayRef shape = genericOp.getShape(opOperand); + ArrayRef operandShape = genericOp.getShape(opOperand); ArrayRef exprs = indexingMap.getResults(); - SmallVector newIndexExprs; - SmallVector newShape; - int64_t origRank = genericOp.getRank(opOperand); - AffineExpr zeroExpr = getAffineConstantExpr(0, context); - auto isUnitExtent = [&](int64_t dim) -> bool { - return shape[dim] == 1 && exprs[dim] == zeroExpr; + auto isUnitDim = [&](unsigned dim) { + if (auto dimExpr = exprs[dim].dyn_cast()) { + unsigned oldPosition = dimExpr.getPosition(); + return !oldDimsToNewDimsMap.count(oldPosition); + } + // Handle the other case where the shape is 1, and is accessed using a + // constant 0. + if (operandShape[dim] == 1) { + auto constAffineExpr = exprs[dim].dyn_cast(); + return constAffineExpr && constAffineExpr.getValue() == 0; + } + return false; }; - // Early return for memrefs with affine maps to represent that we will always - // leave them unchanged. - Type actualType = opOperand->get().getType(); - if (auto memref = dyn_cast(actualType)) { - if (!memref.getLayout().isIdentity()) - return std::nullopt; - } - int64_t dim = 0; - SmallVector reassociation; - ReassociationIndices reassociationGroup; - // Fold dimensions that are unit-extent at the beginning of the tensor. - while (dim < origRank && isUnitExtent(dim)) + while (dim < operandShape.size() && isUnitDim(dim)) reassociationGroup.push_back(dim++); - while (dim < origRank) { - assert(!isUnitExtent(dim) && "expected non unit-extent"); + while (dim < operandShape.size()) { + assert(!isUnitDim(dim) && "expected non unit-extent"); reassociationGroup.push_back(dim); - newIndexExprs.push_back(exprs[dim]); - newShape.push_back(shape[dim]); + AffineExpr newExpr = exprs[dim].replaceDims(dimReplacements); + newIndexExprs.push_back(newExpr); + info.targetShape.push_back(operandShape[dim]); ++dim; // Fold all following dimensions that are unit-extent. - while (dim < origRank && isUnitExtent(dim)) + while (dim < operandShape.size() && isUnitDim(dim)) { reassociationGroup.push_back(dim++); - reassociation.push_back(reassociationGroup); + } + info.reassociation.push_back(reassociationGroup); reassociationGroup.clear(); } - - // Return if the rank was not reduced. - if (origRank == static_cast(newShape.size())) - return std::nullopt; - - UnitExtentReplacementInfo info = { - /*indexMap=*/AffineMap::get(indexingMap.getNumDims(), - indexingMap.getNumSymbols(), newIndexExprs, - context), - /*reassociation=*/reassociation, /*targetShape=*/newShape}; + info.indexMap = + AffineMap::get(oldDimsToNewDimsMap.size(), indexingMap.getNumSymbols(), + newIndexExprs, context); return info; } -namespace { +LogicalResult linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp, + const ControlDropUnitDims &options) { + SmallVector indexingMaps = genericOp.getIndexingMapsArray(); + if (indexingMaps.empty()) + return failure(); + + // 1. Check if any of the iteration dimensions are unit-trip count. They will + // end up being unit-trip count if they are used to index into a unit-dim + // tensor/memref. + AffineMap invertedMap = inversePermutation(concatAffineMaps(indexingMaps)); + if (!invertedMap) { + return rewriter.notifyMatchFailure(genericOp, + "invalid indexing maps for operation"); + } + SmallVector dims = genericOp.getStaticShape(); -/// Pattern to replace tensor/buffer operands/results that are unit extents. -struct ReplaceUnitExtents : public OpRewritePattern { - ReplaceUnitExtents(MLIRContext *ctx, - RankReductionStrategy rankReductionStrategy) - : OpRewritePattern(ctx), - rankReductionStrategy(rankReductionStrategy) {} - - // Expand the given value. - Value expandValue(Value result, Value origOutput, - ArrayRef reassociation, Location loc, - PatternRewriter &rewriter) const { - // There are no results for memref outputs. - auto origResultType = cast(origOutput.getType()); - if (rankReductionStrategy == RankReductionStrategy::ExtractInsertSlice) { - unsigned rank = origResultType.getRank(); - SmallVector offsets(rank, rewriter.getIndexAttr(0)); - SmallVector sizes = - tensor::getMixedSizes(rewriter, loc, origOutput); - SmallVector strides(rank, rewriter.getIndexAttr(1)); - return rewriter.createOrFold( - loc, result, origOutput, offsets, sizes, strides); + // 1a. Get the allowed list of dimensions to drop from the `options`. + SmallVector allowedUnitDims = options.controlFn(genericOp); + if (allowedUnitDims.empty()) { + return rewriter.notifyMatchFailure( + genericOp, "control function returns no allowed unit dims to prune"); + } + llvm::SmallDenseSet unitDimsFilter(allowedUnitDims.begin(), + allowedUnitDims.end()); + llvm::SmallDenseSet unitDims; + ArrayAttr iteratorTypes = genericOp.getIteratorTypes(); + for (const auto &expr : enumerate(invertedMap.getResults())) { + if (AffineDimExpr dimExpr = expr.value().dyn_cast()) { + if (dims[dimExpr.getPosition()] == 1 && + unitDimsFilter.count(expr.index())) + unitDims.insert(expr.index()); } + } - assert(rankReductionStrategy == - RankReductionStrategy::ReassociativeReshape && - "unknown rank reduction strategy"); - return rewriter.create(loc, origResultType, result, - reassociation); + // 2. Compute the iterator types of the modified op by dropping the one-trip + // count loops. + SmallVector newIteratorTypes; + llvm::SmallDenseMap oldDimToNewDimMap; + SmallVector dimReplacements; + unsigned newDims = 0; + for (auto [index, attr] : + llvm::enumerate(genericOp.getIteratorTypesArray())) { + if (unitDims.count(index)) { + dimReplacements.push_back( + getAffineConstantExpr(0, rewriter.getContext())); + } else { + newIteratorTypes.push_back(attr); + oldDimToNewDimMap[index] = newDims; + dimReplacements.push_back( + getAffineDimExpr(newDims, rewriter.getContext())); + newDims++; + } } - // Collapse the given value. - Value collapseValue(Value operand, ArrayRef targetShape, - ArrayRef reassociation, - Location loc, PatternRewriter &rewriter) const { - if (auto memrefType = dyn_cast(operand.getType())) { - if (rankReductionStrategy == RankReductionStrategy::ExtractInsertSlice) { - FailureOr rankReducingExtract = - memref::SubViewOp::rankReduceIfNeeded(rewriter, loc, operand, - targetShape); - assert(succeeded(rankReducingExtract) && "not a unit-extent collapse"); - return *rankReducingExtract; - } - - assert(rankReductionStrategy == - RankReductionStrategy::ReassociativeReshape && - "unknown rank reduction strategy"); - MemRefLayoutAttrInterface layout; - auto targetType = - MemRefType::get(targetShape, memrefType.getElementType(), layout, - memrefType.getMemorySpace()); - return rewriter.create(loc, targetType, operand, - reassociation); + // 3. For each of the operands, find the + // - modified affine map to use. + // - shape of the operands after the unit-dims are dropped. + // - the reassociation indices used to convert from the original + // operand type to modified operand (needed only when using reshapes + // for rank reduction strategy) + // Note that the indexing maps might need changing even if there are no + // unit dimensions that are dropped to handle cases where `0` is used to + // access a unit-extent tensor. Consider moving this out of this specific + // transformation as a stand-alone transformation. Kept here right now due + // to legacy. + SmallVector newIndexingMaps; + SmallVector> reassociations; + SmallVector> targetShapes; + SmallVector collapsed; + auto hasCollapsibleType = [](OpOperand &operand) { + Type operandType = operand.get().getType(); + if (auto memrefOperandType = dyn_cast_or_null(operandType)) { + return memrefOperandType.getLayout().isIdentity(); + } else if (auto tensorOperandType = + dyn_cast(operandType)) { + return tensorOperandType.getEncoding() == nullptr; } - if (auto tensorType = dyn_cast(operand.getType())) { - if (rankReductionStrategy == RankReductionStrategy::ExtractInsertSlice) { - FailureOr rankReducingExtract = - tensor::ExtractSliceOp::rankReduceIfNeeded(rewriter, loc, operand, - targetShape); - assert(succeeded(rankReducingExtract) && "not a unit-extent collapse"); - return *rankReducingExtract; - } - - assert(rankReductionStrategy == - RankReductionStrategy::ReassociativeReshape && - "unknown rank reduction strategy"); - auto targetType = - RankedTensorType::get(targetShape, tensorType.getElementType()); - return rewriter.create(loc, targetType, operand, - reassociation); + return false; + }; + for (OpOperand &opOperand : genericOp->getOpOperands()) { + auto indexingMap = genericOp.getMatchingIndexingMap(&opOperand); + ArrayRef shape = genericOp.getShape(&opOperand); + if (!hasCollapsibleType(opOperand)) { + AffineMap newIndexingMap = indexingMap.replaceDimsAndSymbols( + dimReplacements, ArrayRef{}, oldDimToNewDimMap.size(), 0); + newIndexingMaps.push_back(newIndexingMap); + targetShapes.push_back(llvm::to_vector(shape)); + collapsed.push_back(false); + reassociations.push_back({}); + continue; } - llvm_unreachable("unsupported operand type"); + auto replacementInfo = dropUnitExtentFromOperandMetadata( + rewriter.getContext(), genericOp, &opOperand, oldDimToNewDimMap, + dimReplacements); + reassociations.push_back(replacementInfo.reassociation); + newIndexingMaps.push_back(replacementInfo.indexMap); + targetShapes.push_back(replacementInfo.targetShape); + collapsed.push_back(!(replacementInfo.indexMap.getNumResults() == + indexingMap.getNumResults())); } - LogicalResult matchAndRewrite(GenericOp genericOp, - PatternRewriter &rewriter) const override { - // Skip the pattern if the op has any tensor with special encoding. - if (llvm::any_of(genericOp->getOperandTypes(), [](Type type) { - auto tensorType = dyn_cast(type); - return tensorType && tensorType.getEncoding() != nullptr; - })) - return failure(); - MLIRContext *context = rewriter.getContext(); - Location loc = genericOp.getLoc(); - SmallVector oldOutputs(genericOp.getOutputs().begin(), - genericOp.getOutputs().end()); - - SmallVector newIndexingMaps; - SmallVector> reassociations; - SmallVector> targetShapes; - SmallVector collapsed; - for (OpOperand &opOperand : genericOp->getOpOperands()) { - auto replacementInfo = replaceUnitExtents(genericOp, &opOperand, context); - if (replacementInfo) { - reassociations.push_back(replacementInfo->reassociation); - newIndexingMaps.push_back(replacementInfo->indexMap); - targetShapes.push_back(replacementInfo->targetShape); - collapsed.push_back(true); - } else { - // If replaceUnitExtents cannot handle this case (or no unit dim was - // removed), maintain the same type, indexing map, and create a set of - // mappings representing an identity matrix. - newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(&opOperand)); - reassociations.emplace_back(); - targetShapes.emplace_back(); - collapsed.push_back(false); - } + // Abort if the indexing maps of the result operation are not invertible + // (i.e. not legal) or if no dimension was reduced. + if (newIndexingMaps == indexingMaps || + !inversePermutation(concatAffineMaps(newIndexingMaps))) + return failure(); + + Location loc = genericOp.getLoc(); + // 4. For each of the operands, collapse the operand to convert + // from original shape to shape in the modified operation if needed, + // either through use of reshapes or rank-reducing slices as + // specified in `options`. + SmallVector newOperands; + for (OpOperand &opOperand : genericOp->getOpOperands()) { + int64_t idx = opOperand.getOperandNumber(); + if (!collapsed[idx]) { + newOperands.push_back(opOperand.get()); + continue; } + newOperands.push_back(collapseValue(rewriter, loc, opOperand.get(), + targetShapes[idx], reassociations[idx], + options.rankReductionStrategy)); + } - // Abort if the indexing maps of the result operation are not invertible - // (i.e. not legal) or if no dimension was reduced. - if (!llvm::any_of(collapsed, [](bool c) { return c; }) || - !inversePermutation(concatAffineMaps(newIndexingMaps))) - return failure(); - - // Insert rank reductions. - SmallVector newOperands; - for (OpOperand &opOperand : genericOp->getOpOperands()) { - int64_t idx = opOperand.getOperandNumber(); - if (!collapsed[idx]) { - newOperands.push_back(opOperand.get()); - continue; - } - newOperands.push_back(collapseValue(opOperand.get(), targetShapes[idx], - reassociations[idx], loc, rewriter)); + // 5. Create the `linalg.generic` operation with the new operands, + // indexing maps, iterator types and result types. + ArrayRef newInputs = + ArrayRef(newOperands).take_front(genericOp.getNumDpsInputs()); + ArrayRef newOutputs = + ArrayRef(newOperands).take_back(genericOp.getNumDpsInits()); + SmallVector resultTypes; + resultTypes.reserve(genericOp.getNumResults()); + for (unsigned i : llvm::seq(0, genericOp.getNumResults())) + resultTypes.push_back(newOutputs[i].getType()); + GenericOp replacementOp = + rewriter.create(loc, resultTypes, newInputs, newOutputs, + newIndexingMaps, newIteratorTypes); + rewriter.inlineRegionBefore(genericOp.getRegion(), replacementOp.getRegion(), + replacementOp.getRegion().begin()); + // 5a. Replace `linalg.index` operations that refer to the dropped unit + // dimensions. + replaceUnitDimIndexOps(replacementOp, unitDims, rewriter); + + // 6. If any result type changes, insert a reshape/slice to convert from the + // original + // type to the new type. + SmallVector resultReplacements; + for (auto [index, result] : llvm::enumerate(replacementOp.getResults())) { + unsigned opOperandIndex = index + replacementOp.getNumDpsInputs(); + Value origDest = genericOp.getDpsInitOperand(index)->get(); + if (!collapsed[opOperandIndex]) { + resultReplacements.push_back(result); + continue; } + resultReplacements.push_back(expandValue(rewriter, loc, result, origDest, + reassociations[opOperandIndex], + options.rankReductionStrategy)); + } - // If any result type changes, insert a reshape to convert from the original - // type to the new type. - ArrayRef newInputs = - ArrayRef(newOperands).take_front(genericOp.getNumDpsInputs()); - ArrayRef newOutputs = - ArrayRef(newOperands).take_back(genericOp.getNumDpsInits()); - SmallVector resultTypes; - resultTypes.reserve(genericOp.getNumResults()); - for (unsigned i : llvm::seq(0, genericOp.getNumResults())) - resultTypes.push_back(newOutputs[i].getType()); - GenericOp replacementOp = rewriter.create( - loc, resultTypes, newInputs, newOutputs, newIndexingMaps, - genericOp.getIteratorTypesArray()); - rewriter.inlineRegionBefore(genericOp.getRegion(), - replacementOp.getRegion(), - replacementOp.getRegion().begin()); - - // If any result tensor has a modified shape, then add reshape to recover - // the original shape. - SmallVector resultReplacements; - for (const auto &result : llvm::enumerate(replacementOp.getResults())) { - unsigned index = result.index() + replacementOp.getNumDpsInputs(); - Value origOutput = oldOutputs[result.index()]; - if (!collapsed[result.index() + genericOp.getNumDpsInputs()]) { - resultReplacements.push_back(result.value()); - continue; - } - resultReplacements.push_back(expandValue( - result.value(), origOutput, reassociations[index], loc, rewriter)); - } + rewriter.replaceOp(genericOp, resultReplacements); + return success(); +} - rewriter.replaceOp(genericOp, resultReplacements); - return success(); +namespace { +struct DropUnitDims : public OpRewritePattern { + DropUnitDims(MLIRContext *context, ControlDropUnitDims options = {}, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), options(std::move(options)) {} + + LogicalResult matchAndRewrite(GenericOp genericOp, + PatternRewriter &rewriter) const override { + return dropUnitDims(rewriter, genericOp, options); } private: - RankReductionStrategy rankReductionStrategy; + ControlDropUnitDims options; }; } // namespace @@ -641,8 +613,8 @@ tensor::CollapseShapeOp reshapedSource; { OpBuilder::InsertionGuard g(rewriter); - // The only difference between InsertSliceOp and ParallelInsertSliceOp is - // the insertion point is just before the ParallelCombiningOp in the + // The only difference between InsertSliceOp and ParallelInsertSliceOp + // is the insertion point is just before the ParallelCombiningOp in the // parallel case. if (std::is_same::value) rewriter.setInsertionPoint(insertSliceOp->getParentOp()); @@ -660,13 +632,13 @@ /// Patterns that are used to canonicalize the use of unit-extent dims for /// broadcasting. -void mlir::linalg::populateFoldUnitExtentDimsViaReshapesPatterns( - RewritePatternSet &patterns) { +static void +populateFoldUnitExtentDimsViaReshapesPatterns(RewritePatternSet &patterns, + ControlDropUnitDims &options) { auto *context = patterns.getContext(); - patterns.add(context, - RankReductionStrategy::ReassociativeReshape); + patterns.add(context, options); // TODO: Patterns unrelated to unit dim folding should be factored out. - patterns.add, RankReducedInsertSliceOp>( context); @@ -679,12 +651,13 @@ memref::populateResolveShapedTypeResultDimsPatterns(patterns); } -void mlir::linalg::populateFoldUnitExtentDimsViaSlicesPatterns( - RewritePatternSet &patterns) { +static void +populateFoldUnitExtentDimsViaSlicesPatterns(RewritePatternSet &patterns, + ControlDropUnitDims &options) { auto *context = patterns.getContext(); - patterns.add(context, - RankReductionStrategy::ExtractInsertSlice); - patterns.add(context); + options.rankReductionStrategy = + ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice; + patterns.add(context, options); // TODO: Patterns unrelated to unit dim folding should be factored out. linalg::FillOp::getCanonicalizationPatterns(patterns, context); tensor::EmptyOp::getCanonicalizationPatterns(patterns, context); @@ -693,6 +666,18 @@ memref::populateResolveShapedTypeResultDimsPatterns(patterns); } +void mlir::linalg::populateFoldUnitExtentDimsPatterns( + RewritePatternSet &patterns, linalg::ControlDropUnitDims &options) { + if (options.rankReductionStrategy == + linalg::ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) { + populateFoldUnitExtentDimsViaSlicesPatterns(patterns, options); + } else if (options.rankReductionStrategy == + linalg::ControlDropUnitDims::RankReductionStrategy:: + ReassociativeReshape) { + populateFoldUnitExtentDimsViaReshapesPatterns(patterns, options); + } +} + void mlir::linalg::populateMoveInitOperandsToInputPattern( RewritePatternSet &patterns) { patterns.add(patterns.getContext()); @@ -706,15 +691,13 @@ Operation *op = getOperation(); MLIRContext *context = op->getContext(); RewritePatternSet patterns(context); - if (foldOneTripLoopsOnly) { - patterns.add(context); - } else if (useRankReducingSlices) { - populateFoldUnitExtentDimsViaSlicesPatterns(patterns); - populateMoveInitOperandsToInputPattern(patterns); - } else { - populateFoldUnitExtentDimsViaReshapesPatterns(patterns); - populateMoveInitOperandsToInputPattern(patterns); + ControlDropUnitDims options; + if (useRankReducingSlices) { + options.rankReductionStrategy = linalg::ControlDropUnitDims:: + RankReductionStrategy::ExtractInsertSlice; } + linalg::populateFoldUnitExtentDimsPatterns(patterns, options); + populateMoveInitOperandsToInputPattern(patterns); (void)applyPatternsAndFoldGreedily(op, std::move(patterns)); } }; diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir --- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir +++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir @@ -59,24 +59,24 @@ library_call = "some_external_func" } -func.func @drop_one_trip_loops_all_ones(%arg0 : tensor<1x1x1xf32>, %arg1 : f32, %shape: tensor) -> tensor { +func.func @drop_one_trip_loops_all_ones(%arg0 : tensor<1x1x1xf32>, %arg1 : f32, %shape: tensor<1x1x?x1x1xf32>) -> tensor<1x1x?x1x1xf32> { %0 = linalg.generic #trait ins(%arg0, %arg1 : tensor<1x1x1xf32>, f32) - outs(%shape : tensor) { + outs(%shape : tensor<1x1x?x1x1xf32>) { ^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32) : linalg.yield %arg3 : f32 - } -> tensor - return %0 : tensor + } -> tensor<1x1x?x1x1xf32> + return %0 : tensor<1x1x?x1x1xf32> } // CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> ()> -// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0) -> (0, d0, 0)> +// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0) -> (d0)> // CHECK-LABEL: func @drop_one_trip_loops_all_ones // CHECK: tensor.collapse_shape %{{.*}} [] -// CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1], [2, 3], [4]] +// CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1, 2, 3, 4]] // CHECK: linalg.generic // CHECK-SAME: indexing_maps = [#[[$MAP1]], #[[$MAP1]], #[[$MAP2]]] // CHECK-SAME: iterator_types = ["parallel"] -// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2, 3], [4]] +// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1, 2, 3, 4]] // ----- @@ -922,4 +922,3 @@ // CHECK-SLICES-LABEL: func @drop_all_loops // CHECK-SLICES: memref.subview %{{.*}}[0, 0] [1, 1] [1, 1] : memref<1x1xf32, 3> to memref, 3> // CHECK-SLICES: linalg.generic{{.*}}memref, 3> - diff --git a/mlir/test/Dialect/Linalg/fold-unit-trip-loops.mlir b/mlir/test/Dialect/Linalg/fold-unit-trip-loops.mlir deleted file mode 100644 --- a/mlir/test/Dialect/Linalg/fold-unit-trip-loops.mlir +++ /dev/null @@ -1,110 +0,0 @@ -// RUN: mlir-opt %s -split-input-file -pass-pipeline="builtin.module(func.func(linalg-fold-unit-extent-dims{fold-one-trip-loops-only}))" | FileCheck %s - -#accesses = [ - affine_map<(i, j, k, l, m) -> (i, k, m)>, - affine_map<(i, j, k, l, m) -> (i, k, j, l, m)> -] - -#trait = { - iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"], - indexing_maps = #accesses, - library_call = "some_external_func" -} - -func.func @drop_one_trip_loops(%arg0 : tensor, %shape: tensor) -> tensor -{ - %0 = linalg.generic #trait - ins(%arg0 : tensor) - outs(%shape : tensor) { - ^bb0(%arg1 : f32, %arg2 : f32) : - linalg.yield %arg1 : f32 - } -> tensor - return %0 : tensor -} -// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, 0, d2)> -// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, 0, d1, 0, d2)> -// CHECK-LABEL: func @drop_one_trip_loops -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]] -// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] - -// ----- - -#map0 = affine_map<(i, j) -> (i, j)> -#access = [#map0, #map0] -#trait = { - iterator_types = ["parallel", "parallel"], - indexing_maps = #access, - library_call = "some_external_func" -} - -func.func @drop_all_loops(%arg0 : tensor<1x1xf32>) -> tensor<1x1xf32> -{ - %0 = linalg.generic #trait - ins(%arg0 : tensor<1x1xf32>) - outs(%arg0 : tensor<1x1xf32>) { - ^bb0(%arg1: f32, %arg2: f32) : - linalg.yield %arg1 : f32 - } -> tensor<1x1xf32> - return %0 : tensor<1x1xf32> -} -// CHECK-DAG: #[[$MAP0:.*]] = affine_map<() -> (0, 0)> -// CHECK-LABEL: func @drop_all_loops -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP0]]] -// CHECK-SAME: iterator_types = [] - -// ----- - -#map0 = affine_map<(i, j) -> (i, j)> -#access = [#map0, #map0] -#trait = { - iterator_types = ["parallel", "parallel"], - indexing_maps = #access, - library_call = "some_external_func" -} - -func.func @drop_all_loops(%arg0 : memref<1x1xf32>, %arg1 : memref<1x1xf32>) -{ - linalg.generic #trait - ins(%arg0 : memref<1x1xf32>) - outs(%arg1 : memref<1x1xf32>) { - ^bb0(%arg2: f32, %arg3 : f32) : - linalg.yield %arg2 : f32 - } - return -} -// CHECK-DAG: #[[$MAP0:.*]] = affine_map<() -> (0, 0)> -// CHECK-LABEL: func @drop_all_loops -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP0]]] -// CHECK-SAME: iterator_types = [] - -// ----- - -#accesses = [ - affine_map<(d0, d1) -> (d0, d1)>, - affine_map<(d0, d1) -> (d1)> -] - -#trait = { - indexing_maps = #accesses, - iterator_types = ["parallel", "parallel"], - library_call = "some_external_fn" -} - -func.func @leading_dim_1_canonicalization(%arg0: tensor<1x5xf32>, %shape: tensor<5xf32>) -> tensor<5xf32> { - %0 = linalg.generic #trait - ins(%arg0 : tensor<1x5xf32>) - outs(%shape : tensor<5xf32>) { - ^bb0(%arg2: f32, %arg3: f32): - linalg.yield %arg2 : f32 - } -> tensor<5xf32> - return %0 : tensor<5xf32> -} -// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0) -> (0, d0)> -// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func @leading_dim_1_canonicalization -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]] -// CHECK-SAME: iterator_types = ["parallel"] diff --git a/mlir/test/Dialect/Linalg/test-drop-unit-dims.mlir b/mlir/test/Dialect/Linalg/test-drop-unit-dims.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/test-drop-unit-dims.mlir @@ -0,0 +1,26 @@ +// RUN: mlir-opt -test-linalg-drop-unit-dims --split-input-file %s | FileCheck %s + +// Drop only the outermost unit dimension (controlled using a control function) +func.func @drop_outermost_unit_dims(%arg0: tensor<1x1x42xf32>) -> tensor<1x1x42xf32> { + %0 = tensor.empty() : tensor<1x1x42xf32> + %1 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1, d2)>], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%arg0 : tensor<1x1x42xf32>) outs(%0 : tensor<1x1x42xf32>) { + ^bb0(%b0: f32, %b1 : f32): + %2 = arith.addf %b0, %b1 : f32 + linalg.yield %2 : f32 + } -> tensor<1x1x42xf32> + return %1 : tensor<1x1x42xf32> +} +// CHECK-LABEL: func @drop_outermost_unit_dims +// CHECK-SAME: %[[ARG0:.+]]: tensor<1x1x42xf32> +// CHECK: %[[OUTS:.+]] = tensor.empty() +// CHECK: %[[ARG0_RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1], [2]{{\]}} +// CHECK: %[[OUTS_RESHAPE:.+]] = tensor.collapse_shape %[[OUTS]] {{\[}}[0, 1], [2]{{\]}} +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK-SAME: ins(%[[ARG0_RESHAPE]] : +// CHECK-SAME: outs(%[[OUTS_RESHAPE]] : +// CHECK: %[[EXPAND_SHAPE:.+]] = tensor.expand_shape %[[GENERIC]] {{\[}}[0, 1], [2]{{\]}} +// CHECK: return %[[EXPAND_SHAPE]] diff --git a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt --- a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt @@ -2,6 +2,7 @@ add_mlir_library(MLIRLinalgTestPasses TestDataLayoutPropagation.cpp TestLinalgDecomposeOps.cpp + TestLinalgDropUnitDims.cpp TestLinalgElementwiseFusion.cpp TestLinalgFusionTransforms.cpp TestLinalgTransforms.cpp diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgDropUnitDims.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgDropUnitDims.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgDropUnitDims.cpp @@ -0,0 +1,73 @@ +//===- TestLinalgDropUnitDims.cpp - Test Linalg drop unit dims -----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass for testing the transformation to drop unit +// extent dimensions from `linalg.generic` operations. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; + +namespace { + +LogicalResult dropOutermostUnitDims(RewriterBase &rewriter, + linalg::GenericOp genericOp) { + linalg::ControlDropUnitDims options; + options.controlFn = [](Operation *op) { return SmallVector{0}; }; + return linalg::dropUnitDims(rewriter, genericOp, options); +} + +struct TestLinalgDropUnitDims + : public PassWrapper> { + + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgDropUnitDims) + + TestLinalgDropUnitDims() = default; + TestLinalgDropUnitDims(const TestLinalgDropUnitDims &pass) + : PassWrapper(pass) {} + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + StringRef getArgument() const final { return "test-linalg-drop-unit-dims"; } + + StringRef getDescriptions() const { + return "Test transformation to drop unit-extent dims from Linalg " + "operations"; + } + + void runOnOperation() override { + MLIRContext *context = &this->getContext(); + func::FuncOp funcOp = this->getOperation(); + IRRewriter rewriter(context); + SmallVector genericOps; + funcOp.walk( + [&](linalg::GenericOp genericOp) { genericOps.push_back(genericOp); }); + + for (auto genericOp : genericOps) { + rewriter.setInsertionPoint(genericOp); + dropOutermostUnitDims(rewriter, genericOp); + } + } +}; +} // namespace + +namespace mlir { +namespace test { +void registerTestLinalgDropUnitDims() { + PassRegistration(); +} +} // namespace test +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -100,6 +100,7 @@ void registerTestInterfaces(); void registerTestLastModifiedPass(); void registerTestLinalgDecomposeOps(); +void registerTestLinalgDropUnitDims(); void registerTestLinalgElementwiseFusion(); void registerTestLinalgGreedyFusion(); void registerTestLinalgTransforms(); @@ -222,6 +223,7 @@ mlir::test::registerTestInterfaces(); mlir::test::registerTestLastModifiedPass(); mlir::test::registerTestLinalgDecomposeOps(); + mlir::test::registerTestLinalgDropUnitDims(); mlir::test::registerTestLinalgElementwiseFusion(); mlir::test::registerTestLinalgGreedyFusion(); mlir::test::registerTestLinalgTransforms();