diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -799,11 +799,30 @@ return success(); } }; + +struct FoldInitTensorWithDimOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::DimOp dimOp, + PatternRewriter &rewriter) const override { + Optional maybeConstantIndex = dimOp.getConstantIndex(); + auto initTensorOp = dimOp.source().getDefiningOp(); + if (!initTensorOp || !maybeConstantIndex) + return failure(); + if (initTensorOp.isDynamicSize(*maybeConstantIndex)) { + rewriter.replaceOp(dimOp, + initTensorOp.getDynamicSize(*maybeConstantIndex)); + return success(); + } + rewriter.replaceOpWithNewOp(dimOp, *maybeConstantIndex); + return success(); + } +}; } // namespace void InitTensorOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add, FoldInitTensorWithTensorReshapeOp, ReplaceStaticShapeDims>(context); diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -540,13 +540,10 @@ } // CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)> // CHECK: func @init_tensor_reshape_expansion -// CHECK-SAME: %[[ARG0:.+]]: index -// CHECK: %[[C2:.+]] = constant 2 -// CHECK: %[[INIT1:.+]] = linalg.init_tensor [6, 5, %[[ARG0]]] -// CHECK: %[[D0:.+]] = tensor.dim %[[INIT1]], %[[C2]] -// CHECK: %[[T0:.+]] = affine.apply #[[MAP]]()[%[[D0]]] -// CHECK: %[[INIT2:.+]] = linalg.init_tensor [2, 3, 5, 4, %[[T0]], 7] -// CHECK: return %[[INIT2]] +// CHECK-SAME: %[[ARG0:.+]]: index +// CHECK-NEXT: %[[D:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]] +// CHECK-NEXT: %[[INIT:.+]] = linalg.init_tensor [2, 3, 5, 4, %[[D]], 7] +// CHECK-NEXT: return %[[INIT]] // ----- @@ -558,13 +555,10 @@ } // CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 * 28)> // CHECK: func @init_tensor_reshape_collapse -// CHECK-SAME: %[[ARG0:.+]]: index -// CHECK: %[[C4:.+]] = constant 4 -// CHECK: %[[INIT1:.+]] = linalg.init_tensor [2, 3, 5, 4, %[[ARG0]], 7] -// CHECK: %[[D0:.+]] = tensor.dim %[[INIT1]], %[[C4]] -// CHECK: %[[T0:.+]] = affine.apply #[[MAP]]()[%[[D0]]] -// CHECK: %[[INIT2:.+]] = linalg.init_tensor [6, 5, %[[T0]]] -// CHECK: return %[[INIT2]] +// CHECK-SAME: %[[ARG0:.+]]: index +// CHECK-NEXT: %[[D:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]] +// CHECK-NEXT: %[[INIT:.+]] = linalg.init_tensor [6, 5, %[[D]]] +// CHECK-NEXT: return %[[INIT]] // ----- @@ -873,3 +867,26 @@ return %0 : tensor<2x3x4xf32> } +// ----- + +func private @some_use(%i : index, %j : index) + +// CHECK-LABEL: func @init_canonicalize +// CHECK-SAME: %[[I:.*]]: index +func @init_canonicalize(%i : index) { + %c0 = constant 0 : index + %c1 = constant 1 : index + + // CHECK-NOT: init_tensor + %0 = linalg.init_tensor [%i, 42] : tensor + + // CHECK-NOT: tensor.dim + %1 = tensor.dim %0, %c0: tensor + %2 = tensor.dim %0, %c1: tensor + + // CHECK: %[[c42:.*]] = constant 42 : index + // CHECK: call @some_use(%[[I]], %[[c42]]) + call @some_use(%1, %2) : (index, index) -> () + + return +}