diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -19,12 +19,15 @@ #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Utils/Utils.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/SetVector.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" @@ -225,6 +228,128 @@ } }; +/// Pattern to replace outs with tensor.empty() ops if it meets below +/// conditions: +/// 1) The corresponding block argument is not used. +/// 2) The output operand is a LinalgOp. +struct ReplaceWithEmptyTensorIfUnused : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(GenericOp genericOp, + PatternRewriter &rewriter) const override { + if (!genericOp.hasTensorSemantics()) + return failure(); + + int numInput = genericOp.getNumDpsInputs(); + int numInit = genericOp.getNumDpsInits(); + SmallVector candidates; + for (int i = 0; i < numInit; ++i) + if (genericOp.getBody()->getArgument(numInput + i).use_empty()) + candidates.push_back(i); + + if (candidates.empty()) + return failure(); + + Location loc = genericOp.getLoc(); + bool changed = false; + rewriter.startRootUpdate(genericOp); + + for (auto idx : candidates) { + OpBuilder::InsertionGuard guard(rewriter); + Value operand = genericOp.getDpsInitOperand(idx)->get(); + if (!operand.getDefiningOp()) continue; + rewriter.setInsertionPointAfterValue(operand); + auto elemType = operand.getType().cast().getElementType(); + auto empty = rewriter.create( + loc, tensor::createDimValues(rewriter, loc, operand), elemType); + genericOp.setDpsInitOperand(idx, empty.getResult()); + changed = true; + } + + rewriter.finalizeRootUpdate(genericOp); + return success(changed); + } +}; + +/// Pattern to add init operands to ins when all the loops are parallel and +/// inits have uses. +struct AddInitOperandsToInput : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(GenericOp genericOp, + PatternRewriter &rewriter) const override { + if (!genericOp.hasTensorSemantics()) + return failure(); + if (genericOp.getNumParallelLoops() != genericOp.getNumLoops()) + return failure(); + + int numInit = genericOp.getNumDpsInits(); + int64_t origNumInput = genericOp.getNumDpsInputs(); + SmallVector outputOperands = genericOp.getDpsInitOperands(); + SetVector candidates; + for (int i = 0; i < numInit; ++i) { + if (genericOp.getBody()->getArgument(origNumInput + i).use_empty()) + continue; + candidates.insert(i); + } + + if (candidates.empty()) + return failure(); + + // Compute the modified indexing maps. + SmallVector inputOperands = genericOp.getDpsInputOperands(); + SmallVector indexingMaps = genericOp.getIndexingMapsArray(); + SmallVector newIndexingMaps; + newIndexingMaps.append(indexingMaps.begin(), + std::next(indexingMaps.begin(), origNumInput)); + for (auto i : candidates) { + inputOperands.push_back(outputOperands[i]); + newIndexingMaps.push_back(indexingMaps[i + origNumInput]); + } + newIndexingMaps.append(std::next(indexingMaps.begin(), origNumInput), + indexingMaps.end()); + + Location loc = genericOp.getLoc(); + auto newOp = rewriter.create( + loc, genericOp.getResultTypes(), inputOperands, outputOperands, + newIndexingMaps, genericOp.getIteratorTypesArray(), + /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp)); + + Region ®ion = newOp.getRegion(); + Block *block = new Block(); + region.push_back(block); + BlockAndValueMapping mapper; + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(block); + for (int i = 0; i < origNumInput; ++i) { + mapper.map(genericOp.getBody()->getArgument(i), + block->addArgument( + getElementTypeOrSelf(inputOperands[i].getType()), loc)); + } + for (auto en : llvm::enumerate(candidates)) { + mapper.map( + genericOp.getBody()->getArgument(en.value() + origNumInput), + block->addArgument( + getElementTypeOrSelf(outputOperands[en.index()].getType()), loc)); + } + for (int i = 0, e = outputOperands.size(); i < e; ++i) { + if (candidates.count(i)) { + block->addArgument(getElementTypeOrSelf(outputOperands[i].getType()), + loc); + } else { + mapper.map(genericOp.getBody()->getArgument(i + origNumInput), + block->addArgument( + getElementTypeOrSelf(outputOperands[i].getType()), loc)); + } + } + + for (auto &op : genericOp.getBody()->getOperations()) { + rewriter.clone(op, mapper); + } + rewriter.replaceOp(genericOp, newOp.getResults()); + + return success(); + } +}; + struct UnitExtentReplacementInfo { Type type; AffineMap indexMap; @@ -536,7 +661,9 @@ void mlir::linalg::populateFoldUnitExtentDimsPatterns( RewritePatternSet &patterns) { auto *context = patterns.getContext(); - patterns.add, RankReducedInsertSliceOp>( context); @@ -544,6 +671,8 @@ tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context); tensor::EmptyOp::getCanonicalizationPatterns(patterns, context); tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context); + memref::populateResolveRankedShapeTypeResultDimsPatterns(patterns); + memref::populateResolveShapedTypeResultDimsPatterns(patterns); } namespace { @@ -554,10 +683,12 @@ Operation *op = getOperation(); MLIRContext *context = op->getContext(); RewritePatternSet patterns(context); - if (foldOneTripLoopsOnly) - patterns.add(context); - else + if (foldOneTripLoopsOnly) { + patterns.add(context); + } else { populateFoldUnitExtentDimsPatterns(patterns); + } (void)applyPatternsAndFoldGreedily(op, std::move(patterns)); } }; diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir --- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir +++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir @@ -384,11 +384,12 @@ // CHECK-DAG: %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2, 3] // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1xf32> // CHECK: %[[FILL:.+]] = linalg.fill ins(%{{.+}}{{.*}}outs(%[[INIT]] +// CHECK: %[[INIT2:.+]] = tensor.empty() : tensor<1xf32> // CHECK: %[[RESULT:.+]] = linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]]] +// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP2]]] // CHECK-SAME: iterator_types = ["parallel"] -// CHECK-SAME: ins(%[[RESHAPE]] : tensor) -// CHECK-SAME: outs(%[[FILL]] : tensor<1xf32>) +// CHECK-SAME: ins(%[[RESHAPE]], %[[FILL]] : tensor, tensor<1xf32>) +// CHECK-SAME: outs(%[[INIT2]] : tensor<1xf32>) // CHECK: %[[RESULT_RESHAPE:.+]] = tensor.expand_shape %[[RESULT]] {{\[}}[0, 1]] // CHECK: return %[[RESULT_RESHAPE]]