diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h @@ -107,6 +107,13 @@ void getDimsOfType(Operation *op, StringRef iteratorTypeName, SmallVectorImpl &res); +/// For reshape operation, compute the shape of the output based on the result +/// type and shape of the input. +SmallVector +getReshapeOutputShapeFromInputShape(OpBuilder &b, Location loc, Value src, + ArrayRef dstStaticShape, + ArrayRef reassociation); + namespace detail { LogicalResult verifyStructuredOpInterface(Operation *op); } // namespace detail diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -340,10 +340,15 @@ SmallVector getReassociationExprs() { return llvm::to_vector<4>(llvm::map_range(reassociation(), - [](Attribute a) { - return llvm::to_vector<2>( - a.cast().getValue().getResults()); - })); + [](Attribute a) { + return llvm::to_vector<2>( + a.cast().getValue().getResults()); + })); + } + SmallVector getOutputShape(OpBuilder &b, Location loc) { + return getReshapeOutputShapeFromInputShape( + b, loc, src(), getResultType().getShape(), + getReassociationMaps()); } }]; let assemblyFormat = [{ diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -397,7 +397,6 @@ // InitTensorOp //===----------------------------------------------------------------------===// - static LogicalResult verify(InitTensorOp op) { RankedTensorType resultType = op.getType(); SmallVector staticSizes = llvm::to_vector<4>(llvm::map_range( @@ -507,95 +506,6 @@ }; } // namespace -static Value getCollapsedInitTensor(OpBuilder &builder, - TensorReshapeOp reshapeOp) { - Location loc = reshapeOp.getLoc(); - SmallVector dynamicShapes; - SmallVector staticShapes; - auto reassociation = reshapeOp.getReassociationMaps(); - Value src = reshapeOp.src(); - RankedTensorType srcType = reshapeOp.getSrcType(); - ArrayRef srcShape = srcType.getShape(); - for (auto map : reassociation) { - Value linearizedDynamicDim = nullptr; - int64_t linearizedStaticDim = 1; - for (unsigned i : llvm::map_range(map.getResults(), [](AffineExpr e) { - return e.cast().getPosition(); - })) { - if (ShapedType::isDynamic(srcShape[i])) { - Value shapeVal = builder.create(loc, src, i); - if (linearizedDynamicDim) { - linearizedDynamicDim = - builder.create(loc, linearizedDynamicDim, shapeVal); - } else { - linearizedDynamicDim = shapeVal; - } - } else { - linearizedStaticDim *= srcShape[i]; - } - } - if (linearizedDynamicDim) { - if (linearizedStaticDim != 1) { - linearizedDynamicDim = builder.create( - loc, linearizedDynamicDim, - builder.create(loc, linearizedStaticDim)); - } - dynamicShapes.push_back(linearizedDynamicDim); - staticShapes.push_back(ShapedType::kDynamicSize); - } else { - staticShapes.push_back(linearizedStaticDim); - } - } - return builder.create(loc, dynamicShapes, staticShapes, - srcType.getElementType()); -} - -static Value getExpandedInitTensor(OpBuilder &builder, - TensorReshapeOp reshapeOp) { - SmallVector dynamicShapes; - SmallVector staticShapes; - auto reassociation = reshapeOp.getReassociationMaps(); - Value src = reshapeOp.src(); - RankedTensorType srcType = reshapeOp.getSrcType(); - ArrayRef srcShape = srcType.getShape(); - ArrayRef dstShape = reshapeOp.getResultType().getShape(); - Location loc = reshapeOp.getLoc(); - for (auto map : enumerate(reassociation)) { - int64_t linearizedStaticDim = 1; - bool hasDynamic = false; - for (unsigned i : - llvm::map_range(map.value().getResults(), [](AffineExpr e) { - return e.cast().getPosition(); - })) { - if (ShapedType::isDynamic(dstShape[i])) { - // Only one of the dimensions of the expanded shape should be dynamic. - if (hasDynamic) - return nullptr; - hasDynamic = true; - staticShapes.push_back(ShapedType::kDynamicSize); - continue; - } - staticShapes.push_back(dstShape[i]); - linearizedStaticDim *= dstShape[i]; - } - if (hasDynamic) { - // If the expanded dimensions has a dynamic shape, the src shape must be - // dynamic as well. - if (!ShapedType::isDynamic(srcShape[map.index()])) - return nullptr; - Value dynamicDim = builder.create(loc, src, map.index()); - if (linearizedStaticDim != 1) { - dynamicDim = builder.create( - loc, dynamicDim, - builder.create(loc, linearizedStaticDim)); - } - dynamicShapes.push_back(dynamicDim); - } - } - return builder.create(loc, dynamicShapes, staticShapes, - srcType.getElementType()); -} - namespace { /// Since `init_tensor` operation creates a tensor needed only for its shape, a /// subtensor of this is also needed only for its shape. The result can be @@ -626,17 +536,13 @@ PatternRewriter &rewriter) const override { if (!reshapeOp.src().getDefiningOp()) return failure(); - RankedTensorType collapsedType = reshapeOp.getSrcType(); - RankedTensorType expandedType = reshapeOp.getResultType(); - bool isCollapsed = expandedType.getRank() < collapsedType.getRank(); - if (isCollapsed) - std::swap(collapsedType, expandedType); - Value initTensorOp = isCollapsed - ? getCollapsedInitTensor(rewriter, reshapeOp) - : getExpandedInitTensor(rewriter, reshapeOp); - if (!initTensorOp) - return failure(); - rewriter.replaceOp(reshapeOp, initTensorOp); + Location loc = reshapeOp.getLoc(); + SmallVector resultShapeValues = + reshapeOp.getOutputShape(rewriter, loc); + Value initTensor = rewriter.create( + loc, resultShapeValues, reshapeOp.getResultType().getElementType()); + rewriter.replaceOpWithNewOp( + reshapeOp, reshapeOp.getResultType(), initTensor); return success(); } }; @@ -1096,6 +1002,141 @@ return reassociationMaps; } +/// For reshape op compute the shape at dimension `dimIndex` of the output in +/// terms of shape of the `src`, when the reshape op is a collapsing +/// operation. It is the product of the shape of the collapsed dimensions of the +/// `src`. +static Value +getCollapsedOutputDimFromInputShape(OpBuilder &builder, Location loc, + int64_t dimIndex, Value src, + ArrayRef reassociationMap) { + AffineMap map = reassociationMap[dimIndex]; + unsigned startPos = + map.getResults().front().cast().getPosition(); + unsigned endPos = map.getResults().back().cast().getPosition(); + AffineExpr expr = nullptr; + SmallVector dynamicDims; + for (auto dim : llvm::seq(startPos, endPos + 1)) { + dynamicDims.push_back(builder.create(loc, src, dim)); + AffineExpr currExpr = builder.getAffineSymbolExpr(dim - startPos); + expr = (expr ? expr * currExpr : currExpr); + } + return applyMapToValues(builder, loc, + AffineMap::get(0, endPos - startPos + 1, expr), + dynamicDims)[0]; +} + +/// Given the `src` of a collapsing reshape op and its reassociation maps, +/// compute the shape of the result of the reshape. +static SmallVector getCollapsedOutputShapeFromInputShape( + OpBuilder &builder, Location loc, Value src, + ArrayRef dstStaticShape, ArrayRef reassociation) { + return llvm::to_vector<4>(llvm::map_range( + llvm::seq(0, dstStaticShape.size()), [&](int64_t dim) { + return getCollapsedOutputDimFromInputShape(builder, loc, dim, src, + reassociation); + })); +} + +/// Compute a map that for a given dimension of the expanded type gives the +/// dimension in the collapsed type it maps to. Essentially its the inverse of +/// the `reassocation` maps. +static llvm::DenseMap +getExpandedDimToCollapsedDimMap(ArrayRef reassociation) { + llvm::DenseMap expandedDimToCollapsedDim; + for (auto map : enumerate(reassociation)) { + unsigned startPos = + map.value().getResults().front().cast().getPosition(); + unsigned endPos = + map.value().getResults().back().cast().getPosition(); + for (auto dim : llvm::seq(startPos, endPos + 1)) { + expandedDimToCollapsedDim[dim] = map.index(); + } + } + return expandedDimToCollapsedDim; +} + +/// For an expanding reshape op, compute the value for a dimension of the output +/// from the shape of the input. +static Value getExpandedOutputDimFromInputShape( + OpBuilder &builder, Location loc, int64_t dimIndex, Value src, + ArrayRef dstStaticShape, ArrayRef reassociation, + llvm::DenseMap &expandedDimToCollapsedDim) { + if (!ShapedType::isDynamic(dstStaticShape[dimIndex])) { + return builder.create(loc, dstStaticShape[dimIndex]); + } + unsigned sourceDimPos = expandedDimToCollapsedDim[dimIndex]; + unsigned startPos = reassociation[sourceDimPos] + .getResults() + .front() + .cast() + .getPosition(); + unsigned endPos = reassociation[sourceDimPos] + .getResults() + .back() + .cast() + .getPosition(); + int64_t linearizedStaticDim = 1; + for (auto d : + llvm::enumerate(dstStaticShape.slice(startPos, endPos - startPos + 1))) { + if (d.index() + startPos == static_cast(dimIndex)) + continue; + assert(!ShapedType::isDynamic(d.value()) && + "single dimension cannot be expanded into multiple dynamic " + "dimensions"); + linearizedStaticDim *= d.value(); + } + Value sourceDim = builder.create(loc, src, sourceDimPos); + return applyMapToValues( + builder, loc, + AffineMap::get( + 0, 1, builder.getAffineSymbolExpr(0).floorDiv(linearizedStaticDim)), + sourceDim)[0]; +} + +/// Given the `src` of an expanding reshape op, the reassociation maps and the +/// result type, compute the shape of the result of the reshape. +static SmallVector getExpandedOutputShapeFromInputShape( + OpBuilder &builder, Location loc, Value src, + ArrayRef dstStaticShape, ArrayRef reassociation) { + llvm::DenseMap expandedDimToCollapsedDim = + getExpandedDimToCollapsedDimMap(reassociation); + return llvm::to_vector<4>(llvm::map_range( + llvm::seq(0, dstStaticShape.size()), [&](int64_t dim) { + return getExpandedOutputDimFromInputShape(builder, loc, dim, src, + dstStaticShape, reassociation, + expandedDimToCollapsedDim); + })); +} + +SmallVector mlir::linalg::getReshapeOutputShapeFromInputShape( + OpBuilder &builder, Location loc, Value src, + ArrayRef dstStaticShape, ArrayRef reassocation) { + return dstStaticShape.size() > + static_cast(src.getType().cast().getRank()) + ? getExpandedOutputShapeFromInputShape( + builder, loc, src, dstStaticShape, reassocation) + : getCollapsedOutputShapeFromInputShape( + builder, loc, src, dstStaticShape, reassocation); +} + +/// For a reshape op, compute the value of a given dimension of the output +/// (`dimIndex`) from the shape of the inputs and type of the result. +static Value getReshapeOutputDimFromInputShape( + OpBuilder &builder, Location loc, int64_t dimIndex, Value src, + ArrayRef dstStaticShape, ArrayRef reassociation) { + if (dstStaticShape.size() > + static_cast(src.getType().cast().getRank())) { + llvm::DenseMap expandedDimToCollapsedDim = + getExpandedDimToCollapsedDimMap(reassociation); + return getExpandedOutputDimFromInputShape(builder, loc, dimIndex, src, + dstStaticShape, reassociation, + expandedDimToCollapsedDim); + } + return getCollapsedOutputDimFromInputShape(builder, loc, dimIndex, src, + reassociation); +} + void mlir::linalg::ReshapeOp::build(OpBuilder &b, OperationState &result, Value src, ArrayRef reassociation, @@ -1319,12 +1360,35 @@ return success(); } }; + +/// Canonicalize dim ops that use the output shape with dim of the input. +struct ReplaceDimOfReshapeOpResult : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(DimOp dimOp, + PatternRewriter &rewriter) const override { + Value dimValue = dimOp.memrefOrTensor(); + Optional dimIndex = dimOp.getConstantIndex(); + if (!dimIndex) + return failure(); + + auto reshapeOp = dimValue.getDefiningOp(); + if (!reshapeOp) + return failure(); + + rewriter.replaceOp(dimOp, + getReshapeOutputDimFromInputShape( + rewriter, dimOp.getLoc(), *dimIndex, reshapeOp.src(), + reshapeOp.getResultType().getShape(), + reshapeOp.getReassociationMaps())); + return success(); + } +}; } // namespace void TensorReshapeOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { - results.insert, FoldReshapeWithConstant>( - context); + results.insert, FoldReshapeWithConstant, + ReplaceDimOfReshapeOpResult>(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -560,10 +560,10 @@ tensor<6x5x?xf32> into tensor<2x3x5x4x?x7xf32> return %1 : tensor<2x3x5x4x?x7xf32> } +// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)> // CHECK: func @init_tensor_reshape_expansion // CHECK-SAME: %[[ARG0:.+]]: index -// CHECK: %[[C28:.+]] = constant 28 : index -// CHECK: %[[T0:.+]] = divi_unsigned %[[ARG0]], %[[C28]] +// CHECK: %[[T0:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]] // CHECK: %[[T1:.+]] = linalg.init_tensor [2, 3, 5, 4, %[[T0]], 7] // CHECK: return %[[T1]] @@ -578,10 +578,10 @@ tensor<2x3x5x4x?x7xf32> into tensor<6x5x?xf32> return %1 : tensor<6x5x?xf32> } +// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 * 28)> // CHECK: func @init_tensor_reshape_collapse // CHECK-SAME: %[[ARG0:.+]]: index -// CHECK: %[[C28:.+]] = constant 28 : index -// CHECK: %[[T0:.+]] = muli %[[ARG0]], %[[C28]] +// CHECK: %[[T0:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]] // CHECK: %[[T1:.+]] = linalg.init_tensor [6, 5, %[[T0]]] // CHECK: return %[[T1]] @@ -716,3 +716,54 @@ } : tensor to tensor<2x4xf32> return } + +// ----- + +func @dim_reshape_expansion(%arg0 : tensor<6x5x?xf32>) -> (index, index, index) +{ + %c1 = constant 1 : index + %c3 = constant 3 : index + %c4 = constant 4 : index + %0 = linalg.tensor_reshape %arg0 + [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d2)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>] : + tensor<6x5x?xf32> into tensor<2x3x5x4x?x7xf32> + %1 = dim %0, %c1 : tensor<2x3x5x4x?x7xf32> + %2 = dim %0, %c3 : tensor<2x3x5x4x?x7xf32> + %3 = dim %0, %c4 : tensor<2x3x5x4x?x7xf32> + return %1, %2, %3 : index, index, index +} +// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)> +// CHECK: func @dim_reshape_expansion +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<6x5x?xf32> +// CHECK-DAG: %[[C2:.+]] = constant 2 : index +// CHECK-DAG: %[[C3:.+]] = constant 3 : index +// CHECK-DAG: %[[C4:.+]] = constant 4 : index +// CHECK: %[[D0:.+]] = dim %[[ARG0]], %[[C2]] +// CHECK: %[[D1:.+]] = affine.apply #[[MAP]]()[%[[D0]]] +// CHECK: return %[[C3]], %[[C4]], %[[D1]] + +// ----- + +func @dim_reshape_collapse(%arg0 : tensor<2x3x5x4x?x7xf32>) -> (index, index) +{ + %c1 = constant 1 : index + %c2 = constant 2 : index + %0 = linalg.tensor_reshape %arg0 + [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d2)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>] : + tensor<2x3x5x4x?x7xf32> into tensor<6x5x?xf32> + %1 = dim %0, %c1 : tensor<6x5x?xf32> + %2 = dim %0, %c2 : tensor<6x5x?xf32> + return %1, %2 : index, index +} +// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 * 28)> +// CHECK: func @dim_reshape_collapse +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x3x5x4x?x7xf32> +// CHECK-DAG: %[[C4:.+]] = constant 4 : index +// CHECK-DAG: %[[C5:.+]] = constant 5 : index +// CHECK: %[[D0:.+]] = dim %[[ARG0]], %[[C4]] +// CHECK: %[[D1:.+]] = affine.apply #[[MAP]]()[%[[D0]]] +// CHECK: return %[[C5]], %[[D1]]