diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -19,12 +19,15 @@ #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Utils/Utils.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/SetVector.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" @@ -225,6 +228,125 @@ } }; +/// Pattern to add init operands to ins when all the loops are parallel and +/// blockArgument corresponding to init is used in the region. This is a fix-up +/// when unit reduction dimensions are all folded away. In this context, it +/// becomes a elementwise generic op. E.g., it converts +/// +/// %0 = tensor.empty() : tensor<1x1xf32> +/// %1 = linalg.fill +/// ins(%cst : f32) +/// outs(%0 : tensor<1x1xf32>) -> tensor<1x1xf32> +/// %2 = linalg.generic {indexing_maps = [affine_map<(d0) -> (0, d0, 0, 0)>, +/// affine_map<(d0) -> (0, d0)>], +/// iterator_types = ["parallel"]} +/// ins(%arg0 : tensor<1x?x1x1xf32>) +/// outs(%1 : tensor<1x1xf32>) { +/// ^bb0(%in: f32, %out: f32): +/// %3 = arith.addf %in, %out : f32 +/// linalg.yield %3 : f32 +/// } -> tensor<1x1xf32> +/// +/// into +/// +/// %0 = tensor.empty() : tensor<1x1xf32> +/// %1 = linalg.fill +/// ins(%cst : f32) +/// outs(%0 : tensor<1x1xf32>) -> tensor<1x1xf32> +/// %2 = tensor.empty() : tensor<1x1xf32> +/// %3 = linalg.generic {indexing_maps = [affine_map<(d0) -> (0, d0, 0, 0)>, +/// affine_map<(d0) -> (0, d0)>, +/// affine_map<(d0) -> (0, d0)>], +/// iterator_types = ["parallel"]} +/// ins(%arg0, %1 : tensor<1x?x1x1xf32>, tensor<1x1xf32>) +/// outs(%2 : tensor<1x1xf32>) { +/// ^bb0(%in: f32, %in_0: f32, %out: f32): +/// %4 = arith.addf %in, %in_0 : f32 +/// linalg.yield %4 : f32 +/// } -> tensor<1x1xf32> +struct AddInitOperandsToInput : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(GenericOp genericOp, + PatternRewriter &rewriter) const override { + if (!genericOp.hasTensorSemantics()) + return failure(); + if (genericOp.getNumParallelLoops() != genericOp.getNumLoops()) + return failure(); + + auto outputOperands = genericOp.getDpsInitOperands(); + SetVector candidates; + for (OpOperand *op : outputOperands) { + if (genericOp.getMatchingBlockArgument(op).use_empty()) + continue; + candidates.insert(op); + } + + if (candidates.empty()) + return failure(); + + // Compute the modified indexing maps. + int64_t origNumInput = genericOp.getNumDpsInputs(); + SmallVector newInputOperands = genericOp.getDpsInputOperands(); + SmallVector indexingMaps = genericOp.getIndexingMapsArray(); + SmallVector newIndexingMaps; + newIndexingMaps.append(indexingMaps.begin(), + std::next(indexingMaps.begin(), origNumInput)); + for (OpOperand *op : candidates) { + newInputOperands.push_back(op->get()); + newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(op)); + } + newIndexingMaps.append(std::next(indexingMaps.begin(), origNumInput), + indexingMaps.end()); + + Location loc = genericOp.getLoc(); + SmallVector newOutputOperands = outputOperands; + for (OpOperand * op : candidates) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointAfterValue(op->get()); + auto elemType = op->get().getType().cast().getElementType(); + auto empty = rewriter.create( + loc, tensor::createDimValues(rewriter, loc, op->get()), elemType); + + auto [start, end] = genericOp.getDpsInitsPositionRange(); + newOutputOperands[op->getOperandNumber() - start] = empty.getResult(); + } + + auto newOp = rewriter.create( + loc, genericOp.getResultTypes(), newInputOperands, newOutputOperands, + newIndexingMaps, genericOp.getIteratorTypesArray(), + /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp)); + + Region ®ion = newOp.getRegion(); + Block *block = new Block(); + region.push_back(block); + BlockAndValueMapping mapper; + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(block); + for (auto bbarg : genericOp.getRegionInputArgs()) + mapper.map(bbarg, block->addArgument(bbarg.getType(), loc)); + + for (OpOperand *op : candidates) { + BlockArgument bbarg = genericOp.getMatchingBlockArgument(op); + mapper.map(bbarg, block->addArgument(bbarg.getType(), loc)); + } + + for (OpOperand *op : outputOperands) { + BlockArgument bbarg = genericOp.getMatchingBlockArgument(op); + if (candidates.count(op)) + block->addArgument(bbarg.getType(), loc); + else + mapper.map(bbarg, block->addArgument(bbarg.getType(), loc)); + } + + for (auto &op : genericOp.getBody()->getOperations()) { + rewriter.clone(op, mapper); + } + rewriter.replaceOp(genericOp, newOp.getResults()); + + return success(); + } +}; + struct UnitExtentReplacementInfo { Type type; AffineMap indexMap; @@ -536,7 +658,8 @@ void mlir::linalg::populateFoldUnitExtentDimsPatterns( RewritePatternSet &patterns) { auto *context = patterns.getContext(); - patterns.add, RankReducedInsertSliceOp>( context); @@ -544,6 +667,8 @@ tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context); tensor::EmptyOp::getCanonicalizationPatterns(patterns, context); tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context); + memref::populateResolveRankedShapeTypeResultDimsPatterns(patterns); + memref::populateResolveShapedTypeResultDimsPatterns(patterns); } namespace { @@ -554,10 +679,11 @@ Operation *op = getOperation(); MLIRContext *context = op->getContext(); RewritePatternSet patterns(context); - if (foldOneTripLoopsOnly) - patterns.add(context); - else + if (foldOneTripLoopsOnly) { + patterns.add(context); + } else { populateFoldUnitExtentDimsPatterns(patterns); + } (void)applyPatternsAndFoldGreedily(op, std::move(patterns)); } }; diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir --- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir +++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir @@ -384,11 +384,12 @@ // CHECK-DAG: %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2, 3] // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1xf32> // CHECK: %[[FILL:.+]] = linalg.fill ins(%{{.+}}{{.*}}outs(%[[INIT]] +// CHECK: %[[INIT2:.+]] = tensor.empty() : tensor<1xf32> // CHECK: %[[RESULT:.+]] = linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]]] +// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP2]]] // CHECK-SAME: iterator_types = ["parallel"] -// CHECK-SAME: ins(%[[RESHAPE]] : tensor) -// CHECK-SAME: outs(%[[FILL]] : tensor<1xf32>) +// CHECK-SAME: ins(%[[RESHAPE]], %[[FILL]] : tensor, tensor<1xf32>) +// CHECK-SAME: outs(%[[INIT2]] : tensor<1xf32>) // CHECK: %[[RESULT_RESHAPE:.+]] = tensor.expand_shape %[[RESULT]] {{\[}}[0, 1]] // CHECK: return %[[RESULT_RESHAPE]]