diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -146,13 +146,13 @@ } /// Update the index accesses of linalg operations having index semantics. -template -static void replaceUnitDimIndexOps(GenericOpTy op, +static void replaceUnitDimIndexOps(GenericOp genericOp, const DenseSet &unitDims, PatternRewriter &rewriter) { - assert(op->getNumRegions() == 1 && op->getRegion(0).getBlocks().size() == 1 && + assert(genericOp->getNumRegions() == 1 && + genericOp->getRegion(0).getBlocks().size() == 1 && "expected generic operation to have one block."); - Block &block = op->getRegion(0).front(); + Block &block = genericOp->getRegion(0).front(); for (IndexOp indexOp : llvm::make_early_inc_range(block.getOps())) { OpBuilder::InsertionGuard guard(rewriter); @@ -170,39 +170,13 @@ } } -/// Modify the region of indexed generic op to drop arguments corresponding to -/// loops that are unit trip count. -template -static LogicalResult -replaceBlockArgForUnitDimLoops(OpTy op, const DenseSet &unitDims, - PatternRewriter &rewriterp) { - return success(); -} - -template <> -LogicalResult replaceBlockArgForUnitDimLoops( - IndexedGenericOp op, const DenseSet &unitDims, - PatternRewriter &rewriter) { - OpBuilder::InsertionGuard guard(rewriter); - Block *entryBlock = &op->getRegion(0).front(); - rewriter.setInsertionPointToStart(entryBlock); - Value zero = rewriter.create(op.getLoc(), 0); - for (unsigned unitDimLoop : unitDims) { - entryBlock->getArgument(unitDimLoop).replaceAllUsesWith(zero); - } - SmallVector unitDimsToErase(unitDims.begin(), unitDims.end()); - entryBlock->eraseArguments(unitDimsToErase); - return success(); -} - namespace { /// Pattern to fold unit-trip count loops in GenericOps. -template -struct FoldUnitDimLoops : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(GenericOpTy op, +struct FoldUnitDimLoops : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const override { - SmallVector indexingMaps = op.getIndexingMaps(); + SmallVector indexingMaps = genericOp.getIndexingMaps(); if (indexingMaps.empty()) return failure(); @@ -213,7 +187,7 @@ if (!invertedMap) return failure(); SmallVector dims; - for (ShapedType shapedType : op.getShapedOperandTypes()) + for (ShapedType shapedType : genericOp.getShapedOperandTypes()) dims.append(shapedType.getShape().begin(), shapedType.getShape().end()); // Find all the reduction iterators. Those need some special consideration @@ -221,7 +195,7 @@ auto getLoopDimsOfType = [&](StringRef iteratorTypeName) -> SmallVector { SmallVector dimExprs; - getDimsOfType(op, iteratorTypeName, dimExprs); + getDimsOfType(genericOp, iteratorTypeName, dimExprs); return llvm::to_vector<4>(llvm::map_range(dimExprs, [](AffineExpr expr) { return expr.cast().getPosition(); })); @@ -230,7 +204,7 @@ DenseSet unitDims; SmallVector unitDimsReductionLoops; - ArrayAttr iteratorTypes = op.iterator_types(); + ArrayAttr iteratorTypes = genericOp.iterator_types(); for (auto expr : enumerate(invertedMap.getResults())) { if (AffineDimExpr dimExpr = expr.value().dyn_cast()) if (dims[dimExpr.getPosition()] == 1) { @@ -260,7 +234,7 @@ ArrayAttr newIndexingMapAttr = replaceUnitDims(unitDims, indexingMaps, context); if (!newIndexingMapAttr) - return op.emitError("unable to compute modified indexing_maps"); + return genericOp.emitError("unable to compute modified indexing_maps"); // Compute the iterator types of the modified op by dropping the one-trip // count loops. @@ -270,12 +244,11 @@ newIteratorTypes.push_back(attr.value()); } - rewriter.startRootUpdate(op); - op.indexing_mapsAttr(newIndexingMapAttr); - op.iterator_typesAttr(ArrayAttr::get(context, newIteratorTypes)); - (void)replaceBlockArgForUnitDimLoops(op, unitDims, rewriter); - replaceUnitDimIndexOps(op, unitDims, rewriter); - rewriter.finalizeRootUpdate(op); + rewriter.startRootUpdate(genericOp); + genericOp.indexing_mapsAttr(newIndexingMapAttr); + genericOp.iterator_typesAttr(ArrayAttr::get(context, newIteratorTypes)); + replaceUnitDimIndexOps(genericOp, unitDims, rewriter); + rewriter.finalizeRootUpdate(genericOp); return success(); } }; @@ -351,23 +324,22 @@ } /// Pattern to replace tensors operands/results that are unit extents. -template -struct ReplaceUnitExtentTensors : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(GenericOpTy op, +struct ReplaceUnitExtentTensors : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const override { - if (!op.hasTensorSemantics()) + if (!genericOp.hasTensorSemantics()) return failure(); MLIRContext *context = rewriter.getContext(); - Location loc = op.getLoc(); + Location loc = genericOp.getLoc(); SmallVector newIndexingMaps; SmallVector reassociationMaps; SmallVector newInputOutputTypes; bool doCanonicalization = false; - for (auto it : - llvm::zip(op.getIndexingMaps(), op.getShapedOperandTypes())) { + for (auto it : llvm::zip(genericOp.getIndexingMaps(), + genericOp.getShapedOperandTypes())) { auto replacementInfo = replaceUnitExtents( std::get<0>(it), std::get<1>(it).template cast(), context); @@ -402,20 +374,20 @@ return res; }; - SmallVector newInputs = insertReshapes(op.inputs()); - SmallVector newOutputs = insertReshapes(op.outputs()); + SmallVector newInputs = insertReshapes(genericOp.inputs()); + SmallVector newOutputs = insertReshapes(genericOp.outputs()); // If any result type changes, insert a reshape to convert from the original // type to the new type. SmallVector resultTypes; - resultTypes.reserve(op.getNumResults()); - for (unsigned i : llvm::seq(0, op.getNumResults())) - resultTypes.push_back(newInputOutputTypes[i + op.getNumInputs()]); - GenericOpTy replacementOp = rewriter.create( + resultTypes.reserve(genericOp.getNumResults()); + for (unsigned i : llvm::seq(0, genericOp.getNumResults())) + resultTypes.push_back(newInputOutputTypes[i + genericOp.getNumInputs()]); + GenericOp replacementOp = rewriter.create( loc, resultTypes, newInputs, newOutputs, newIndexingMaps, llvm::to_vector<4>( - op.iterator_types().template getAsValueRange())); - rewriter.inlineRegionBefore(op.region(), replacementOp.region(), + genericOp.iterator_types().template getAsValueRange())); + rewriter.inlineRegionBefore(genericOp.region(), replacementOp.region(), replacementOp.region().begin()); // If any result tensor has a modified shape, then add reshape to recover @@ -423,7 +395,7 @@ SmallVector resultReplacements; for (auto result : llvm::enumerate(replacementOp.getResults())) { unsigned index = result.index() + replacementOp.getNumInputs(); - RankedTensorType origResultType = op.getResult(result.index()) + RankedTensorType origResultType = genericOp.getResult(result.index()) .getType() .template cast(); if (origResultType != result.value().getType()) @@ -433,7 +405,7 @@ else resultReplacements.push_back(result.value()); } - rewriter.replaceOp(op, resultReplacements); + rewriter.replaceOp(genericOp, resultReplacements); return success(); } }; @@ -528,9 +500,7 @@ void mlir::linalg::populateFoldUnitExtentDimsPatterns( RewritePatternSet &patterns) { auto *context = patterns.getContext(); - patterns.add, FoldUnitDimLoops, - ReplaceUnitExtentTensors, - ReplaceUnitExtentTensors, + patterns.add( context); TensorReshapeOp::getCanonicalizationPatterns(patterns, context); @@ -545,9 +515,7 @@ MLIRContext *context = funcOp.getContext(); RewritePatternSet patterns(context); if (foldOneTripLoopsOnly) - patterns - .add, FoldUnitDimLoops>( - context); + patterns.add(context); else populateFoldUnitExtentDimsPatterns(patterns); (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns)); diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir --- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir +++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir @@ -42,48 +42,6 @@ library_call = "some_external_func" } -func @drop_one_trip_loops_indexed_generic - (%arg0 : tensor, %shape: tensor) -> tensor -{ - %0 = linalg.indexed_generic #trait - ins(%arg0 : tensor) - outs(%shape: tensor) { - ^bb0(%arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, - %arg5 : index, %arg6 : i32, %arg7 : i32) : - %1 = addi %arg1, %arg2 : index - %2 = addi %1, %arg3 : index - %3 = addi %2, %arg4 : index - %4 = addi %3, %arg5 : index - %5 = index_cast %4 : index to i32 - %6 = addi %5, %arg6 : i32 - linalg.yield %6 : i32 - } -> tensor - return %0 : tensor -} -// CHECK-LABEL: func @drop_one_trip_loops_indexed_generic -// CHECK: linalg.indexed_generic -// CHECK: ^{{.+}}( -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index, %[[ARG2:[a-zA-Z0-9]+]]: index -// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index, %[[ARG4:[a-zA-Z0-9]+]]: i32, %{{.*}}: i32) -// CHECK: %[[T3:.+]] = addi %[[ARG1]], %[[ARG2]] -// CHECK: %[[T4:.+]] = addi %[[T3]], %[[ARG3]] -// CHECK: %[[T5:.+]] = index_cast %[[T4]] : index to i32 -// CHECK: %[[T6:.+]] = addi %[[T5]], %[[ARG4]] : i32 -// CHECK: linalg.yield %[[T6]] : i32 - -// ----- - -#accesses = [ - affine_map<(i, j, k, l, m) -> (i, k, m)>, - affine_map<(i, j, k, l, m) -> (i, k, j, l, m)> -] - -#trait = { - iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"], - indexing_maps = #accesses, - library_call = "some_external_func" -} - func @drop_one_trip_loops_indexed (%arg0 : tensor, %shape: tensor) -> tensor { @@ -158,35 +116,6 @@ library_call = "some_external_func" } -func @drop_all_loops_indexed_generic - (%arg0 : tensor<1x1xi32>) -> tensor<1x1xi32>{ - %0 = linalg.indexed_generic #trait - ins(%arg0 : tensor<1x1xi32>) - outs(%arg0 : tensor<1x1xi32>) { - ^bb0(%arg1 : index, %arg2 : index, %arg3: i32, %arg4: i32) : - %1 = addi %arg1, %arg2 : index - %2 = index_cast %1 : index to i32 - %3 = addi %2, %arg3 : i32 - linalg.yield %3 : i32 - } -> tensor<1x1xi32> - return %0 : tensor<1x1xi32> -} - -// CHECK-LABEL: func @drop_all_loops_indexed_generic -// CHECK: linalg.indexed_generic -// CHECK: ^{{.+}}(%[[ARG1:.+]]: i32, %[[ARG2:.+]]: i32) -// CHECK: linalg.yield %[[ARG1]] : i32 - -// ----- - -#map0 = affine_map<(i, j) -> (i, j)> -#access = [#map0, #map0] -#trait = { - iterator_types = ["parallel", "parallel"], - indexing_maps = #access, - library_call = "some_external_func" -} - func @drop_all_loops_indexed (%arg0 : tensor<1x1xi32>) -> tensor<1x1xi32>{ %0 = linalg.generic #trait