diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -718,9 +718,118 @@ }; } // namespace +static Value getCollapsedInitTensor(OpBuilder &builder, + TensorReshapeOp reshapeOp) { + Location loc = reshapeOp.getLoc(); + SmallVector dynamicDims; + SmallVector staticDims; + auto reassociation = reshapeOp.getReassociationMaps(); + Value src = reshapeOp.src(); + RankedTensorType srcType = reshapeOp.getSrcType(); + ArrayRef srcShape = srcType.getShape(); + for (auto map : reassociation) { + Value linearizedDynamicDim = nullptr; + int64_t linearizedStaticDim = 1; + for (unsigned i : llvm::map_range(map.getResults(), [](AffineExpr e) { + return e.cast().getPosition(); + })) { + if (ShapedType::isDynamic(srcShape[i])) { + Value shapeVal = builder.create(loc, src, i); + if (linearizedDynamicDim) { + linearizedDynamicDim = + builder.create(loc, linearizedDynamicDim, shapeVal); + } else { + linearizedDynamicDim = shapeVal; + } + } else { + linearizedStaticDim *= srcShape[i]; + } + } + if (linearizedDynamicDim) { + if (linearizedStaticDim != 1) { + linearizedDynamicDim = builder.create( + loc, linearizedDynamicDim, + builder.create(loc, linearizedStaticDim)); + } + dynamicDims.push_back(linearizedDynamicDim); + staticDims.push_back(ShapedType::kDynamicSize); + } else { + staticDims.push_back(linearizedStaticDim); + } + } + return builder.create(loc, dynamicDims, staticDims, + srcType.getElementType()); +} + +static Value getExpandedInitTensor(OpBuilder &builder, + TensorReshapeOp reshapeOp) { + SmallVector dynamicDims; + SmallVector staticDims; + auto reassociation = reshapeOp.getReassociationMaps(); + Value src = reshapeOp.src(); + RankedTensorType srcType = reshapeOp.getSrcType(); + ArrayRef srcShape = srcType.getShape(); + ArrayRef dstShape = reshapeOp.getResultType().getShape(); + Location loc = reshapeOp.getLoc(); + for (auto map : enumerate(reassociation)) { + int64_t linearizedStaticDim = 1; + bool hasDynamic = false; + for (unsigned i : + llvm::map_range(map.value().getResults(), [](AffineExpr e) { + return e.cast().getPosition(); + })) { + if (ShapedType::isDynamic(dstShape[i])) { + assert(!hasDynamic && + "unexpected dynamic dim reshape to multiple dynamic dims"); + hasDynamic = true; + staticDims.push_back(ShapedType::kDynamicSize); + continue; + } + staticDims.push_back(dstShape[i]); + linearizedStaticDim *= dstShape[i]; + } + if (hasDynamic) { + assert(ShapedType::isDynamic(srcShape[map.index()]) && + "expected shape in collapsed type to be static as well"); + Value dynamicDim = builder.create(loc, src, map.index()); + if (linearizedStaticDim != 1) { + dynamicDim = builder.create( + loc, dynamicDim, + builder.create(loc, linearizedStaticDim)); + } + dynamicDims.push_back(dynamicDim); + } + } + return builder.create(loc, dynamicDims, staticDims, + srcType.getElementType()); +} + +namespace { +struct FoldWithTensorReshapeOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp, + PatternRewriter &rewriter) const override { + if (!reshapeOp.src().getDefiningOp()) + return failure(); + RankedTensorType collapsedType = reshapeOp.getSrcType(); + RankedTensorType expandedType = reshapeOp.getResultType(); + bool isCollapsed = expandedType.getRank() < collapsedType.getRank(); + if (isCollapsed) + std::swap(collapsedType, expandedType); + Value initTensorOp = isCollapsed + ? getCollapsedInitTensor(rewriter, reshapeOp) + : getExpandedInitTensor(rewriter, reshapeOp); + rewriter.replaceOp(reshapeOp, initTensorOp); + return success(); + } +}; +} // namespace + void InitTensorOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); + results.insert(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -413,3 +413,39 @@ // CHECK-SAME: [[ARG_0:%.*]]: tensor, [[ARG_1:%.*]]: tensor) // CHECK: dim [[ARG_0]] // CHECK: dim [[ARG_1]] + +// ----- + +func @init_tensor_reshape_expansion(%arg0 : index) -> tensor<2x3x5x4x?x7xf32> { + %0 = linalg.init_tensor [6, 5, %arg0] : tensor<6x5x?xf32> + %1 = linalg.tensor_reshape %0 + [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d2)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>] : + tensor<6x5x?xf32> into tensor<2x3x5x4x?x7xf32> + return %1 : tensor<2x3x5x4x?x7xf32> +} +// CHECK: func @init_tensor_reshape_expansion +// CHECK-SAME: %[[ARG0:.+]]: index +// CHECK: %[[C28:.+]] = constant 28 : index +// CHECK: %[[T0:.+]] = divi_unsigned %[[ARG0]], %[[C28]] +// CHECK: %[[T1:.+]] = linalg.init_tensor [2, 3, 5, 4, %[[T0]], 7] +// CHECK: return %[[T1]] + +// ----- + +func @init_tensor_reshape_collapse(%arg0 : index) -> tensor<6x5x?xf32> { + %0 = linalg.init_tensor [2, 3, 5, 4, %arg0, 7] : tensor<2x3x5x4x?x7xf32> + %1 = linalg.tensor_reshape %0 + [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d2)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>] : + tensor<2x3x5x4x?x7xf32> into tensor<6x5x?xf32> + return %1 : tensor<6x5x?xf32> +} +// CHECK: func @init_tensor_reshape_collapse +// CHECK-SAME: %[[ARG0:.+]]: index +// CHECK: %[[C28:.+]] = constant 28 : index +// CHECK: %[[T0:.+]] = muli %[[ARG0]], %[[C28]] +// CHECK: %[[T1:.+]] = linalg.init_tensor [6, 5, %[[T0]]] +// CHECK: return %[[T1]]