diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -799,11 +799,30 @@ return success(); } }; + +struct FoldInitTensorWithDimOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::DimOp dimOp, + PatternRewriter &rewriter) const override { + Optional maybeConstantIndex = dimOp.getConstantIndex(); + auto initTensorOp = dimOp.source().getDefiningOp(); + if (!initTensorOp || !maybeConstantIndex) + return failure(); + if (initTensorOp.isDynamicSize(*maybeConstantIndex)) { + rewriter.replaceOp(dimOp, + initTensorOp.getDynamicSize(*maybeConstantIndex)); + return success(); + } + rewriter.replaceOpWithNewOp(dimOp, *maybeConstantIndex); + return success(); + } +}; } // namespace void InitTensorOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add, FoldInitTensorWithTensorReshapeOp, ReplaceStaticShapeDims>(context); diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -873,3 +873,24 @@ return %0 : tensor<2x3x4xf32> } +// ----- + +func private @some_use(%i : index, %j : index) + +// CHECK-LABEL: func @init_canonicalize +// CHECK-SAME: %[[I:.*]]: index +func @init_canonicalize(%i : index) { + %c0 = constant 0 : index + %c1 = constant 0 : index + + // CHECK-NOT: init_tensor + %0 = linalg.init_tensor [%i, 42] : tensor + + // CHECK-NOT: tensor.dim + %1 = tensor.dim %0, %c0: index + %2 = tensor.dim %0, %c1: index + + // CHECK: %[[c42:.*]] = constant 42 : index + // CHECK: call @some_use(%[[I]], %[[c42]]) + call @some_use(%1, %2) : (index, index) -> () +} \ No newline at end of file