diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -19,12 +19,14 @@ #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" @@ -102,11 +104,11 @@ /// Given dims of the iteration space of a structured op that are known to be /// single trip count (`unitDims`), return the indexing maps to use in the /// canonicalized op with these dims removed, given the original `indexingMaps`. -static ArrayAttr replaceUnitDims(DenseSet &unitDims, - ArrayRef indexingMaps, - MLIRContext *context) { +Optional> +replaceUnitDims(DenseSet &unitDims, ArrayRef indexingMaps, + MLIRContext *context) { if (indexingMaps.empty()) - return nullptr; + return {}; unsigned numIterationDims = indexingMaps.front().getNumDims(); unsigned numSymbols = indexingMaps.front().getNumSymbols(); @@ -127,12 +129,12 @@ for (unsigned symbol : llvm::seq(0, numSymbols)) symReplacements.push_back(getAffineSymbolExpr(symbol, context)); - SmallVector newIndexingMaps; + SmallVector newIndexingMaps; newIndexingMaps.reserve(indexingMaps.size()); for (AffineMap operandMap : indexingMaps) { // Expected indexing maps to have no symbols. if (operandMap.getNumSymbols()) - return nullptr; + return llvm::None; newIndexingMaps.push_back(simplifyAffineMap( operandMap.replaceDimsAndSymbols(dimReplacements, symReplacements, numIterationDims - unitDims.size(), @@ -142,20 +144,25 @@ // Check that the new index maps are invertible. If not, something went // wrong, so abort. if (!inversePermutation(concatAffineMaps(newIndexingMaps))) - return nullptr; + return llvm::None; + + return newIndexingMaps; +} + +static ArrayAttr indexingMapsToArrayAttr(ArrayRef indexingMaps, + MLIRContext *context) { return ArrayAttr::get(context, llvm::to_vector<4>(llvm::map_range( - newIndexingMaps, [](AffineMap map) -> Attribute { + indexingMaps, [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); }))); } /// Update the index accesses of linalg operations having index semantics. -static void replaceUnitDimIndexOps(GenericOp genericOp, +static void replaceUnitDimIndexOps(Block *block, const DenseSet &unitDims, PatternRewriter &rewriter) { - for (IndexOp indexOp : - llvm::make_early_inc_range(genericOp.getBody()->getOps())) { + for (IndexOp indexOp : llvm::make_early_inc_range(block->getOps())) { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(indexOp); if (unitDims.count(indexOp.getDim()) != 0) { @@ -172,8 +179,8 @@ } namespace { -/// Pattern to fold unit-trip count loops in GenericOps. -struct FoldUnitDimLoops : public OpRewritePattern { +/// Pattern to fold parallel unit-trip count loops in GenericOps. +struct FoldParallelUnitDimLoops : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const override { @@ -190,12 +197,13 @@ SmallVector dims = genericOp.getStaticShape(); DenseSet unitDims; - SmallVector unitDimsReductionLoops; - ArrayAttr iteratorTypes = genericOp.getIteratorTypes(); + auto iteratorTypes = genericOp.getIteratorTypesArray(); for (const auto &expr : enumerate(invertedMap.getResults())) { - if (AffineDimExpr dimExpr = expr.value().dyn_cast()) - if (dims[dimExpr.getPosition()] == 1) + if (AffineDimExpr dimExpr = expr.value().dyn_cast()) { + if (dims[dimExpr.getPosition()] == 1 && + isParallelIterator(iteratorTypes[expr.index()])) unitDims.insert(expr.index()); + } } if (unitDims.empty()) @@ -203,28 +211,221 @@ // Compute the modified indexing maps. MLIRContext *context = rewriter.getContext(); - ArrayAttr newIndexingMapAttr = + Optional> newIndexingMaps = replaceUnitDims(unitDims, indexingMaps, context); - if (!newIndexingMapAttr) + if (!newIndexingMaps) return genericOp.emitError("unable to compute modified indexing_maps"); + ArrayAttr newIndexingMapAttr = + indexingMapsToArrayAttr(newIndexingMaps.value(), context); // Compute the iterator types of the modified op by dropping the one-trip // count loops. - SmallVector newIteratorTypes; + SmallVector newIteratorTypes; for (const auto &attr : llvm::enumerate(iteratorTypes)) { if (!unitDims.count(attr.index())) - newIteratorTypes.push_back(attr.value()); + newIteratorTypes.push_back( + IteratorTypeAttr::get(context, attr.value())); } rewriter.startRootUpdate(genericOp); genericOp.setIndexingMapsAttr(newIndexingMapAttr); genericOp.setIteratorTypesAttr(ArrayAttr::get(context, newIteratorTypes)); - replaceUnitDimIndexOps(genericOp, unitDims, rewriter); + replaceUnitDimIndexOps(genericOp.getBody(), unitDims, rewriter); rewriter.finalizeRootUpdate(genericOp); return success(); } }; +static Value getDimValue(OpBuilder &builder, Location loc, Value v, + int64_t dim) { + ShapedType type = v.getType().cast(); + if (!type.isDynamicDim(dim)) { + return builder.create(loc, type.getDimSize(dim)); + } + return TypeSwitch(v.getType()) + .Case([&](RankedTensorType t) -> Value { + return builder.create(loc, v, dim); + }) + .Case([&](MemRefType t) -> Value { + return builder.create(loc, v, dim); + }); +} + +static OpFoldResult getDim(OpBuilder &builder, Location loc, Value v, + int64_t dim) { + auto t = v.getType().cast(); + if (t.isDynamicDim(dim)) { + return getDimValue(builder, loc, v, dim); + } + return builder.getI64IntegerAttr(t.getDimSize(dim)); +} + +static SmallVector getDims(OpBuilder &builder, Location loc, + Value shapedTypeValue) { + return llvm::to_vector(llvm::map_range( + llvm::seq( + 0, shapedTypeValue.getType().cast().getRank()), + [&](int64_t dim) { return getDim(builder, loc, shapedTypeValue, dim); })); +} + +/// Pattern to replace outs with tensor.empty() ops if it meets below +/// conditions: +/// 1) The corresponding block argument is not used. +/// 2) The output operand is a LinalgOp. +struct ReplaceWithEmptyTensorIfUnused : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(GenericOp genericOp, + PatternRewriter &rewriter) const override { + if (!genericOp.hasTensorSemantics()) + return failure(); + + int numInput = genericOp.getNumDpsInputs(); + int numInit = genericOp.getNumDpsInits(); + SmallVector candidates; + for (int i = 0; i < numInit; ++i) + if (genericOp.getBody()->getArgument(numInput + i).use_empty()) + candidates.push_back(i); + + if (candidates.empty()) + return failure(); + + Location loc = genericOp.getLoc(); + bool changed = false; + rewriter.startRootUpdate(genericOp); + + for (auto idx : candidates) { + OpBuilder::InsertionGuard guard(rewriter); + Value operand = genericOp.getDpsInitOperand(idx)->get(); + if (!operand.getDefiningOp()) continue; + rewriter.setInsertionPointAfterValue(operand); + auto elemType = operand.getType().cast().getElementType(); + auto empty = rewriter.create( + loc, getDims(rewriter, loc, operand), elemType); + genericOp.setDpsInitOperand(idx, empty.getResult()); + changed = true; + } + + rewriter.finalizeRootUpdate(genericOp); + return success(changed); + } +}; + +/// Pattern to fold reduction unit-trip count loops in GenericOps. The output +/// operands that are changed will be added to input operands. The outputs of +/// the generic op is still the same. For tensor cases, they will be cleaned up +/// by ReplaceWithEmptyTensorIfUnused pattern. For memref cases, we should keep +/// the buffer, so the result can be yield to the correct buffer. +struct FoldReductionUnitDimLoops : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(GenericOp genericOp, + PatternRewriter &rewriter) const override { + SmallVector indexingMaps = genericOp.getIndexingMapsArray(); + if (indexingMaps.empty()) + return failure(); + + // Check if any of the iteration dimensions are unit-trip count. They will + // end up being unit-trip count if they are used to index into a unit-dim + // tensor/memref. + AffineMap invertedMap = inversePermutation(concatAffineMaps(indexingMaps)); + if (!invertedMap) + return failure(); + SmallVector dims = genericOp.getStaticShape(); + + DenseSet unitDims; + auto iteratorTypes = genericOp.getIteratorTypesArray(); + for (const auto &expr : enumerate(invertedMap.getResults())) { + if (AffineDimExpr dimExpr = expr.value().dyn_cast()) { + if (dims[dimExpr.getPosition()] == 1 && + isReductionIterator(iteratorTypes[expr.index()])) + unitDims.insert(expr.index()); + } + } + + if (unitDims.empty()) + return failure(); + + // Compute the modified indexing maps. + MLIRContext *context = rewriter.getContext(); + Optional> updatedIndexingMaps = + replaceUnitDims(unitDims, indexingMaps, context); + if (!updatedIndexingMaps) + return genericOp.emitError("unable to compute modified indexing_maps"); + + // Compute the iterator types of the modified op by dropping the one-trip + // count loops. + SmallVector newIteratorTypes; + for (const auto &attr : llvm::enumerate(iteratorTypes)) { + if (!unitDims.count(attr.index())) + newIteratorTypes.push_back(attr.value()); + } + + int64_t origNumInput = genericOp.getNumDpsInputs(); + SmallVector inputOperands = genericOp.getDpsInputOperands(); + SmallVector outputOperands = genericOp.getDpsInitOperands(); + SmallVector newIndexingMaps; + newIndexingMaps.append( + updatedIndexingMaps.value().begin(), + std::next(updatedIndexingMaps.value().begin(), origNumInput)); + SmallVector changedOutput; + for (int i = 0, e = genericOp.getNumDpsInits(); i < e; ++i) { + AffineMap map = updatedIndexingMaps.value()[i + origNumInput]; + if (map != indexingMaps[i + origNumInput]) { + changedOutput.push_back(i); + inputOperands.push_back(outputOperands[i]); + newIndexingMaps.push_back(map); + } + } + newIndexingMaps.append( + std::next(updatedIndexingMaps.value().begin(), origNumInput), + updatedIndexingMaps.value().end()); + + Location loc = genericOp.getLoc(); + auto newOp = rewriter.create( + loc, genericOp.getResultTypes(), inputOperands, outputOperands, + newIndexingMaps, newIteratorTypes, /*bodyBuild=*/nullptr, + linalg::getPrunedAttributeList(genericOp)); + + Region ®ion = newOp.getRegion(); + + Block *block = new Block(); + region.push_back(block); + BlockAndValueMapping mapper; + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(block); + for (int i = 0; i < origNumInput; ++i) { + mapper.map(genericOp.getBody()->getArgument(i), + block->addArgument( + getElementTypeOrSelf(inputOperands[i].getType()), loc)); + } + for (auto en : llvm::enumerate(changedOutput)) { + mapper.map(genericOp.getBody()->getArgument(en.value() + origNumInput), + block->addArgument( + getElementTypeOrSelf( + inputOperands[en.index() + origNumInput].getType()), + loc)); + } + for (int i = 0, e = outputOperands.size(); i < e; ++i) { + AffineMap map = updatedIndexingMaps.value()[i + origNumInput]; + if (map != indexingMaps[i + origNumInput]) { + block->addArgument(getElementTypeOrSelf(outputOperands[i].getType()), + loc); + } else { + mapper.map(genericOp.getBody()->getArgument(i + origNumInput), + block->addArgument( + getElementTypeOrSelf(outputOperands[i].getType()), loc)); + } + } + + for (auto &op : genericOp.getBody()->getOperations()) { + rewriter.clone(op, mapper); + } + replaceUnitDimIndexOps(block, unitDims, rewriter); + rewriter.replaceOp(genericOp, newOp.getResults()); + + return success(); + } +}; + struct UnitExtentReplacementInfo { Type type; AffineMap indexMap; @@ -536,7 +737,9 @@ void mlir::linalg::populateFoldUnitExtentDimsPatterns( RewritePatternSet &patterns) { auto *context = patterns.getContext(); - patterns.add, RankReducedInsertSliceOp>( context); @@ -544,6 +747,8 @@ tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context); tensor::EmptyOp::getCanonicalizationPatterns(patterns, context); tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context); + memref::populateResolveRankedShapeTypeResultDimsPatterns(patterns); + memref::populateResolveShapedTypeResultDimsPatterns(patterns); } namespace { @@ -554,10 +759,12 @@ Operation *op = getOperation(); MLIRContext *context = op->getContext(); RewritePatternSet patterns(context); - if (foldOneTripLoopsOnly) - patterns.add(context); - else + if (foldOneTripLoopsOnly) { + patterns.add(context); + } else { populateFoldUnitExtentDimsPatterns(patterns); + } (void)applyPatternsAndFoldGreedily(op, std::move(patterns)); } }; diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir --- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir +++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir @@ -344,18 +344,21 @@ } -> tensor<1x?xf32> return %3 : tensor<1x?xf32> } -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1) -> (d0)> +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0)> // CHECK: func @unit_dim_for_reduction // CHECK-SAME: %[[ARG0:.+]]: tensor<1x?x1x?xf32> +// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index +// CHECK-DAG: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C3]] : tensor<1x?x1x?xf32> // CHECK-DAG: %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]] -// CHECK: %[[INIT:.+]] = tensor.empty(%{{.+}}) : tensor +// CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM]]) : tensor // CHECK: %[[FILL:.+]] = linalg.fill ins(%{{.+}}{{.*}}outs(%[[INIT]] +// CHECK: %[[INIT2:.+]] = tensor.empty(%[[DIM]]) : tensor // CHECK: %[[RESULT:.+]] = linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP3]]] +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP1]]] // CHECK-SAME: iterator_types = ["parallel", "reduction"] -// CHECK-SAME: ins(%[[RESHAPE]] : tensor) -// CHECK-SAME: outs(%[[FILL]] : tensor) +// CHECK-SAME: ins(%[[RESHAPE]], %[[FILL]] : tensor, tensor) +// CHECK-SAME: outs(%[[INIT2]] : tensor) // CHECK: %[[RESULT_RESHAPE:.+]] = tensor.expand_shape %[[RESULT]] {{\[}}[0, 1]] // CHECK: return %[[RESULT_RESHAPE]] @@ -384,11 +387,12 @@ // CHECK-DAG: %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2, 3] // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1xf32> // CHECK: %[[FILL:.+]] = linalg.fill ins(%{{.+}}{{.*}}outs(%[[INIT]] +// CHECK: %[[INIT2:.+]] = tensor.empty() : tensor<1xf32> // CHECK: %[[RESULT:.+]] = linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]]] +// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP2]]] // CHECK-SAME: iterator_types = ["parallel"] -// CHECK-SAME: ins(%[[RESHAPE]] : tensor) -// CHECK-SAME: outs(%[[FILL]] : tensor<1xf32>) +// CHECK-SAME: ins(%[[RESHAPE]], %[[FILL]] : tensor, tensor<1xf32>) +// CHECK-SAME: outs(%[[INIT2]] : tensor<1xf32>) // CHECK: %[[RESULT_RESHAPE:.+]] = tensor.expand_shape %[[RESULT]] {{\[}}[0, 1]] // CHECK: return %[[RESULT_RESHAPE]] @@ -419,11 +423,12 @@ // CHECK-DAG: %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1], [2, 3]] // CHECK: %[[INIT:.+]] = tensor.empty(%{{.+}}) : tensor // CHECK: %[[FILL:.+]] = linalg.fill ins(%{{.+}}{{.*}}outs(%[[INIT]] +// CHECK: %[[INIT2:.+]] = tensor.empty(%{{.+}}) : tensor // CHECK: %[[RESULT:.+]] = linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP3]]] +// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP3]], #[[MAP3]]] // CHECK-SAME: iterator_types = ["parallel", "reduction"] -// CHECK-SAME: ins(%[[RESHAPE]] : tensor) -// CHECK-SAME: outs(%[[FILL]] : tensor) +// CHECK-SAME: ins(%[[RESHAPE]], %[[FILL]] : tensor, tensor) +// CHECK-SAME: outs(%[[INIT2]] : tensor) // CHECK: %[[RESULT_RESHAPE:.+]] = tensor.expand_shape %[[RESULT]] {{\[}}[0, 1]] // CHECK: return %[[RESULT_RESHAPE]]