diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -2119,6 +2119,54 @@ } }; +/// Remove generic/indexed_generic operations (on tensors) that are just copying +/// the values from inputs to the results. Requirements are +/// 1) All iterator types are parallel +/// 2) The body contains just a yield operation with the yielded values being +/// the arguments corresponding to the operands. +struct RemoveIdentityLinalgOps : public RewritePattern { + RemoveIdentityLinalgOps(PatternBenefit benefit = 1) + : RewritePattern(benefit, MatchAnyOpTypeTag()) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + if (!isa(op)) + return failure(); + LinalgOp genericOp = cast(op); + if (!genericOp.hasTensorSemantics()) + return failure(); + // Check all indexing maps are identity. + if (llvm::any_of(genericOp.getIndexingMaps(), + [](AffineMap map) { return !map.isIdentity(); })) + return failure(); + + // Check that the body of the linalg operation is just a linalg.yield + // operation. + Block &body = op->getRegion(0).front(); + if (!llvm::hasSingleElement(body)) + return failure(); + auto yieldOp = dyn_cast(body.getTerminator()); + if (!yieldOp) + return failure(); + + // Get the argument number of the returned values. That is the operand + // number to use for replacing uses of this operation. + unsigned numIndexArgs = genericOp.getNumPayloadInductionVariables(); + SmallVector returnedArgs; + for (Value yieldVal : yieldOp.values()) { + auto yieldArg = yieldVal.dyn_cast(); + if (!yieldArg) + return failure(); + unsigned argumentNumber = yieldArg.getArgNumber(); + if (argumentNumber < numIndexArgs) + return failure(); + returnedArgs.push_back(op->getOperand(argumentNumber - numIndexArgs)); + } + rewriter.replaceOp(genericOp, returnedArgs); + return success(); + } +}; + /// Canonicalize a `linalgOp` -> `dim` pattern by replacing the `dim` arg /// with the corresponding output tensor argument of the linalg op. struct ReplaceDimOfLinalgResult : public OpRewritePattern { @@ -2143,7 +2191,8 @@ #define CANONICALIZERS_AND_FOLDERS(XXX) \ void XXX::getCanonicalizationPatterns(OwningRewritePatternList &results, \ MLIRContext *context) { \ - results.insert(); \ + results.insert(); \ results.insert(context); \ } \ \ diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -249,8 +249,10 @@ return %1: tensor<0xf32> } // CHECK-LABEL: @dce_zero_memref +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<0xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<0xf32> // CHECK-NOT: linalg.copy -// CHECK-NEXT: linalg.generic +// CHECK-NEXT: return %[[ARG1]] // ----- @@ -449,3 +451,30 @@ // CHECK: %[[T0:.+]] = muli %[[ARG0]], %[[C28]] // CHECK: %[[T1:.+]] = linalg.init_tensor [6, 5, %[[T0]]] // CHECK: return %[[T1]] + +// ----- + +#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +func @remove_no_op(%arg0 : tensor, %arg1 : tensor) + -> (tensor, tensor) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %c2 = constant 2 : index + %0 = dim %arg0, %c0 : tensor + %1 = dim %arg0, %c1 : tensor + %2 = dim %arg0, %c2 : tensor + %3 = linalg.init_tensor [%0, %1, %2] : tensor + %4, %5 = linalg.generic { + indexing_maps = [#map, #map, #map, #map], + iterator_types = ["parallel", "parallel", "parallel"] + } ins(%arg0, %arg1 : tensor, tensor) + outs(%3, %3 : tensor, tensor) { + ^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32, %arg5 : f32): + linalg.yield %arg3, %arg2 : f32, f32 + } -> tensor, tensor + return %4, %5 : tensor, tensor +} +// CHECK-LABEL: func @remove_no_op +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor +// CHECK: return %[[ARG1]], %[[ARG0]]