diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1958,14 +1958,33 @@ return success(); } }; + +/// Canonicalize a `linalgOp` -> `dim` pattern by replacing the `dim` arg +/// with the corresponding output tensor argument of the linalg op. +struct ReplaceDimOfLinalgResult : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DimOp dimOp, + PatternRewriter &rewriter) const override { + Value dimOpArg = dimOp.memrefOrTensor(); + auto linalgOp = dimOpArg.getDefiningOp(); + if (!linalgOp) + return failure(); + + auto results = linalgOp.getOperation()->getResults(); + int64_t id = std::distance(results.begin(), llvm::find(results, dimOpArg)); + auto outputTensors = linalgOp.getOutputTensors(); + rewriter.replaceOpWithNewOp(dimOp, outputTensors[id], dimOp.index()); + return success(); + } +}; } // namespace #define CANONICALIZERS_AND_FOLDERS(XXX) \ void XXX::getCanonicalizationPatterns(OwningRewritePatternList &results, \ MLIRContext *context) { \ - results.insert(); \ - results.insert(); \ - results.insert(); \ + results.insert(); \ + results.insert(context); \ } \ \ LogicalResult XXX::fold(ArrayRef, \ diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -389,3 +389,31 @@ // CHECK: func @init_tensor_dynamic_dim // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index // CHECK: return %[[ARG0]] + +// ----- + +#map = affine_map<(d0) -> (d0)> + +func @init_tensor_dim_of_linalg_result(%arg_0 : tensor, + %arg_1: tensor) -> (tensor, tensor) { + %0, %1 = linalg.generic { + indexing_maps = [#map, #map, #map], + iterator_types = ["parallel"] + } ins(%arg_0 : tensor) + outs(%arg_0, %arg_1 : tensor, tensor) { + ^bb0(%in: f32, %out_0: f32, %out_1: f32): + linalg.yield %in, %in : f32, f32 + } -> tensor, tensor + + %c0 = constant 0 : index + %num_elem_0 = dim %0, %c0 : tensor + %result_0 = linalg.init_tensor [%num_elem_0] : tensor + + %num_elem_1 = dim %1, %c0 : tensor + %result_1 = linalg.init_tensor [%num_elem_1] : tensor + return %result_0, %result_1 : tensor, tensor +} +// CHECK-LABEL: func @init_tensor_dim_of_linalg_result( +// CHECK-SAME: [[ARG_0:%.*]]: tensor, [[ARG_1:%.*]]: tensor) +// CHECK: dim [[ARG_0]] +// CHECK: dim [[ARG_1]]