diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h @@ -107,6 +107,13 @@ void getDimsOfType(Operation *op, StringRef iteratorTypeName, SmallVectorImpl &res); +/// For reshape operation, compute the shape of the output based on the result +/// type and shape of the input. +SmallVector +getReshapeOutputShapeFromInputShape(OpBuilder &b, Location loc, Value src, + ArrayRef dstStaticShape, + ArrayRef reassociation); + namespace detail { LogicalResult verifyStructuredOpInterface(Operation *op); } // namespace detail diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -342,10 +342,15 @@ SmallVector getReassociationExprs() { return llvm::to_vector<4>(llvm::map_range(reassociation(), - [](Attribute a) { - return llvm::to_vector<2>( - a.cast().getValue().getResults()); - })); + [](Attribute a) { + return llvm::to_vector<2>( + a.cast().getValue().getResults()); + })); + } + SmallVector getOutputShape(OpBuilder &b, Location loc) { + return getReshapeOutputShapeFromInputShape( + b, loc, src(), getResultType().getShape(), + getReassociationMaps()); } }]; let assemblyFormat = [{ diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -605,85 +605,6 @@ return RankedTensorType::get(staticSizes, elementType); } -namespace { -/// Change the type of the result of a `linalg.init_tensor` by making the result -/// type statically sized along dimension that in the original operation where -/// defined as dynamic, but the size was defined using a `constant` op. For -/// example -/// -/// %c5 = constant 5: index -/// %0 = linalg.init_tensor [%arg0, %c5] : tensor -/// -/// to -/// -/// %0 = linalg.init_tensor [%arg0, 5] : tensor -struct ReplaceStaticShapeDims : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(InitTensorOp op, - PatternRewriter &rewriter) const override { - SmallVector dynamicSizes; - SmallVector staticSizes; - for (unsigned i = 0, e = op.getType().getRank(); i != e; ++i) { - // If the size is already static, nothing to do. - if (!op.isDynamicSize(i)) { - staticSizes.push_back(op.getStaticSize(i)); - continue; - } - - // If the size is dynamic but defined using a `constant` op, get the - // constant value to find the static size to use. - unsigned operandNum = op.getIndexOfDynamicSize(i); - Value sizeOperand = op.getOperand(operandNum); - if (auto constantIndexOp = sizeOperand.getDefiningOp()) { - staticSizes.push_back(constantIndexOp.getValue()); - continue; - } - - // Fallback case. Keep the size dynamic. - dynamicSizes.push_back(sizeOperand); - staticSizes.push_back(ShapedType::kDynamicSize); - } - RankedTensorType newType = - RankedTensorType::get(staticSizes, op.getType().getElementType()); - if (newType == op.getType()) - return failure(); - auto newOp = - rewriter.create(op.getLoc(), newType, dynamicSizes, - rewriter.getI64ArrayAttr(staticSizes)); - rewriter.replaceOpWithNewOp(op, op.getType(), newOp); - return success(); - } -}; - -/// Canonicalize a `linalg.init_tensor` -> `dim` pattern by replacing the `dim` -/// with -/// - A constant value if the size is static along the dimension. -/// - The dynamic value that defines the size of the result of -/// `linalg.init_tensor` op. -struct ReplaceDimOfInitTensorOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(DimOp dimOp, - PatternRewriter &rewriter) const override { - auto initTensorOp = dimOp.memrefOrTensor().getDefiningOp(); - if (!initTensorOp) - return failure(); - auto dimIndex = dimOp.index().getDefiningOp(); - if (!dimIndex) - return failure(); - int64_t index = dimIndex.getValue(); - if (!initTensorOp.isDynamicSize(index)) { - rewriter.replaceOpWithNewOp( - dimOp, initTensorOp.getStaticSize(index)); - } else { - rewriter.replaceOp(dimOp, initTensorOp.getDynamicSize(index)); - } - return success(); - } -}; -} // namespace - static Value getCollapsedInitTensor(OpBuilder &builder, TensorReshapeOp reshapeOp) { Location loc = reshapeOp.getLoc(); @@ -773,6 +694,85 @@ srcType.getElementType()); } +namespace { +/// Change the type of the result of a `linalg.init_tensor` by making the result +/// type statically sized along dimension that in the original operation where +/// defined as dynamic, but the size was defined using a `constant` op. For +/// example +/// +/// %c5 = constant 5: index +/// %0 = linalg.init_tensor [%arg0, %c5] : tensor +/// +/// to +/// +/// %0 = linalg.init_tensor [%arg0, 5] : tensor +struct ReplaceStaticShapeDims : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(InitTensorOp op, + PatternRewriter &rewriter) const override { + SmallVector dynamicSizes; + SmallVector staticSizes; + for (unsigned i = 0, e = op.getType().getRank(); i != e; ++i) { + // If the size is already static, nothing to do. + if (!op.isDynamicSize(i)) { + staticSizes.push_back(op.getStaticSize(i)); + continue; + } + + // If the size is dynamic but defined using a `constant` op, get the + // constant value to find the static size to use. + unsigned operandNum = op.getIndexOfDynamicSize(i); + Value sizeOperand = op.getOperand(operandNum); + if (auto constantIndexOp = sizeOperand.getDefiningOp()) { + staticSizes.push_back(constantIndexOp.getValue()); + continue; + } + + // Fallback case. Keep the size dynamic. + dynamicSizes.push_back(sizeOperand); + staticSizes.push_back(ShapedType::kDynamicSize); + } + RankedTensorType newType = + RankedTensorType::get(staticSizes, op.getType().getElementType()); + if (newType == op.getType()) + return failure(); + auto newOp = + rewriter.create(op.getLoc(), newType, dynamicSizes, + rewriter.getI64ArrayAttr(staticSizes)); + rewriter.replaceOpWithNewOp(op, op.getType(), newOp); + return success(); + } +}; + +/// Canonicalize a `linalg.init_tensor` -> `dim` pattern by replacing the `dim` +/// with +/// - A constant value if the size is static along the dimension. +/// - The dynamic value that defines the size of the result of +/// `linalg.init_tensor` op. +struct ReplaceDimOfInitTensorOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DimOp dimOp, + PatternRewriter &rewriter) const override { + auto initTensorOp = dimOp.memrefOrTensor().getDefiningOp(); + if (!initTensorOp) + return failure(); + auto dimIndex = dimOp.index().getDefiningOp(); + if (!dimIndex) + return failure(); + int64_t index = dimIndex.getValue(); + if (!initTensorOp.isDynamicSize(index)) { + rewriter.replaceOpWithNewOp( + dimOp, initTensorOp.getStaticSize(index)); + } else { + rewriter.replaceOp(dimOp, initTensorOp.getDynamicSize(index)); + } + return success(); + } +}; +} // namespace + namespace { /// Since `init_tensor` operation creates a tensor needed only for its shape, a /// subtensor of this is also needed only for its shape. The result can be @@ -803,17 +803,13 @@ PatternRewriter &rewriter) const override { if (!reshapeOp.src().getDefiningOp()) return failure(); - RankedTensorType collapsedType = reshapeOp.getSrcType(); - RankedTensorType expandedType = reshapeOp.getResultType(); - bool isCollapsed = expandedType.getRank() < collapsedType.getRank(); - if (isCollapsed) - std::swap(collapsedType, expandedType); - Value initTensorOp = isCollapsed - ? getCollapsedInitTensor(rewriter, reshapeOp) - : getExpandedInitTensor(rewriter, reshapeOp); - if (!initTensorOp) - return failure(); - rewriter.replaceOp(reshapeOp, initTensorOp); + Location loc = reshapeOp.getLoc(); + SmallVector resultShapeValues = + reshapeOp.getOutputShape(rewriter, loc); + Value initTensor = rewriter.create( + loc, resultShapeValues, reshapeOp.getResultType().getElementType()); + rewriter.replaceOpWithNewOp( + reshapeOp, reshapeOp.getResultType(), initTensor); return success(); } }; @@ -1255,6 +1251,141 @@ return reassociationMaps; } +/// For reshape op compute the shape at dimension `dimIndex` of the output in +/// terms of shape of the `src`, when the reshape op is a collapsing +/// operation. It is the product of the shape of the collapsed dimensions of the +/// `src`. +static Value +getCollapsedOutputDimFromInputShape(OpBuilder &builder, Location loc, + int64_t dimIndex, Value src, + ArrayRef reassociationMap) { + AffineMap map = reassociationMap[dimIndex]; + unsigned startPos = + map.getResults().front().cast().getPosition(); + unsigned endPos = map.getResults().back().cast().getPosition(); + AffineExpr expr; + SmallVector dynamicDims; + for (auto dim : llvm::seq(startPos, endPos + 1)) { + dynamicDims.push_back(builder.create(loc, src, dim)); + AffineExpr currExpr = builder.getAffineSymbolExpr(dim - startPos); + expr = (expr ? expr * currExpr : currExpr); + } + return applyMapToValues(builder, loc, + AffineMap::get(0, endPos - startPos + 1, expr), + dynamicDims)[0]; +} + +/// Given the `src` of a collapsing reshape op and its reassociation maps, +/// compute the shape of the result of the reshape. +static SmallVector getCollapsedOutputShapeFromInputShape( + OpBuilder &builder, Location loc, Value src, + ArrayRef dstStaticShape, ArrayRef reassociation) { + return llvm::to_vector<4>(llvm::map_range( + llvm::seq(0, dstStaticShape.size()), [&](int64_t dim) { + return getCollapsedOutputDimFromInputShape(builder, loc, dim, src, + reassociation); + })); +} + +/// Compute a map that for a given dimension of the expanded type gives the +/// dimension in the collapsed type it maps to. Essentially its the inverse of +/// the `reassocation` maps. +static llvm::DenseMap +getExpandedDimToCollapsedDimMap(ArrayRef reassociation) { + llvm::DenseMap expandedDimToCollapsedDim; + for (auto map : enumerate(reassociation)) { + unsigned startPos = + map.value().getResults().front().cast().getPosition(); + unsigned endPos = + map.value().getResults().back().cast().getPosition(); + for (auto dim : llvm::seq(startPos, endPos + 1)) { + expandedDimToCollapsedDim[dim] = map.index(); + } + } + return expandedDimToCollapsedDim; +} + +/// For an expanding reshape op, compute the value for a dimension of the output +/// from the shape of the input. +static Value getExpandedOutputDimFromInputShape( + OpBuilder &builder, Location loc, int64_t dimIndex, Value src, + ArrayRef dstStaticShape, ArrayRef reassociation, + llvm::DenseMap &expandedDimToCollapsedDim) { + if (!ShapedType::isDynamic(dstStaticShape[dimIndex])) { + return builder.create(loc, dstStaticShape[dimIndex]); + } + unsigned sourceDimPos = expandedDimToCollapsedDim[dimIndex]; + unsigned startPos = reassociation[sourceDimPos] + .getResults() + .front() + .cast() + .getPosition(); + unsigned endPos = reassociation[sourceDimPos] + .getResults() + .back() + .cast() + .getPosition(); + int64_t linearizedStaticDim = 1; + for (auto d : + llvm::enumerate(dstStaticShape.slice(startPos, endPos - startPos + 1))) { + if (d.index() + startPos == static_cast(dimIndex)) + continue; + assert(!ShapedType::isDynamic(d.value()) && + "single dimension cannot be expanded into multiple dynamic " + "dimensions"); + linearizedStaticDim *= d.value(); + } + Value sourceDim = builder.create(loc, src, sourceDimPos); + return applyMapToValues( + builder, loc, + AffineMap::get( + 0, 1, builder.getAffineSymbolExpr(0).floorDiv(linearizedStaticDim)), + sourceDim)[0]; +} + +/// Given the `src` of an expanding reshape op, the reassociation maps and the +/// result type, compute the shape of the result of the reshape. +static SmallVector getExpandedOutputShapeFromInputShape( + OpBuilder &builder, Location loc, Value src, + ArrayRef dstStaticShape, ArrayRef reassociation) { + llvm::DenseMap expandedDimToCollapsedDim = + getExpandedDimToCollapsedDimMap(reassociation); + return llvm::to_vector<4>(llvm::map_range( + llvm::seq(0, dstStaticShape.size()), [&](int64_t dim) { + return getExpandedOutputDimFromInputShape(builder, loc, dim, src, + dstStaticShape, reassociation, + expandedDimToCollapsedDim); + })); +} + +SmallVector mlir::linalg::getReshapeOutputShapeFromInputShape( + OpBuilder &builder, Location loc, Value src, + ArrayRef dstStaticShape, ArrayRef reassocation) { + return dstStaticShape.size() > + static_cast(src.getType().cast().getRank()) + ? getExpandedOutputShapeFromInputShape( + builder, loc, src, dstStaticShape, reassocation) + : getCollapsedOutputShapeFromInputShape( + builder, loc, src, dstStaticShape, reassocation); +} + +/// For a reshape op, compute the value of a given dimension of the output +/// (`dimIndex`) from the shape of the inputs and type of the result. +static Value getReshapeOutputDimFromInputShape( + OpBuilder &builder, Location loc, int64_t dimIndex, Value src, + ArrayRef dstStaticShape, ArrayRef reassociation) { + if (dstStaticShape.size() > + static_cast(src.getType().cast().getRank())) { + llvm::DenseMap expandedDimToCollapsedDim = + getExpandedDimToCollapsedDimMap(reassociation); + return getExpandedOutputDimFromInputShape(builder, loc, dimIndex, src, + dstStaticShape, reassociation, + expandedDimToCollapsedDim); + } + return getCollapsedOutputDimFromInputShape(builder, loc, dimIndex, src, + reassociation); +} + void mlir::linalg::ReshapeOp::build(OpBuilder &b, OperationState &result, Value src, ArrayRef reassociation, @@ -1478,12 +1609,35 @@ return success(); } }; + +/// Canonicalize dim ops that use the output shape with dim of the input. +struct ReplaceDimOfReshapeOpResult : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(DimOp dimOp, + PatternRewriter &rewriter) const override { + Value dimValue = dimOp.memrefOrTensor(); + Optional dimIndex = dimOp.getConstantIndex(); + if (!dimIndex) + return failure(); + + auto reshapeOp = dimValue.getDefiningOp(); + if (!reshapeOp) + return failure(); + + rewriter.replaceOp(dimOp, + getReshapeOutputDimFromInputShape( + rewriter, dimOp.getLoc(), *dimIndex, reshapeOp.src(), + reshapeOp.getResultType().getShape(), + reshapeOp.getReassociationMaps())); + return success(); + } +}; } // namespace void TensorReshapeOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { - results.insert, FoldReshapeWithConstant>( - context); + results.insert, FoldReshapeWithConstant, + ReplaceDimOfReshapeOpResult>(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -560,10 +560,10 @@ tensor<6x5x?xf32> into tensor<2x3x5x4x?x7xf32> return %1 : tensor<2x3x5x4x?x7xf32> } +// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)> // CHECK: func @init_tensor_reshape_expansion // CHECK-SAME: %[[ARG0:.+]]: index -// CHECK: %[[C28:.+]] = constant 28 : index -// CHECK: %[[T0:.+]] = divi_unsigned %[[ARG0]], %[[C28]] +// CHECK: %[[T0:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]] // CHECK: %[[T1:.+]] = linalg.init_tensor [2, 3, 5, 4, %[[T0]], 7] // CHECK: return %[[T1]] @@ -578,10 +578,10 @@ tensor<2x3x5x4x?x7xf32> into tensor<6x5x?xf32> return %1 : tensor<6x5x?xf32> } +// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 * 28)> // CHECK: func @init_tensor_reshape_collapse // CHECK-SAME: %[[ARG0:.+]]: index -// CHECK: %[[C28:.+]] = constant 28 : index -// CHECK: %[[T0:.+]] = muli %[[ARG0]], %[[C28]] +// CHECK: %[[T0:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]] // CHECK: %[[T1:.+]] = linalg.init_tensor [6, 5, %[[T0]]] // CHECK: return %[[T1]] @@ -716,3 +716,54 @@ } : tensor to tensor<2x4xf32> return } + +// ----- + +func @dim_reshape_expansion(%arg0 : tensor<6x5x?xf32>) -> (index, index, index) +{ + %c1 = constant 1 : index + %c3 = constant 3 : index + %c4 = constant 4 : index + %0 = linalg.tensor_reshape %arg0 + [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d2)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>] : + tensor<6x5x?xf32> into tensor<2x3x5x4x?x7xf32> + %1 = dim %0, %c1 : tensor<2x3x5x4x?x7xf32> + %2 = dim %0, %c3 : tensor<2x3x5x4x?x7xf32> + %3 = dim %0, %c4 : tensor<2x3x5x4x?x7xf32> + return %1, %2, %3 : index, index, index +} +// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)> +// CHECK: func @dim_reshape_expansion +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<6x5x?xf32> +// CHECK-DAG: %[[C2:.+]] = constant 2 : index +// CHECK-DAG: %[[C3:.+]] = constant 3 : index +// CHECK-DAG: %[[C4:.+]] = constant 4 : index +// CHECK: %[[D0:.+]] = dim %[[ARG0]], %[[C2]] +// CHECK: %[[D1:.+]] = affine.apply #[[MAP]]()[%[[D0]]] +// CHECK: return %[[C3]], %[[C4]], %[[D1]] + +// ----- + +func @dim_reshape_collapse(%arg0 : tensor<2x3x5x4x?x7xf32>) -> (index, index) +{ + %c1 = constant 1 : index + %c2 = constant 2 : index + %0 = linalg.tensor_reshape %arg0 + [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d2)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>] : + tensor<2x3x5x4x?x7xf32> into tensor<6x5x?xf32> + %1 = dim %0, %c1 : tensor<6x5x?xf32> + %2 = dim %0, %c2 : tensor<6x5x?xf32> + return %1, %2 : index, index +} +// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 * 28)> +// CHECK: func @dim_reshape_collapse +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x3x5x4x?x7xf32> +// CHECK-DAG: %[[C4:.+]] = constant 4 : index +// CHECK-DAG: %[[C5:.+]] = constant 5 : index +// CHECK: %[[D0:.+]] = dim %[[ARG0]], %[[C4]] +// CHECK: %[[D1:.+]] = affine.apply #[[MAP]]()[%[[D0]]] +// CHECK: return %[[C5]], %[[D1]]