diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -718,9 +718,120 @@ }; } // namespace +static Value getCollapsedInitTensor(OpBuilder &builder, + TensorReshapeOp reshapeOp) { + Location loc = reshapeOp.getLoc(); + SmallVector dynamicShapes; + SmallVector staticShapes; + auto reassociation = reshapeOp.getReassociationMaps(); + Value src = reshapeOp.src(); + RankedTensorType srcType = reshapeOp.getSrcType(); + ArrayRef srcShape = srcType.getShape(); + for (auto map : reassociation) { + Value linearizedDynamicDim = nullptr; + int64_t linearizedStaticDim = 1; + for (unsigned i : llvm::map_range(map.getResults(), [](AffineExpr e) { + return e.cast().getPosition(); + })) { + if (ShapedType::isDynamic(srcShape[i])) { + Value shapeVal = builder.create(loc, src, i); + if (linearizedDynamicDim) { + linearizedDynamicDim = + builder.create(loc, linearizedDynamicDim, shapeVal); + } else { + linearizedDynamicDim = shapeVal; + } + } else { + linearizedStaticDim *= srcShape[i]; + } + } + if (linearizedDynamicDim) { + if (linearizedStaticDim != 1) { + linearizedDynamicDim = builder.create( + loc, linearizedDynamicDim, + builder.create(loc, linearizedStaticDim)); + } + dynamicShapes.push_back(linearizedDynamicDim); + staticShapes.push_back(ShapedType::kDynamicSize); + } else { + staticShapes.push_back(linearizedStaticDim); + } + } + return builder.create(loc, dynamicShapes, staticShapes, + srcType.getElementType()); +} + +static Value getExpandedInitTensor(OpBuilder &builder, + TensorReshapeOp reshapeOp) { + SmallVector dynamicShapes; + SmallVector staticShapes; + auto reassociation = reshapeOp.getReassociationMaps(); + Value src = reshapeOp.src(); + RankedTensorType srcType = reshapeOp.getSrcType(); + ArrayRef srcShape = srcType.getShape(); + ArrayRef dstShape = reshapeOp.getResultType().getShape(); + Location loc = reshapeOp.getLoc(); + for (auto map : enumerate(reassociation)) { + int64_t linearizedStaticDim = 1; + bool hasDynamic = false; + for (unsigned i : + llvm::map_range(map.value().getResults(), [](AffineExpr e) { + return e.cast().getPosition(); + })) { + if (ShapedType::isDynamic(dstShape[i])) { + if (hasDynamic) + return nullptr; + hasDynamic = true; + staticShapes.push_back(ShapedType::kDynamicSize); + continue; + } + staticShapes.push_back(dstShape[i]); + linearizedStaticDim *= dstShape[i]; + } + if (hasDynamic) { + assert(ShapedType::isDynamic(srcShape[map.index()]) && + "expected shape in collapsed type to be static as well"); + Value dynamicDim = builder.create(loc, src, map.index()); + if (linearizedStaticDim != 1) { + dynamicDim = builder.create( + loc, dynamicDim, + builder.create(loc, linearizedStaticDim)); + } + dynamicShapes.push_back(dynamicDim); + } + } + return builder.create(loc, dynamicShapes, staticShapes, + srcType.getElementType()); +} + +namespace { +struct FoldWithTensorReshapeOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp, + PatternRewriter &rewriter) const override { + if (!reshapeOp.src().getDefiningOp()) + return failure(); + RankedTensorType collapsedType = reshapeOp.getSrcType(); + RankedTensorType expandedType = reshapeOp.getResultType(); + bool isCollapsed = expandedType.getRank() < collapsedType.getRank(); + if (isCollapsed) + std::swap(collapsedType, expandedType); + Value initTensorOp = isCollapsed + ? getCollapsedInitTensor(rewriter, reshapeOp) + : getExpandedInitTensor(rewriter, reshapeOp); + if (!initTensorOp) + return failure(); + rewriter.replaceOp(reshapeOp, initTensorOp); + return success(); + } +}; +} // namespace + void InitTensorOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); + results.insert(context); } //===----------------------------------------------------------------------===// @@ -1043,23 +1154,23 @@ ArrayRef expandedShape = expandedType.getShape(); unsigned expandedDimStart = 0; for (auto map : llvm::enumerate(op.getReassociationMaps())) { - Optional dynamicDims; + Optional dynamicShape; int64_t linearizedStaticShape = 1; for (auto dim : llvm::enumerate(expandedShape.slice( expandedDimStart, map.value().getNumResults()))) { if (ShapedType::isDynamic(dim.value())) { - if (isExpandingReshape && dynamicDims) { + if (isExpandingReshape && dynamicShape) { return op->emitOpError("invalid to have a single dimension (") << map.index() << ") expanded into multiple dynamic dims (" - << expandedDimStart + dynamicDims.getValue() << "," + << expandedDimStart + dynamicShape.getValue() << "," << expandedDimStart + dim.index() << ")"; } - dynamicDims = dim.index(); + dynamicShape = dim.index(); } else { linearizedStaticShape *= dim.value(); } } - if (dynamicDims) { + if (dynamicShape) { if (!ShapedType::isDynamic(collapsedShape[map.index()])) { return op->emitOpError("expected dimension ") << map.index() diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -413,3 +413,39 @@ // CHECK-SAME: [[ARG_0:%.*]]: tensor, [[ARG_1:%.*]]: tensor) // CHECK: dim [[ARG_0]] // CHECK: dim [[ARG_1]] + +// ----- + +func @init_tensor_reshape_expansion(%arg0 : index) -> tensor<2x3x5x4x?x7xf32> { + %0 = linalg.init_tensor [6, 5, %arg0] : tensor<6x5x?xf32> + %1 = linalg.tensor_reshape %0 + [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d2)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>] : + tensor<6x5x?xf32> into tensor<2x3x5x4x?x7xf32> + return %1 : tensor<2x3x5x4x?x7xf32> +} +// CHECK: func @init_tensor_reshape_expansion +// CHECK-SAME: %[[ARG0:.+]]: index +// CHECK: %[[C28:.+]] = constant 28 : index +// CHECK: %[[T0:.+]] = divi_unsigned %[[ARG0]], %[[C28]] +// CHECK: %[[T1:.+]] = linalg.init_tensor [2, 3, 5, 4, %[[T0]], 7] +// CHECK: return %[[T1]] + +// ----- + +func @init_tensor_reshape_collapse(%arg0 : index) -> tensor<6x5x?xf32> { + %0 = linalg.init_tensor [2, 3, 5, 4, %arg0, 7] : tensor<2x3x5x4x?x7xf32> + %1 = linalg.tensor_reshape %0 + [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d2)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>] : + tensor<2x3x5x4x?x7xf32> into tensor<6x5x?xf32> + return %1 : tensor<6x5x?xf32> +} +// CHECK: func @init_tensor_reshape_collapse +// CHECK-SAME: %[[ARG0:.+]]: index +// CHECK: %[[C28:.+]] = constant 28 : index +// CHECK: %[[T0:.+]] = muli %[[ARG0]], %[[C28]] +// CHECK: %[[T1:.+]] = linalg.init_tensor [6, 5, %[[T0]]] +// CHECK: return %[[T1]]