diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -19,15 +19,12 @@ #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" -#include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Tensor/Utils/Utils.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "llvm/ADT/SetVector.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" @@ -228,125 +225,6 @@ } }; -/// Pattern to add init operands to ins when all the loops are parallel and -/// blockArgument corresponding to init is used in the region. This is a fix-up -/// when unit reduction dimensions are all folded away. In this context, it -/// becomes a elementwise generic op. E.g., it converts -/// -/// %0 = tensor.empty() : tensor<1x1xf32> -/// %1 = linalg.fill -/// ins(%cst : f32) -/// outs(%0 : tensor<1x1xf32>) -> tensor<1x1xf32> -/// %2 = linalg.generic {indexing_maps = [affine_map<(d0) -> (0, d0, 0, 0)>, -/// affine_map<(d0) -> (0, d0)>], -/// iterator_types = ["parallel"]} -/// ins(%arg0 : tensor<1x?x1x1xf32>) -/// outs(%1 : tensor<1x1xf32>) { -/// ^bb0(%in: f32, %out: f32): -/// %3 = arith.addf %in, %out : f32 -/// linalg.yield %3 : f32 -/// } -> tensor<1x1xf32> -/// -/// into -/// -/// %0 = tensor.empty() : tensor<1x1xf32> -/// %1 = linalg.fill -/// ins(%cst : f32) -/// outs(%0 : tensor<1x1xf32>) -> tensor<1x1xf32> -/// %2 = tensor.empty() : tensor<1x1xf32> -/// %3 = linalg.generic {indexing_maps = [affine_map<(d0) -> (0, d0, 0, 0)>, -/// affine_map<(d0) -> (0, d0)>, -/// affine_map<(d0) -> (0, d0)>], -/// iterator_types = ["parallel"]} -/// ins(%arg0, %1 : tensor<1x?x1x1xf32>, tensor<1x1xf32>) -/// outs(%2 : tensor<1x1xf32>) { -/// ^bb0(%in: f32, %in_0: f32, %out: f32): -/// %4 = arith.addf %in, %in_0 : f32 -/// linalg.yield %4 : f32 -/// } -> tensor<1x1xf32> -struct AddInitOperandsToInput : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(GenericOp genericOp, - PatternRewriter &rewriter) const override { - if (!genericOp.hasTensorSemantics()) - return failure(); - if (genericOp.getNumParallelLoops() != genericOp.getNumLoops()) - return failure(); - - auto outputOperands = genericOp.getDpsInitOperands(); - SetVector candidates; - for (OpOperand *op : outputOperands) { - if (genericOp.getMatchingBlockArgument(op).use_empty()) - continue; - candidates.insert(op); - } - - if (candidates.empty()) - return failure(); - - // Compute the modified indexing maps. - int64_t origNumInput = genericOp.getNumDpsInputs(); - SmallVector newInputOperands = genericOp.getDpsInputOperands(); - SmallVector indexingMaps = genericOp.getIndexingMapsArray(); - SmallVector newIndexingMaps; - newIndexingMaps.append(indexingMaps.begin(), - std::next(indexingMaps.begin(), origNumInput)); - for (OpOperand *op : candidates) { - newInputOperands.push_back(op->get()); - newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(op)); - } - newIndexingMaps.append(std::next(indexingMaps.begin(), origNumInput), - indexingMaps.end()); - - Location loc = genericOp.getLoc(); - SmallVector newOutputOperands = outputOperands; - for (OpOperand *op : candidates) { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointAfterValue(op->get()); - auto elemType = op->get().getType().cast().getElementType(); - auto empty = rewriter.create( - loc, tensor::createDimValues(rewriter, loc, op->get()), elemType); - - auto [start, end] = genericOp.getDpsInitsPositionRange(); - newOutputOperands[op->getOperandNumber() - start] = empty.getResult(); - } - - auto newOp = rewriter.create( - loc, genericOp.getResultTypes(), newInputOperands, newOutputOperands, - newIndexingMaps, genericOp.getIteratorTypesArray(), - /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp)); - - Region ®ion = newOp.getRegion(); - Block *block = new Block(); - region.push_back(block); - BlockAndValueMapping mapper; - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(block); - for (auto bbarg : genericOp.getRegionInputArgs()) - mapper.map(bbarg, block->addArgument(bbarg.getType(), loc)); - - for (OpOperand *op : candidates) { - BlockArgument bbarg = genericOp.getMatchingBlockArgument(op); - mapper.map(bbarg, block->addArgument(bbarg.getType(), loc)); - } - - for (OpOperand *op : outputOperands) { - BlockArgument bbarg = genericOp.getMatchingBlockArgument(op); - if (candidates.count(op)) - block->addArgument(bbarg.getType(), loc); - else - mapper.map(bbarg, block->addArgument(bbarg.getType(), loc)); - } - - for (auto &op : genericOp.getBody()->getOperations()) { - rewriter.clone(op, mapper); - } - rewriter.replaceOp(genericOp, newOp.getResults()); - - return success(); - } -}; - struct UnitExtentReplacementInfo { Type type; AffineMap indexMap; @@ -658,8 +536,7 @@ void mlir::linalg::populateFoldUnitExtentDimsPatterns( RewritePatternSet &patterns) { auto *context = patterns.getContext(); - patterns.add, RankReducedInsertSliceOp>( context); @@ -667,8 +544,6 @@ tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context); tensor::EmptyOp::getCanonicalizationPatterns(patterns, context); tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context); - memref::populateResolveRankedShapeTypeResultDimsPatterns(patterns); - memref::populateResolveShapedTypeResultDimsPatterns(patterns); } namespace { @@ -680,7 +555,7 @@ MLIRContext *context = op->getContext(); RewritePatternSet patterns(context); if (foldOneTripLoopsOnly) - patterns.add(context); + patterns.add(context); else populateFoldUnitExtentDimsPatterns(patterns); (void)applyPatternsAndFoldGreedily(op, std::move(patterns)); diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir --- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir +++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir @@ -384,12 +384,11 @@ // CHECK-DAG: %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2, 3] // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1xf32> // CHECK: %[[FILL:.+]] = linalg.fill ins(%{{.+}}{{.*}}outs(%[[INIT]] -// CHECK: %[[INIT2:.+]] = tensor.empty() : tensor<1xf32> // CHECK: %[[RESULT:.+]] = linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP2]]] +// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]]] // CHECK-SAME: iterator_types = ["parallel"] -// CHECK-SAME: ins(%[[RESHAPE]], %[[FILL]] : tensor, tensor<1xf32>) -// CHECK-SAME: outs(%[[INIT2]] : tensor<1xf32>) +// CHECK-SAME: ins(%[[RESHAPE]] : tensor) +// CHECK-SAME: outs(%[[FILL]] : tensor<1xf32>) // CHECK: %[[RESULT_RESHAPE:.+]] = tensor.expand_shape %[[RESULT]] {{\[}}[0, 1]] // CHECK: return %[[RESULT_RESHAPE]] diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -8301,7 +8301,6 @@ ":LinalgUtils", ":MathDialect", ":MemRefDialect", - ":MemRefTransforms", ":Pass", ":SCFDialect", ":SCFTransforms",