diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -225,8 +225,8 @@ return getResult().getType().cast(); } - // Infer the shape of the result tensor given the static shapes - // and element type of the result tensor. + // Infer the shape of the result tensor given the type of the source tensor + // and paddings. static RankedTensorType inferResultType(RankedTensorType sourceType, ArrayRef staticLow, ArrayRef staticHigh); diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h --- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h @@ -47,7 +47,7 @@ /// Convert reassociation indices to affine expressions. SmallVector, 2> convertReassociationIndicesToExprs( - OpBuilder &b, ArrayRef reassociationIndices); + MLIRContext *context, ArrayRef reassociationIndices); /// Constructs affine maps out of Array>. SmallVector diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1147,16 +1147,16 @@ } SmallVector TensorCollapseShapeOp::getReassociationExprs() { - OpBuilder b(this->getContext()); - return convertReassociationIndicesToExprs(b, getReassociationIndices()); + return convertReassociationIndicesToExprs(getContext(), + getReassociationIndices()); } SmallVector TensorExpandShapeOp::getReassociationMaps() { return getSymbolLessAffineMaps(getReassociationExprs()); } SmallVector TensorExpandShapeOp::getReassociationExprs() { - OpBuilder b(this->getContext()); - return convertReassociationIndicesToExprs(b, getReassociationIndices()); + return convertReassociationIndicesToExprs(getContext(), + getReassociationIndices()); } /// For reshape op compute the shape at dimension `dimIndex` of the output in @@ -1317,7 +1317,7 @@ auto resultType = computeTensorReshapeCollapsedType( src.getType().cast(), getSymbolLessAffineMaps( - convertReassociationIndicesToExprs(b, reassociation))); + convertReassociationIndicesToExprs(b.getContext(), reassociation))); build(b, result, resultType, src, attrs); result.addAttribute(getReassociationAttrName(), getReassociationIndicesAttribute(b, reassociation)); @@ -1330,7 +1330,7 @@ auto resultType = computeTensorReshapeCollapsedType( src.getType().cast(), getSymbolLessAffineMaps( - convertReassociationIndicesToExprs(b, reassociation))); + convertReassociationIndicesToExprs(b.getContext(), reassociation))); build(b, result, resultType, src, attrs); result.addAttribute(getReassociationAttrName(), getReassociationIndicesAttribute(b, reassociation)); diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1316,16 +1316,16 @@ return getSymbolLessAffineMaps(getReassociationExprs()); } SmallVector CollapseShapeOp::getReassociationExprs() { - OpBuilder b(this->getContext()); - return convertReassociationIndicesToExprs(b, getReassociationIndices()); + return convertReassociationIndicesToExprs(getContext(), + getReassociationIndices()); } SmallVector ExpandShapeOp::getReassociationMaps() { return getSymbolLessAffineMaps(getReassociationExprs()); } SmallVector ExpandShapeOp::getReassociationExprs() { - OpBuilder b(this->getContext()); - return convertReassociationIndicesToExprs(b, getReassociationIndices()); + return convertReassociationIndicesToExprs(getContext(), + getReassociationIndices()); } static void print(OpAsmPrinter &p, ExpandShapeOp op) { @@ -1427,8 +1427,8 @@ ArrayRef attrs) { auto memRefType = src.getType().cast(); auto resultType = computeReshapeCollapsedType( - memRefType, getSymbolLessAffineMaps( - convertReassociationIndicesToExprs(b, reassociation))); + memRefType, getSymbolLessAffineMaps(convertReassociationIndicesToExprs( + b.getContext(), reassociation))); build(b, result, resultType, src, attrs); result.addAttribute(getReassociationAttrName(), getReassociationIndicesAttribute(b, reassociation)); @@ -1439,8 +1439,8 @@ ArrayRef attrs) { auto memRefType = src.getType().cast(); auto resultType = computeReshapeCollapsedType( - memRefType, getSymbolLessAffineMaps( - convertReassociationIndicesToExprs(b, reassociation))); + memRefType, getSymbolLessAffineMaps(convertReassociationIndicesToExprs( + b.getContext(), reassociation))); build(b, result, resultType, src, attrs); result.addAttribute(getReassociationAttrName(), getReassociationIndicesAttribute(b, reassociation)); @@ -1475,10 +1475,41 @@ return verifyReshapeOp(op, op.getSrcType(), op.getResultType()); } +struct CollapseShapeOpMemRefCastFolder + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(CollapseShapeOp op, + PatternRewriter &rewriter) const override { + auto cast = op.getOperand().getDefiningOp(); + if (!cast) + return failure(); + + if (!CastOp::canFoldIntoConsumerOp(cast)) + return failure(); + + Type newResultType = computeReshapeCollapsedType( + cast.getOperand().getType().cast(), + op.getReassociationMaps()); + + if (newResultType == op.getResultType()) { + rewriter.updateRootInPlace( + op, [&]() { op.srcMutable().assign(cast.source()); }); + } else { + Value newOp = rewriter.create( + op->getLoc(), cast.source(), op.getReassociationIndices()); + rewriter.replaceOpWithNewOp(op, op.getType(), newOp); + } + return success(); + } +}; + void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add, - CollapseMixedReshapeOps>(context); + CollapseMixedReshapeOps, + CollapseShapeOpMemRefCastFolder>(context); } OpFoldResult ExpandShapeOp::fold(ArrayRef operands) { if (succeeded(foldMemRefCast(*this))) @@ -1486,8 +1517,6 @@ return foldReshapeOp(*this, operands); } OpFoldResult CollapseShapeOp::fold(ArrayRef operands) { - if (succeeded(foldMemRefCast(*this))) - return getResult(); return foldReshapeOp(*this, operands); } diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp --- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp +++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp @@ -183,13 +183,13 @@ SmallVector, 2> mlir::convertReassociationIndicesToExprs( - OpBuilder &b, ArrayRef reassociationIndices) { + MLIRContext *context, ArrayRef reassociationIndices) { SmallVector, 2> reassociationMaps; for (const auto &indices : reassociationIndices) { SmallVector reassociationMap; reassociationMap.reserve(indices.size()); for (int64_t index : indices) - reassociationMap.push_back(b.getAffineDimExpr(index)); + reassociationMap.push_back(mlir::getAffineDimExpr(index, context)); reassociationMaps.push_back(std::move(reassociationMap)); } return reassociationMaps; diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -511,3 +511,31 @@ } // CHECK-LABEL: @fold_memref_reshape_dynamic // CHECK-NOT: linalg.{{.*}}_shape + +// ----- + +// CHECK-LABEL: func @collapse_after_memref_cast_type_change( +// CHECK-SAME: %[[INPUT:.*]]: memref) -> memref { +// CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[INPUT]] +// CHECK-SAME: {{\[\[}}0], [1, 2, 3]] : memref into memref +// CHECK: %[[DYNAMIC:.*]] = memref.cast %[[COLLAPSED]] : +// CHECK-SAME: memref to memref +// CHECK: return %[[DYNAMIC]] : memref +// CHECK: } +func @collapse_after_memref_cast_type_change(%arg0 : memref) -> memref { + %dynamic = memref.cast %arg0: memref to memref + %collapsed = memref.collapse_shape %dynamic [[0], [1, 2, 3]] : memref into memref + return %collapsed : memref +} + +// CHECK-LABEL: func @collapse_after_memref_cast( +// CHECK-SAME: %[[INPUT:.*]]: memref) -> memref { +// CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[INPUT]] +// CHECK_SAME: {{\[\[}}0], [1, 2, 3]] : memref into memref +// CHECK: return %[[COLLAPSED]] : memref +func @collapse_after_memref_cast(%arg0 : memref) -> memref { + %dynamic = memref.cast %arg0: memref to memref + %collapsed = memref.collapse_shape %dynamic [[0], [1, 2, 3]] : memref into memref + return %collapsed : memref +} +