diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -330,7 +330,12 @@ // be either a contracting or expanding reshape. OpBuilder<(ins "Type":$resultType, "Value":$src, "ArrayRef":$reassociation, - CArg<"ArrayRef", "{}">:$attrs)>, + CArg<"ArrayRef", "{}">:$attrs), + [{ + build($_builder, $_state, resultType, src, attrs); + $_state.addAttribute("reassociation", + getReassociationIndicesAttribute($_builder, reassociation)); + }]>, OpBuilder<(ins "Type":$resultType, "Value":$src, "ArrayRef":$reassociation, CArg<"ArrayRef", "{}">:$attrs), @@ -355,21 +360,33 @@ return reassociationIndices; }; }]; + + let parser = [{ return ::parseReshapeLikeOp(parser, result); }]; } def IndexListArrayAttr : TypedArrayAttrBase; -def Linalg_ReshapeOp : Linalg_ReshapeLikeOp<"reshape", +class Linalg_ReshapeOp : Linalg_ReshapeLikeOp]>, Arguments<(ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation)>, Results<(outs AnyStridedMemRef:$result)> { - let summary = "linalg.reshape produces a new view into the operand view"; + let extraClassDeclaration = commonExtraClassDeclaration # [{ + MemRefType getSrcType() { return src().getType().cast(); } + MemRefType getResultType() { return result().getType().cast(); } + }]; + let hasFolder = 1; + let hasCanonicalizer = 1; + let printer = [{ return ::print(p, *this); }]; +} + +def Linalg_ExpandShapeOp : Linalg_ReshapeOp<"expand_shape"> { + let summary = "operation to produce a memref with a higher rank."; let description = [{ - The `linalg.reshape` op produces a new view whose sizes are a reassociation - of the original `view`. Depending on whether or not the reassociated - MemRefType is contiguous, the resulting memref may require explicit alloc - and copies. + The `linalg.expand_shape` op produces a new view with a higher rank whose + sizes are a reassociation of the original `view`. Depending on whether or + not the reassociated MemRefType is contiguous, the resulting memref may + require explicit alloc and copies. A reassociation is defined as a continuous grouping of dimensions and is represented with an array of I64ArrayAttr attribute. @@ -381,85 +398,67 @@ All other cases are undefined behavior and a reshape op may not lower to LLVM if it cannot be proven statically that it does not require alloc+copy. - A reshape may either collapse or expand dimensions, depending on the - relationship between source and target memref ranks. The verification rule - is that the reassociation maps are applied to the memref with the larger - rank to obtain the memref with the smaller rank. In the case of a dimension - expansion, the reassociation maps can be interpreted as inverse maps. - - The result memref type of a reshape when dimensions are collapsed - (operand memref type when dimensions are expanded) can be - zero-ranked if the operand memref type (or the result memref type - when dimensions are expanded) is statically shaped with all - dimensions being unit extent. In such cases the reassociation map - is empty. + The operand memref type when dimensions can be zero-ranked if the result + memref type is statically shaped with all dimensions being unit extent. In + such case the reassociation map is empty. - Examples: + The verification rule is that the reassociation maps are applied to the + result memref with the larger rank to obtain the operand memref with the + smaller rank. - ```mlir - // Dimension collapse (i, j) -> i' and k -> k' - %1 = linalg.reshape %0 [[0, 1], [2]] : - memref into memref - ``` + Example: ```mlir // Dimension expansion i -> (i', j') and (k) -> (k') - %1 = linalg.reshape %0 [[0, 1], [2]] : + %1 = linalg.expand_shape %0 [[0, 1], [2]] : memref into memref ``` }]; - let extraClassDeclaration = commonExtraClassDeclaration # [{ - MemRefType getSrcType() { return src().getType().cast(); } - MemRefType getResultType() { return result().getType().cast(); } - }]; - let hasFolder = 1; - let hasCanonicalizer = 1; - let printer = [{ return ::print(p, *this); }]; - let parser = [{ return ::parseReshapeLikeOp(parser, result); }]; } -def Linalg_TensorReshapeOp : Linalg_ReshapeLikeOp< - "tensor_reshape", - [DeclareOpInterfaceMethods]>, - Arguments<(ins AnyTensor:$src, - IndexListArrayAttr:$reassociation)>, - Results<(outs AnyTensor:$result)> { - let summary = "linalg.tensor_reshape produces a new reshaped tensor."; +def Linalg_CollapseShapeOp : Linalg_ReshapeOp<"collapse_shape"> { + let summary = "operation to produce a memref with a smaller rank."; let description = [{ - The `linalg.reshape` op produces a new tensor whose sizes are a - reassociation of the original `src`. + The `linalg.collapse_shape` op produces a new view with a smaller rank + whose sizes are a reassociation of the original `view`. Depending on + whether or not the reassociated MemRefType is contiguous, the resulting + memref may require explicit alloc and copies. A reassociation is defined as a continuous grouping of dimensions and is represented with an array of I64ArrayAttr attribute. - A reshape may either collapse or expand dimensions, depending on the - relationship between source and target tensor ranks. The verification rule - is that the reassociation maps are applied to the tensor with the larger - rank to obtain the tensor with the smaller rank. In the case of a dimension - expansion, the reassociation maps can be interpreted as inverse maps. + For now, it is assumed that either: + 1. a reassociation produces and consumes contiguous MemRefType or, + 2. the reshape op will be folded into its consumers (by changing the shape + of the computations). + All other cases are undefined behavior and a reshape op may not lower to + LLVM if it cannot be proven statically that it does not require alloc+copy. + + The result memref type of a reshape can be zero-ranked if the operand + memref type is statically shaped with all dimensions being unit extent. In + such case the reassociation map is empty. - The result tensor type of a reshape when dimensions are collapsed - (operand tensor type when dimensions are expanded) can be - zero-ranked if the operand tensor type (or the result tensor type - when dimensions are expanded) is statically shaped with all - dimensions being unit extent. In such cases the reassociation map - is empty. + The verification rule is that the reassociation maps are applied to the + operand memref with the larger rank to obtain the result memref with the + smaller rank. Examples: ```mlir // Dimension collapse (i, j) -> i' and k -> k' - %b = linalg.tensor_reshape %a [[0, 1], [2]] - : tensor into tensor - ``` - - ```mlir - // Dimension expansion i -> (i', j') and (k) -> (k') - %b = linalg.tensor_reshape %a [[0, 1], [2]] - : tensor into tensor + %1 = linalg.collapse_shape %0 [[0, 1], [2]] : + memref into memref ``` }]; +} + +class Linalg_TensorReshapeOp : Linalg_ReshapeLikeOp< + mnemonic, + [DeclareOpInterfaceMethods]>, + Arguments<(ins AnyTensor:$src, + IndexListArrayAttr:$reassociation)>, + Results<(outs AnyTensor:$result)> { let extraClassDeclaration = commonExtraClassDeclaration # [{ RankedTensorType getSrcType() { return src().getType().cast(); @@ -474,6 +473,60 @@ let parser = [{ return ::parseReshapeLikeOp(parser, result); }]; } +def Linalg_TensorExpandShapeOp : Linalg_TensorReshapeOp<"tensor_expand_shape"> { + let summary = "operation to produce a tensor with a higher rank"; + let description = [{ + The `linalg.tensor_expand_shape` op produces a new tensor with a higher + rank whose sizes are a reassociation of the original `src`. + + A reassociation is defined as a continuous grouping of dimensions and is + represented with an array of I64ArrayAttr attribute. + + The verification rule is that the reassociation maps are applied to the + result tensor with the higher rank to obtain the operand tensor with the + smaller rank. + + The operand tensor type of a reshape can be zero-ranked if the result + tensor type is statically shaped with all dimensions being unit extent. In + such cases the reassociation map is empty. + + Examples: + + ```mlir + // Dimension expansion i -> (i', j') and (k) -> (k') + %b = linalg.tensor_expand_shape %a [[0, 1], [2]] + : tensor into tensor + ``` + }]; +} + +def Linalg_TensorCollapseShapeOp : Linalg_TensorReshapeOp<"tensor_collapse_shape"> { + let summary = "operation to produce a tensor with a smaller rank"; + let description = [{ + The `linalg.tensor_collapse_shape` op produces a new tensor with a smaller + rank whose sizes are a reassociation of the original `src`. + + A reassociation is defined as a continuous grouping of dimensions and is + represented with an array of I64ArrayAttr attribute. + + The verification rule is that the reassociation maps are applied to the + operand tensor with the higher rank to obtain the result tensor with the + smaller rank. + + The result tensor type of a reshape can be zero-ranked if the operand + tensor type is statically shaped with all dimensions being unit extent. In + such case the reassociation map is empty. + + Examples: + + ```mlir + // Dimension collapse (i, j) -> i' and k -> k' + %b = linalg.tensor_collapse_shape %a [[0, 1], [2]] + : tensor into tensor + ``` + }]; +} + def Linalg_YieldOp : Linalg_Op<"yield", [NoSideEffect, ReturnLike, Terminator]>, Arguments<(ins Variadic:$values)> { let summary = "Linalg yield operation"; diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp --- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp +++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp @@ -95,9 +95,11 @@ // ReshapeOp creates a new view descriptor of the proper rank. // For now, the only conversion supported is for target MemRef with static sizes // and strides. +template class ReshapeOpConversion : public ConvertOpToLLVMPattern { public: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + using ReshapeOpAdaptor = typename ReshapeOp::Adaptor; LogicalResult matchAndRewrite(ReshapeOp reshapeOp, ArrayRef operands, @@ -118,8 +120,9 @@ ReshapeOpAdaptor adaptor(operands); MemRefDescriptor baseDesc(adaptor.src()); Location loc = reshapeOp->getLoc(); - auto desc = MemRefDescriptor::undef(rewriter, reshapeOp->getLoc(), - typeConverter->convertType(dstType)); + auto desc = + MemRefDescriptor::undef(rewriter, reshapeOp->getLoc(), + this->typeConverter->convertType(dstType)); desc.setAllocatedPtr(rewriter, loc, baseDesc.allocatedPtr(rewriter, loc)); desc.setAlignedPtr(rewriter, loc, baseDesc.alignedPtr(rewriter, loc)); desc.setOffset(rewriter, loc, baseDesc.offset(rewriter, loc)); @@ -149,7 +152,8 @@ /// Populate the given list with patterns that convert from Linalg to LLVM. void mlir::populateLinalgToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns) { - patterns.add( + patterns.add, + ReshapeOpConversion, YieldOpConversion>( converter); // Populate the type conversions for the linalg types. diff --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp --- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp +++ b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp @@ -191,7 +191,8 @@ target.addLegalDialect(); target.addLegalOp(); - target.addLegalOp(); + target.addLegalOp(); RewritePatternSet patterns(&getContext()); populateLinalgToStandardConversionPatterns(patterns); if (failed(applyFullConversion(module, target, std::move(patterns)))) diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -1188,16 +1188,20 @@ getIdentityExprs(resultTy.getShape().size())}; auto collapsedTy = RankedTensorType::get({totalElems}, elemTy); - Value collapsedOp = rewriter.create( + Value collapsedOp = rewriter.create( loc, collapsedTy, args[0], collapsingMap); - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( reshape, resultTy, collapsedOp, expandingMap); return success(); } - rewriter.replaceOpWithNewOp( - reshape, resultTy, args[0], reassociationMap); + if (resultTy.getRank() < args[0].getType().cast().getRank()) + rewriter.replaceOpWithNewOp( + reshape, resultTy, args[0], reassociationMap); + else + rewriter.replaceOpWithNewOp( + reshape, resultTy, args[0], reassociationMap); return success(); } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -882,13 +882,14 @@ } }; +template struct FoldInitTensorWithTensorReshapeOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp, PatternRewriter &rewriter) const override { - if (!reshapeOp.src().getDefiningOp()) + if (!reshapeOp.src().template getDefiningOp()) return failure(); Location loc = reshapeOp.getLoc(); SmallVector, 4> resultShapes; @@ -912,7 +913,9 @@ void InitTensorOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add, + FoldInitTensorWithTensorReshapeOp, ReplaceStaticShapeDims>(context); } @@ -1206,12 +1209,24 @@ p << ": " << op.src().getType() << " into " << op.getType(); } -static void print(OpAsmPrinter &p, linalg::ReshapeOp op) { - print(p, op); +static void print(OpAsmPrinter &p, linalg::ExpandShapeOp op) { + print(p, op); } -static void print(OpAsmPrinter &p, linalg::TensorReshapeOp op) { - print(p, op); +static void print(OpAsmPrinter &p, linalg::CollapseShapeOp op) { + print(p, op); +} + +static void print(OpAsmPrinter &p, linalg::TensorExpandShapeOp op) { + print(p, op); +} + +static void print(OpAsmPrinter &p, linalg::TensorCollapseShapeOp op) { + print(p, op); +} + +static constexpr StringRef getReassociationAttrName() { + return "reassociation"; } static ParseResult parseReshapeLikeOp(OpAsmParser &parser, @@ -1253,7 +1268,7 @@ break; } - result.addAttribute(ReshapeOp::getReassociationAttrName(), + result.addAttribute(getReassociationAttrName(), b.getArrayAttr(reassociation)); // Parse optional attributes. @@ -1334,36 +1349,10 @@ ShapedType srcReshapeSrcType = srcReshapeOp.getSrcType(); ShapedType intermediateType = reshapeOp.getSrcType(); ShapedType resultType = reshapeOp.getResultType(); - - auto areReshapeOpsFoldable = [](ShapedType largerType, - ShapedType intermediateType, - ShapedType smallerType) -> bool { - return largerType.getRank() > intermediateType.getRank() && - intermediateType.getRank() > smallerType.getRank(); - }; Optional> reassociationIndices = - llvm::None; - // Check if producer and consumer are both expanding dims or both collapsing - // dims. In this case, try to compose the affine maps. This works for - // dynamic shapes too. - if (areReshapeOpsFoldable(resultType, intermediateType, - srcReshapeSrcType) || - areReshapeOpsFoldable(srcReshapeSrcType, intermediateType, - resultType)) { - reassociationIndices = collapseReassociationIndices( - srcReshapeOp.getReassociationMaps(), reshapeOp.getReassociationMaps(), - rewriter.getContext()); - } - if (!reassociationIndices) { - // If the source reshape can be collapsed/expanded into the target reshape - // they can still be folded. This can only be reasoned about statically - // for cases where - // - either all shapes are static, or - // - The number of dynamic dimensions matches in the source of source and - // result with all other dimensions being 1. - reassociationIndices = - getReassociationIndicesForReshape(srcReshapeSrcType, resultType); - } + collapseReassociationIndices(srcReshapeOp.getReassociationMaps(), + reshapeOp.getReassociationMaps(), + rewriter.getContext()); if (!reassociationIndices) return failure(); rewriter.replaceOpWithNewOp( @@ -1371,15 +1360,55 @@ return success(); } }; + +/// Pattern to collapse producer/consumer reshape ops that are both collapsing +/// dimensions or are both expanding dimensions. +template +struct CollapseMixedReshapeOps : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp, + PatternRewriter &rewriter) const override { + auto srcReshapeOp = + reshapeOp.src().template getDefiningOp(); + if (!srcReshapeOp) + return failure(); + + ShapedType srcReshapeSrcType = srcReshapeOp.getSrcType(); + ShapedType intermediateType = reshapeOp.getSrcType(); + ShapedType resultType = reshapeOp.getResultType(); + + // If the source reshape can be collapsed/expanded into the target reshape + // they can still be folded. This can only be reasoned about statically + // for cases where + // - either all shapes are static, or + // - The number of dynamic dimensions matches in the source of source and + // result with all other dimensions being 1. + Optional> reassociationIndices = + getReassociationIndicesForReshape(srcReshapeSrcType, resultType); + if (!reassociationIndices) + return failure(); + bool originalOpExpands = + intermediateType.getRank() > srcReshapeSrcType.getRank(); + bool resultingOpExpands = + resultType.getRank() > srcReshapeSrcType.getRank(); + if (!(resultingOpExpands ^ originalOpExpands)) + rewriter.replaceOpWithNewOp( + reshapeOp, resultType, srcReshapeOp.src(), *reassociationIndices); + else + rewriter.replaceOpWithNewOp( + reshapeOp, resultType, srcReshapeOp.src(), *reassociationIndices); + return success(); + } +}; } // namespace -template +template static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp, ArrayRef operands) { // Fold producer-consumer reshape ops that where the operand type of the // producer is same as the return type of the consumer. - ReshapeOpTy reshapeSrcOp = - reshapeOp.src().template getDefiningOp(); + auto reshapeSrcOp = + reshapeOp.src().template getDefiningOp(); if (reshapeSrcOp && reshapeSrcOp.getSrcType() == reshapeOp.getResultType()) return reshapeSrcOp.src(); // Reshape of a constant can be replaced with a new constant. @@ -1564,20 +1593,38 @@ return reassociationMaps; } -SmallVector ReshapeOp::getReassociationMaps() { +SmallVector CollapseShapeOp::getReassociationMaps() { + return getSymbolLessAffineMaps(getReassociationExprs()); +} +SmallVector CollapseShapeOp::getReassociationExprs() { + OpBuilder b(this->getContext()); + return convertReassociationIndicesToExprs(b, getReassociationIndices()); +} +SmallVector ExpandShapeOp::getReassociationMaps() { + return getSymbolLessAffineMaps(getReassociationExprs()); +} +SmallVector ExpandShapeOp::getReassociationExprs() { + OpBuilder b(this->getContext()); + return convertReassociationIndicesToExprs(b, getReassociationIndices()); +} + +SmallVector TensorCollapseShapeOp::getReassociationMaps() { return getSymbolLessAffineMaps(getReassociationExprs()); } -SmallVector ReshapeOp::getReassociationExprs() { +SmallVector +TensorCollapseShapeOp::getReassociationExprs() { OpBuilder b(this->getContext()); return convertReassociationIndicesToExprs(b, getReassociationIndices()); } -SmallVector TensorReshapeOp::getReassociationMaps() { +SmallVector TensorExpandShapeOp::getReassociationMaps() { return getSymbolLessAffineMaps(getReassociationExprs()); } -SmallVector TensorReshapeOp::getReassociationExprs() { +SmallVector +TensorExpandShapeOp::getReassociationExprs() { OpBuilder b(this->getContext()); return convertReassociationIndicesToExprs(b, getReassociationIndices()); } + /// For reshape op compute the shape at dimension `dimIndex` of the output in /// terms of shape of the `src`, when the reshape op is a collapsing /// operation. It is the product of the shape of the collapsed dimensions of the @@ -1708,7 +1755,7 @@ return b.getArrayAttr(reassociationAttr); } -void mlir::linalg::ReshapeOp::build( +void mlir::linalg::ExpandShapeOp::build( OpBuilder &b, OperationState &result, Value src, ArrayRef reassociation, ArrayRef attrs) { @@ -1717,20 +1764,26 @@ memRefType, getSymbolLessAffineMaps( convertReassociationIndicesToExprs(b, reassociation))); build(b, result, resultType, src, attrs); - result.addAttribute(ReshapeOp::getReassociationAttrName(), + result.addAttribute(getReassociationAttrName(), getReassociationIndicesAttribute(b, reassociation)); } -void mlir::linalg::ReshapeOp::build( - OpBuilder &b, OperationState &result, Type resultType, Value src, +Value mlir::linalg::ExpandShapeOp::getViewSource() { return src(); } + +void mlir::linalg::CollapseShapeOp::build( + OpBuilder &b, OperationState &result, Value src, ArrayRef reassociation, ArrayRef attrs) { + auto memRefType = src.getType().cast(); + auto resultType = computeReshapeCollapsedType( + memRefType, getSymbolLessAffineMaps( + convertReassociationIndicesToExprs(b, reassociation))); build(b, result, resultType, src, attrs); - result.addAttribute(ReshapeOp::getReassociationAttrName(), + result.addAttribute(getReassociationAttrName(), getReassociationIndicesAttribute(b, reassociation)); } -Value mlir::linalg::ReshapeOp::getViewSource() { return src(); } +Value mlir::linalg::CollapseShapeOp::getViewSource() { return src(); } /// Verify that shapes of the reshaped types using following rules /// 1) if a dimension in the collapsed type is static, then the corresponding @@ -1785,18 +1838,17 @@ // Common verifier for reshape-like types. Fills `expandedType` and // `collapsedType` with the proper `src` or `result` type. -template -static LogicalResult verifyReshapeLikeTypes(Op op, T &expandedType, - T &collapsedType) { - expandedType = op.getSrcType(); - collapsedType = op.getResultType(); +template ::value || + std::is_same::value> +static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType, + T collapsedType) { unsigned expandedRank = expandedType.getRank(); unsigned collapsedRank = collapsedType.getRank(); - bool isCollapse = expandedRank > collapsedRank; - if (!isCollapse) { - std::swap(expandedRank, collapsedRank); - std::swap(expandedType, collapsedType); - } + if (expandedRank < collapsedRank) + return op.emitOpError("expected the type ") + << expandedType + << " to have higher rank than the type = " << collapsedType; if (expandedRank == 0) return op.emitOpError("expected non-zero memref ranks"); if (expandedRank == collapsedRank) @@ -1825,11 +1877,13 @@ if (!isReassociationValid(maps, &invalidIdx)) return op.emitOpError("expected reassociation map #") << invalidIdx << " to be valid and contiguous"; - return verifyReshapeLikeShapes(op, collapsedType, expandedType, !isCollapse); + return verifyReshapeLikeShapes(op, collapsedType, expandedType, isExpansion); } -static LogicalResult verify(ReshapeOp op) { - MemRefType expandedType, collapsedType; +template +static LogicalResult verifyReshapeOp(TensorReshapeOp op, + MemRefType expandedType, + MemRefType collapsedType) { if (failed(verifyReshapeLikeTypes(op, expandedType, collapsedType))) return failure(); auto maps = op.getReassociationMaps(); @@ -1840,9 +1894,24 @@ return success(); } -void ReshapeOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add>(context); +static LogicalResult verify(ExpandShapeOp op) { + return verifyReshapeOp(op, op.getResultType(), op.getSrcType()); +} + +void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add, + CollapseMixedReshapeOps>(context); +} + +static LogicalResult verify(CollapseShapeOp op) { + return verifyReshapeOp(op, op.getSrcType(), op.getResultType()); +} + +void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add, + CollapseMixedReshapeOps>(context); } //===----------------------------------------------------------------------===// @@ -1877,7 +1946,7 @@ return RankedTensorType::get(newShape, type.getElementType()); } -void mlir::linalg::TensorReshapeOp::build( +void mlir::linalg::TensorCollapseShapeOp::build( OpBuilder &b, OperationState &result, Value src, ArrayRef reassociation, ArrayRef attrs) { @@ -1886,21 +1955,27 @@ getSymbolLessAffineMaps( convertReassociationIndicesToExprs(b, reassociation))); build(b, result, resultType, src, attrs); - result.addAttribute(ReshapeOp::getReassociationAttrName(), + result.addAttribute(getReassociationAttrName(), getReassociationIndicesAttribute(b, reassociation)); } -void mlir::linalg::TensorReshapeOp::build( - OpBuilder &b, OperationState &result, Type resultType, Value src, +void mlir::linalg::TensorExpandShapeOp::build( + OpBuilder &b, OperationState &result, Value src, ArrayRef reassociation, ArrayRef attrs) { + auto resultType = computeTensorReshapeCollapsedType( + src.getType().cast(), + getSymbolLessAffineMaps( + convertReassociationIndicesToExprs(b, reassociation))); build(b, result, resultType, src, attrs); - result.addAttribute(ReshapeOp::getReassociationAttrName(), + result.addAttribute(getReassociationAttrName(), getReassociationIndicesAttribute(b, reassociation)); } -static LogicalResult verify(TensorReshapeOp op) { - RankedTensorType expandedType, collapsedType; +template +static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op, + RankedTensorType expandedType, + RankedTensorType collapsedType) { if (failed(verifyReshapeLikeTypes(op, expandedType, collapsedType))) return failure(); @@ -1913,9 +1988,18 @@ return success(); } +static LogicalResult verify(TensorExpandShapeOp op) { + return verifyTensorReshapeOp(op, op.getResultType(), op.getSrcType()); +} + +static LogicalResult verify(TensorCollapseShapeOp op) { + return verifyTensorReshapeOp(op, op.getSrcType(), op.getResultType()); +} + namespace { /// Reshape of a splat constant can be replaced with a constant of the result /// type. +template struct FoldReshapeWithConstant : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp, @@ -1936,11 +2020,12 @@ /// /// For such op chains, we can create new linalg.fill ops with the result /// type of the linalg.tensor_reshape op. +template struct FoldFillWithTensorReshape : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp, PatternRewriter &rewriter) const override { - auto oldFill = reshapeOp.src().getDefiningOp(); + auto oldFill = reshapeOp.src().template getDefiningOp(); if (!oldFill) return failure(); @@ -1955,14 +2040,38 @@ }; } // namespace -void TensorReshapeOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add, FoldFillWithTensorReshape, - FoldInitTensorWithTensorReshapeOp, FoldReshapeWithConstant>( - context); +void TensorExpandShapeOp::getCanonicalizationPatterns( + RewritePatternSet &results, MLIRContext *context) { + results + .add, + CollapseMixedReshapeOps, + FoldFillWithTensorReshape, + FoldInitTensorWithTensorReshapeOp, + FoldReshapeWithConstant>(context); +} + +void TensorCollapseShapeOp::getCanonicalizationPatterns( + RewritePatternSet &results, MLIRContext *context) { + results + .add, + CollapseMixedReshapeOps, + FoldFillWithTensorReshape, + FoldInitTensorWithTensorReshapeOp, + FoldReshapeWithConstant>(context); +} + +LogicalResult TensorExpandShapeOp::reifyReturnTypeShapesPerResultDim( + OpBuilder &b, SmallVectorImpl> &reifiedReturnShapes) { + auto resultShape = + getAsValues(b, getLoc(), + getReshapeOutputShapeFromInputShape( + b, getLoc(), src(), getResultType().getShape(), + getReassociationMaps())); + reifiedReturnShapes.emplace_back(std::move(resultShape)); + return success(); } -LogicalResult TensorReshapeOp::reifyReturnTypeShapesPerResultDim( +LogicalResult TensorCollapseShapeOp::reifyReturnTypeShapesPerResultDim( OpBuilder &b, SmallVectorImpl> &reifiedReturnShapes) { auto resultShape = getAsValues(b, getLoc(), @@ -2753,13 +2862,23 @@ // TODO: Consider making all this boilerplate easy to autogenerate // with Tablegen. This seems a desirable property in the context of // OpInterfaces where a Linalg "named" op **isa** LinalgOp. -OpFoldResult ReshapeOp::fold(ArrayRef operands) { +OpFoldResult ExpandShapeOp::fold(ArrayRef operands) { if (succeeded(foldMemRefCast(*this))) return getResult(); - return foldReshapeOp(*this, operands); + return foldReshapeOp(*this, operands); +} +OpFoldResult CollapseShapeOp::fold(ArrayRef operands) { + if (succeeded(foldMemRefCast(*this))) + return getResult(); + return foldReshapeOp(*this, operands); +} +OpFoldResult TensorExpandShapeOp::fold(ArrayRef operands) { + return foldReshapeOp(*this, + operands); } -OpFoldResult TensorReshapeOp::fold(ArrayRef operands) { - return foldReshapeOp(*this, operands); +OpFoldResult TensorCollapseShapeOp::fold(ArrayRef operands) { + return foldReshapeOp(*this, + operands); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp @@ -149,17 +149,25 @@ /// Conversion pattern that replaces `linalg.tensor_reshape` with /// `linalg.reshape`. +template class BufferizeTensorReshapeOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; + using ReshapeOp = typename std::conditional_t< + std::is_same::value, ExpandShapeOp, + CollapseShapeOp>; LogicalResult matchAndRewrite(TensorReshapeOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { - linalg::TensorReshapeOpAdaptor adaptor(operands, op->getAttrDictionary()); - rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()).cast(), - adaptor.src(), adaptor.reassociation()); + Adaptor adaptor(operands, op->getAttrDictionary()); + rewriter.replaceOpWithNewOp(op, + this->getTypeConverter() + ->convertType(op.getType()) + .template cast(), + adaptor.src(), + adaptor.reassociation()); return success(); } }; @@ -348,7 +356,8 @@ BufferizeAnyLinalgOp, BufferizeFillOp, BufferizeInitTensorOp, - BufferizeTensorReshapeOp, + BufferizeTensorReshapeOp, + BufferizeTensorReshapeOp, SubTensorOpConverter, SubTensorInsertOpConverter >(typeConverter, patterns.getContext()); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp @@ -31,7 +31,7 @@ // FromElementsOp results in a tensor<1xdtype>, we need to reshape that to // a tensor instead. - return builder.create( + return builder.create( loc, type, createNewTensorOp, ArrayRef{}); } @@ -159,8 +159,8 @@ /// Canonicalizes the pattern of the form /// /// %tensor = tensor.from_elements(%element) : (i32) -> tensor<1xi32> -/// %reshaped_tensor = linalg.tensor_reshape %tensor [] : tensor<1xi32> into -/// tensor +/// %reshaped_tensor = linalg.tensor_collapse_shape %tensor [] +/// : tensor<1xi32> into tensor /// %extracted_element = tensor.extract %reshaped_tensor[] : tensor /// /// to just %element. @@ -170,10 +170,11 @@ LogicalResult matchAndRewrite(tensor::ExtractOp extract, PatternRewriter &rewriter) const final { - if (extract.indices().size() != 0) + if (!extract.indices().empty()) return failure(); - auto tensorReshape = extract.tensor().getDefiningOp(); + auto tensorReshape = + extract.tensor().getDefiningOp(); if (tensorReshape == nullptr) return failure(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -362,10 +362,11 @@ for (auto operand : llvm::enumerate(values)) { if (operand.value().getType() == newInputOutputTypes[flattenedIdx]) res.push_back(operand.value()); - else - res.push_back(rewriter.create( + else { + res.push_back(rewriter.create( loc, newInputOutputTypes[flattenedIdx], operand.value(), convertAffineMapArrayToExprs(reassociationMaps[flattenedIdx]))); + } ++flattenedIdx; } return res; @@ -395,11 +396,11 @@ RankedTensorType origResultType = genericOp.getResult(result.index()) .getType() .template cast(); - if (origResultType != result.value().getType()) - resultReplacements.push_back(rewriter.create( + if (origResultType != result.value().getType()) { + resultReplacements.push_back(rewriter.create( loc, origResultType, result.value(), convertAffineMapArrayToExprs(reassociationMaps[index]))); - else + } else resultReplacements.push_back(result.value()); } rewriter.replaceOp(genericOp, resultReplacements); @@ -460,8 +461,8 @@ Location loc = subTensorOp.getLoc(); Value newSubTensor = rewriter.create( loc, rankReducedType, subTensorOp.source(), offsets, sizes, strides); - rewriter.replaceOpWithNewOp(subTensorOp, resultType, - newSubTensor, *reassociation); + rewriter.replaceOpWithNewOp( + subTensorOp, resultType, newSubTensor, *reassociation); return success(); } }; @@ -482,7 +483,7 @@ reassociation->size() == static_cast(sourceType.getRank())) return failure(); Location loc = insertOp.getLoc(); - auto reshapedSource = rewriter.create( + auto reshapedSource = rewriter.create( loc, insertOp.source(), *reassociation); rewriter.replaceOpWithNewOp( insertOp, reshapedSource, insertOp.dest(), insertOp.getMixedOffsets(), @@ -500,7 +501,8 @@ patterns.add( context); - TensorReshapeOp::getCanonicalizationPatterns(patterns, context); + TensorCollapseShapeOp::getCanonicalizationPatterns(patterns, context); + TensorExpandShapeOp::getCanonicalizationPatterns(patterns, context); } namespace { diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -306,8 +306,7 @@ /// with output index_map `affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>` /// /// and reshape: -/// %1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>, -/// affine_map<(i, j, k, l) -> (j, k, l)>] : +/// %1 = linalg.tensor_collapse_shape %0 [[0], [0, 1, 2]] : /// tensor into tensor /// /// would be rewritten into: @@ -348,24 +347,21 @@ resultExprs, context); } -/// Checks if the `reshapeOp` can be fused with it consumer (if `asProducer` is -/// true) or its producer (if `asProducer` is false) given the indexing map at -/// its use. -static bool isTensorReshapeOpFoldableByLinearization(TensorReshapeOp reshapeOp, - AffineMap useIndexMap, - bool asProducer) { - RankedTensorType returnType = reshapeOp.getResultType(); - RankedTensorType operandType = reshapeOp.getSrcType(); - // Reshape is fusable with its consumer (i.e. reshape as a producer) when its - // operand is of lesser rank than the result. Fusing when operand has higher - // rank will require use of mods and divs in the indexing maps of the fused op - // which would make it non-invertible. Similarly reshape is fused with its - // producer (i.e. reshape as consumer) only if the return type has lesser - // rank. - if ((asProducer && reshapeOp.getSrcType().hasStaticShape() && - returnType.getRank() < operandType.getRank()) || - (!asProducer && reshapeOp.getResultType().hasStaticShape() && - operandType.getRank() < returnType.getRank())) +// TensorExpandShapeOp is fusable with its consumer (i.e. reshape as a +// producer). Fusing when operand has higher rank will require use of mods and +// divs in the indexing maps of the fused op which would make it non-invertible. +static bool isTensorReshapeOpFoldableByLinearization( + TensorExpandShapeOp expandOp, AffineMap useIndexMap, bool asProducer) { + if (!asProducer && expandOp.getResultType().hasStaticShape()) + return false; + return useIndexMap.isPermutation(); +} + +// TensorCollapseShapeOp is fusable with its producer (i.e. reshape as a +// consumer). +static bool isTensorReshapeOpFoldableByLinearization( + TensorCollapseShapeOp collapseOp, AffineMap useIndexMap, bool asProducer) { + if (asProducer && collapseOp.getSrcType().hasStaticShape()) return false; return useIndexMap.isPermutation(); } @@ -398,17 +394,14 @@ /// indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>, /// affine_map<(d0, d1, d2) -> (d1, d2)>, /// affine_map<(d0, d1, d2) -> (d0, d2, d1)>] -/// %d = linalg.tensor_reshape %c -/// [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1)>, -/// affine_map<(d0, d1, d2, d3, d4, d5) -> (d2)>, -/// affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>] +/// %d = linalg.tensor_expand_shape %c [[0, 1], [2], [3, 4, 5]] /// : tensor into tensor /// /// The reshape can be folded into the `genericOp` if its loop dimensionality -/// is increased to match the result (operand) of the tensor_reshape when the -/// reshape is expanding (folding). The indexing_map of the fused tensor in the -/// `genericOp` and the reassociation map helps compute the indexing maps of -/// the modified op. For the above example, based on the reassociation map it +/// is increased to match the result (operand) of the tensor_expand_shape. +/// The indexing_map of the fused tensor in the `genericOp` and the +/// reassociation map helps compute the indexing maps of the modified op. +/// For the above example, based on the reassociation map it /// can be concluded that /// /// - The loop used to access the first dimension of the fused tensor is split @@ -436,14 +429,9 @@ /// Since operands to the linalg generic are now 5D, reshapes can be introduced /// to make it consistent /// -/// %0 = linalg.tensor_reshape %a -/// [affine_map<(e0, e1, e2, e3, e4, e5) -> (e0, e1, e2), -/// affine_map<(e0, e1, e2, e3, e4, e5) -> (e3, e4), -/// affine_map<(e0, e1, e2, e3, e4, e5) -> (e5)] +/// %0 = linalg.tensor_expand_shape %a [[0, 1, 2], [3, 4], [5]] /// : tensor into tensor -/// %1 = linalg.tensor_reshape %b -/// [affine_map<(e0, e1, e2, e3) -> (e0, e1, e2), -/// affine_map<(e0, e1, e2, e3) -> (e3)] +/// %1 = linalg.tensor_expand_shape %b [[0, 1, 2], [3]] /// : tensor into tensor /// /// The added reshapes are again expanding patterns, so they will get fused @@ -607,11 +595,12 @@ return RankedTensorType::get(expandedShape, originalType.getElementType()); } -/// Returns the reassociation maps to use in the `linalg.tensor_reshape` -/// operation to convert the operands of the origial operation to operands of +/// Returns the reassociation maps to use in the `linalg.tensor_expand_shape` +/// operation to convert the operands of the original operation to operands of /// the expanded operation. The same method is used to compute the -/// `linalg.tensor_reshape` used to collapse the result of the expanded op to -/// get the value that can replace all uses of the results of the original op. +/// `linalg.tensor_collapse_shape` used to collapse the result of the expanded +/// op to get the value that can replace all uses of the results of the original +/// op. static SmallVector getReassociationForExpansion(AffineMap indexingMap, const ExpansionInfo &expansionInfo) { @@ -671,25 +660,29 @@ } } -/// Implements the fusion of a tensor_reshape op and a generic op as explained -/// in `isFusableWithReshapeByExpansion`. Assumes that those conditions have -/// been satisfied. +/// Implements the fusion of a tensor_collapse_shape or a tensor_expand_shape op +/// and a generic op as explained in `isFusableWithReshapeByExpansion`. Assumes +/// that those conditions have been satisfied. static Optional> -fuseWithReshapeByExpansion(GenericOp genericOp, TensorReshapeOp reshapeOp, +fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp, unsigned fusedTensorIndex, PatternRewriter &rewriter) { assert(isFusableWithReshapeByDimExpansion(genericOp, fusedTensorIndex) && "preconditions for fuse operation failed"); // Check if reshape is expanding or collapsing. - bool isExpanding = - reshapeOp.getSrcType().getRank() < reshapeOp.getResultType().getRank(); - RankedTensorType expandedType = - isExpanding ? reshapeOp.getResultType() : reshapeOp.getSrcType(); + auto expandingReshapeOp = dyn_cast(*reshapeOp); + auto collapsingReshapeOp = dyn_cast(*reshapeOp); + bool isExpanding = (expandingReshapeOp != nullptr); + RankedTensorType expandedType = isExpanding + ? expandingReshapeOp.getResultType() + : collapsingReshapeOp.getSrcType(); ExpansionInfo expansionInfo; - if (failed(expansionInfo.compute(genericOp, fusedTensorIndex, - reshapeOp.getReassociationMaps(), - expandedType.getShape(), rewriter))) + if (failed(expansionInfo.compute( + genericOp, fusedTensorIndex, + isExpanding ? expandingReshapeOp.getReassociationMaps() + : collapsingReshapeOp.getReassociationMaps(), + expandedType.getShape(), rewriter))) return llvm::None; if (failed(isGenericOpExpandable(genericOp, expansionInfo, rewriter))) @@ -703,7 +696,8 @@ SmallVector expandedOpOperands; for (auto operand : llvm::enumerate(genericOp.getInputs())) { if (operand.index() == fusedTensorIndex) { - expandedOpOperands.push_back(reshapeOp.src()); + expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.src() + : collapsingReshapeOp.src()); continue; } AffineMap indexingMap = genericOp.getInputIndexingMap(operand.index()); @@ -714,7 +708,7 @@ // Reshape the operand to get the right type. SmallVector reassociation = getReassociationForExpansion(indexingMap, expansionInfo); - expandedOpOperands.push_back(rewriter.create( + expandedOpOperands.push_back(rewriter.create( genericOp.getLoc(), expandedOperandType, operand.value(), reassociation)); continue; @@ -732,7 +726,7 @@ if (expandedOutputType != result.value().getType()) { SmallVector reassociation = getReassociationForExpansion(indexingMap, expansionInfo); - outputs.push_back(rewriter.create( + outputs.push_back(rewriter.create( genericOp.getLoc(), expandedOutputType, result.value(), reassociation)); } @@ -763,7 +757,7 @@ SmallVector reassociation = getReassociationForExpansion( genericOp.getOutputIndexingMap(result.index()), expansionInfo); - resultVals.push_back(rewriter.create( + resultVals.push_back(rewriter.create( genericOp.getLoc(), result.value().getType(), fusedOp->getResult(result.index()), reassociation)); } else { @@ -776,18 +770,15 @@ namespace { -/// Pattern to fold tensor_reshape op with its consumer by using the source of -/// the reshape op as the operand in the consumer (instead of the result of the -/// tensor_reshapeop) when the tensor_reshape op is collapsing. The -/// corresponding index map in the consumer needs to be modified to linearize -/// the folded dimension. +/// Pattern to fold tensor_expand_shape op with its consumer by using the source +/// of the reshape op as the operand in the consumer (instead of the result of +/// the tensor_collapse_shape). The corresponding index map in the consumer +/// needs to be modified to linearize the folded dimension. /// /// For example, /// /// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -/// %0 = linalg.tensor_reshape %arg0 -/// [affine_map<(i, j, k, l) -> (i)>, affine_map<(i, j, k, l) -> (j, k)>, -/// affine_map<(i, j, k, l) -> (l)>] +/// %0 = linalg.tensor_expand_shape %arg0 [[0], [1, 2], [3]] /// tensor into tensor /// %1 = linalg.generic { indexing_maps = [#map0, #map0, #map0], ... } /// ins(%0, %arg1 : tensor, tensor) ... @@ -800,7 +791,7 @@ /// %0 = linalg.generic { indexing_maps = [#map0, #map1, #map1] ... } /// ins(%arg0, %arg1 : tensor, tensor) ... /// -> tensor -template +template struct FoldProducerReshapeOpByLinearization : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -810,20 +801,25 @@ if (!genericOp.hasTensorSemantics()) return failure(); for (auto operand : llvm::enumerate(genericOp.getInputs())) { - TensorReshapeOp reshapeOp = - operand.value().getDefiningOp(); - if (!reshapeOp || - !isTensorReshapeOpFoldableByLinearization( + auto reshapeOp = operand.value().getDefiningOp(); + if (!reshapeOp) + continue; + + Value src = reshapeOp.src(); + RankedTensorType operandType = reshapeOp.getSrcType(); + RankedTensorType returnType = reshapeOp.getResultType(); + + if (!isTensorReshapeOpFoldableByLinearization( reshapeOp, genericOp.getInputIndexingMap(operand.index()), /*asProducer =*/true) || (foldUnitDimReshapesOnly && - !isUnitDimExpansionOnly(reshapeOp.getResultType().getShape(), + !isUnitDimExpansionOnly(returnType.getShape(), reshapeOp.getReassociationMaps()))) continue; // Compute the fused operands list, SmallVector fusedOperands(genericOp.getInputs()); - fusedOperands[operand.index()] = reshapeOp.src(); + fusedOperands[operand.index()] = src; fusedOperands.append(genericOp.getOutputs().begin(), genericOp.getOutputs().end()); @@ -836,9 +832,8 @@ auto invMap = inversePermutation(fusedIndexMaps[operand.index()]); // Compute the indexing map to use for the result of the producer. - AffineMap modifiedMap = - linearizeCollapsedDims(invMap, reshapeOp.getResultType().getShape(), - reshapeOp.getReassociationMaps()); + AffineMap modifiedMap = linearizeCollapsedDims( + invMap, returnType.getShape(), reshapeOp.getReassociationMaps()); for (AffineExpr expr : modifiedMap.getResults()) { if (!expr.isPureAffine()) return failure(); @@ -884,8 +879,7 @@ /// /// For example, /// -/// %0 = linalg.tensor_reshape %A [ -/// affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>] +/// %0 = linalg.tensor_expand_shape %A [[0, 1], [2]] /// : tensor<12544x16xf32> into tensor<112x112x16xf32> /// %2 = linalg.generic {indexing_maps = [ /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, @@ -903,8 +897,7 @@ /// iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 /// : tensor<12544x16xf32>, tensor<16xf32>) outs(%1 : tensor<12544x16xf32>) { /// } -> tensor<12544x16xf32> -/// %3 = linalg.tensor_reshape %2 [ -/// #affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>] +/// %3 = linalg.tensor_expand_shape %2 [[0, 1], [2]] /// : tensor<12544x16xf32> into tensor<112x112x16xf32> struct PushExpandingReshape : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -923,16 +916,14 @@ int64_t destRank = genericOp.getNumParallelLoops(); SmallVector newOperands = llvm::to_vector<4>(genericOp.getInputs()); - TensorReshapeOp reshapeFound; - // 1. Look for tensor_reshape operands and figure out save the dimensions - // merged. + TensorExpandShapeOp reshapeFound; + // 1. Look for tensor_expand_shape operands and figure out save the + // dimensions merged. for (auto operand : llvm::enumerate(genericOp.getInputs())) { - TensorReshapeOp reshapeOp = - operand.value().template getDefiningOp(); - if (!reshapeOp || reshapeOp.getSrcType().getRank() > - reshapeOp.getResultType().getRank()) { + auto reshapeOp = + operand.value().template getDefiningOp(); + if (!reshapeOp) continue; - } // TODO: We could support non-identity map as long as the merged // dimensions are still contiguous. if (!genericOp.getIndexingMaps()[operand.index()].isIdentity()) @@ -997,7 +988,7 @@ auto newOutputType = RankedTensorType::get( reshapeFound.getSrcType().getShape(), output.getType().template cast().getElementType()); - Value newOutput = rewriter.create( + Value newOutput = rewriter.create( genericOp->getLoc(), newOutputType, output, reassociation); newOutputTypes.push_back(newOutputType); newOutputs.push_back(newOutput); @@ -1013,7 +1004,7 @@ // 6. Reshape the so that the type matches the uses. SmallVector newResults; for (auto result : llvm::enumerate(newOp->getResults())) { - newResults.push_back(rewriter.create( + newResults.push_back(rewriter.create( genericOp->getLoc(), genericOp.getOutputTensorTypes()[result.index()], result.value(), reassociation)); } @@ -1022,9 +1013,9 @@ } }; -/// Pattern to fuse a tensor_reshape op with its consumer generic op, when the -/// reshape op is collapsing dimensions. The dimensionality of the loop in the -/// consumer is expanded. +/// Pattern to fuse a tensor_collapse_shape op with its consumer generic op, +/// when the reshape op is collapsing dimensions. The dimensionality of the loop +/// in the consumer is expanded. class FoldWithProducerReshapeOpByExpansion : public OpRewritePattern { public: @@ -1037,16 +1028,14 @@ LogicalResult matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const override { for (auto operand : llvm::enumerate(genericOp.getInputs())) { - TensorReshapeOp reshapeOp = - operand.value().getDefiningOp(); + TensorCollapseShapeOp reshapeOp = + operand.value().getDefiningOp(); if (!reshapeOp) continue; // Fold only if // - The tensor reshape op is folding. // - All constraints of fusing with reshape by expansion are met. - if (reshapeOp.getSrcType().getRank() < - reshapeOp.getResultType().getRank() || - !isFusableWithReshapeByDimExpansion(genericOp, operand.index()) || + if (!isFusableWithReshapeByDimExpansion(genericOp, operand.index()) || (!controlFoldingReshapes( reshapeOp->getResult(0), genericOp.getInputOpOperands()[operand.index()]))) @@ -1067,16 +1056,17 @@ ControlElementwiseOpsFusionFn controlFoldingReshapes; }; -/// Pattern to fold tensor_reshape op with its producer. The corresponding index -/// map in the consumer needs to be modified to linearize the folded dimension. -template +/// Pattern to fold tensor_collapse_shape or tensor_expand_shape op with its +/// producer. The corresponding index map in the consumer needs to be modified +/// to linearize the folded dimension. +template struct FoldConsumerReshapeOpByLinearization : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp, PatternRewriter &rewriter) const override { - GenericOp producer = reshapeOp.src().getDefiningOp(); + GenericOp producer = reshapeOp.src().template getDefiningOp(); if (!producer || !producer.hasTensorSemantics() || producer.getNumOutputs() != 1 || !isTensorReshapeOpFoldableByLinearization( @@ -1131,19 +1121,14 @@ } }; -/// Pattern to fold a tensor_reshape op with its producer generic op if the -/// tensor_reshape op is expanding, by expanding the dimensionality of the loop -/// in the producer op. +/// Pattern to fold a tensor_expand_shape op with its producer generic op +/// by expanding the dimensionality of the loop in the producer op. struct FoldReshapeWithGenericOpByExpansion - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp, + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(TensorExpandShapeOp reshapeOp, PatternRewriter &rewriter) const override { - // Fold only if - // - The tensor reshape op is a expanding case. - // - All constraints of fusing with reshape by expansion are met. - if (reshapeOp.getSrcType().getRank() > reshapeOp.getResultType().getRank()) - return failure(); + // Fold only if all constraints of fusing with reshape by expansion are met. GenericOp producer = reshapeOp.src().getDefiningOp(); if (!producer || producer.getNumOutputs() != 1 || !isFusableWithReshapeByDimExpansion(producer, @@ -1245,9 +1230,14 @@ bool mlir::linalg::skipUnitDimReshape(const OpResult &producer, const OpOperand &consumer) { - auto reshapeOp = producer.getDefiningOp(); - return !isUnitDimExpansionOnly(reshapeOp.getSrcType().getShape(), - reshapeOp.getReassociationMaps()); + auto expandShapeOp = producer.getDefiningOp(); + if (expandShapeOp) + return !isUnitDimExpansionOnly(expandShapeOp.getSrcType().getShape(), + expandShapeOp.getReassociationMaps()); + auto collapseShapeOp = + producer.getDefiningOp(); + return !isUnitDimExpansionOnly(collapseShapeOp.getSrcType().getShape(), + collapseShapeOp.getReassociationMaps()); } namespace { @@ -1360,16 +1350,22 @@ void mlir::linalg::populateFoldReshapeOpsByLinearizationPatterns( RewritePatternSet &patterns) { - patterns.add, - FoldConsumerReshapeOpByLinearization>( - patterns.getContext()); + patterns + .add, + FoldProducerReshapeOpByLinearization, + FoldConsumerReshapeOpByLinearization, + FoldConsumerReshapeOpByLinearization>( + patterns.getContext()); } void mlir::linalg::populateFoldUnitDimsReshapeOpsByLinearizationPatterns( RewritePatternSet &patterns) { - patterns.add, - FoldConsumerReshapeOpByLinearization>( - patterns.getContext()); + patterns + .add, + FoldProducerReshapeOpByLinearization, + FoldConsumerReshapeOpByLinearization, + FoldConsumerReshapeOpByLinearization>( + patterns.getContext()); } void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns( @@ -1391,7 +1387,8 @@ AffineApplyOp::getCanonicalizationPatterns(patterns, context); GenericOp::getCanonicalizationPatterns(patterns, context); IndexedGenericOp::getCanonicalizationPatterns(patterns, context); - TensorReshapeOp::getCanonicalizationPatterns(patterns, context); + TensorExpandShapeOp::getCanonicalizationPatterns(patterns, context); + TensorCollapseShapeOp::getCanonicalizationPatterns(patterns, context); } void mlir::linalg::populatePushReshapeOpsPatterns(RewritePatternSet &patterns) { diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -61,7 +61,7 @@ // CHECK-LABEL: @test_broadcast func @test_broadcast(%arg0: tensor<1xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> { // CHECK: [[INIT:%.+]] = linalg.init_tensor [2] : tensor<2xf32> - // CHECK: [[RESHAPE:%.+]] = linalg.tensor_reshape %arg0 + // CHECK: [[RESHAPE:%.+]] = linalg.tensor_collapse_shape %arg0 // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel"]} ins([[RESHAPE]], %arg1 : tensor, tensor<2xf32>) outs([[INIT]] : tensor<2xf32>) { // CHECK: ^bb0(%arg2: f32, %arg3: f32, %arg4: f32): // CHECK: [[ELEMENT:%.+]] = addf %arg2, %arg3 : f32 @@ -79,7 +79,7 @@ // CHECK-LABEL: @test_broadcast_swapped_args func @test_broadcast_swapped_args(%arg0: tensor<2xf32>, %arg1: tensor<1xf32>) -> tensor<2xf32> { // CHECK: [[INIT:%.+]] = linalg.init_tensor [2] : tensor<2xf32> - // CHECK: [[RESHAPE:%.+]] = linalg.tensor_reshape %arg1 + // CHECK: [[RESHAPE:%.+]] = linalg.tensor_collapse_shape %arg1 // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%arg0, [[RESHAPE]] : tensor<2xf32>, tensor) outs([[INIT]] : tensor<2xf32>) { // CHECK: ^bb0(%arg2: f32, %arg3: f32, %arg4: f32): // CHECK: [[ELEMENT:%.+]] = addf %arg2, %arg3 : f32 @@ -98,8 +98,8 @@ // CHECK-LABEL: @test_multibroadcast func @test_multibroadcast(%arg0: tensor<1x3xf32>, %arg1: tensor<2x1xf32>) -> tensor<2x3xf32> { // CHECK: [[INIT:%.+]] = linalg.init_tensor [2, 3] : tensor<2x3xf32> - // CHECK: [[RESHAPE1:%.+]] = linalg.tensor_reshape %arg0 {{\[}}[0, 1]] - // CHECK: [[RESHAPE2:%.+]] = linalg.tensor_reshape %arg1 {{\[}}[0, 1]] + // CHECK: [[RESHAPE1:%.+]] = linalg.tensor_collapse_shape %arg0 {{\[}}[0, 1]] + // CHECK: [[RESHAPE2:%.+]] = linalg.tensor_collapse_shape %arg1 {{\[}}[0, 1]] // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins([[RESHAPE1]], [[RESHAPE2]] : tensor<3xf32>, tensor<2xf32>) outs([[INIT]] : tensor<2x3xf32>) { // CHECK: ^bb0(%arg2: f32, %arg3: f32, %arg4: f32): // CHECK: [[ELEMENT:%.+]] = addf %arg2, %arg3 : f32 @@ -467,7 +467,7 @@ // CHECK-LABEL: @test_reshape_downrank func @test_reshape_downrank(%arg0: tensor<2x3xf32>) -> tensor<6xf32> { - // CHECK: [[RESHAPE:%.+]] = linalg.tensor_reshape %arg0 {{\[}}[0, 1]] + // CHECK: [[RESHAPE:%.+]] = linalg.tensor_collapse_shape %arg0 {{\[}}[0, 1]] %0 = "tosa.reshape"(%arg0) {new_shape = [6]} : (tensor<2x3xf32>) -> tensor<6xf32> // CHECK: return [[RESHAPE]] return %0 : tensor<6xf32> @@ -477,7 +477,7 @@ // CHECK-LABEL: @test_reshape_uprank func @test_reshape_uprank(%arg0: tensor<6xf32>) -> tensor<2x3xf32> { - // CHECK: [[RESHAPE:%.+]] = linalg.tensor_reshape %arg0 {{\[}}[0, 1]] + // CHECK: [[RESHAPE:%.+]] = linalg.tensor_expand_shape %arg0 {{\[}}[0, 1]] %0 = "tosa.reshape"(%arg0) {new_shape = [2, 3]} : (tensor<6xf32>) -> tensor<2x3xf32> // CHECK: return [[RESHAPE]] return %0 : tensor<2x3xf32> @@ -488,8 +488,8 @@ // CHECK-LABEL: @test_reshape_samerank func @test_reshape_samerank(%arg0: tensor<3x2xf32>) -> tensor<2x3xf32> { // CHECK-SAME: (%[[ARG0:.*]]: tensor<3x2xf32>) - // CHECK-NEXT: %[[RESHAPE1:.*]] = linalg.tensor_reshape %[[ARG0]] {{\[}}[0, 1]] - // CHECK-NEXT: %[[RESHAPE2:.*]] = linalg.tensor_reshape %[[RESHAPE1]] {{\[}}[0, 1]] + // CHECK-NEXT: %[[RESHAPE1:.*]] = linalg.tensor_collapse_shape %[[ARG0]] {{\[}}[0, 1]] + // CHECK-NEXT: %[[RESHAPE2:.*]] = linalg.tensor_expand_shape %[[RESHAPE1]] {{\[}}[0, 1]] %0 = "tosa.reshape"(%arg0) {new_shape = [2, 3]} : (tensor<3x2xf32>) -> tensor<2x3xf32> // CHECK-NEXT: return %[[RESHAPE2]] return %0 : tensor<2x3xf32> @@ -499,7 +499,7 @@ // CHECK-LABEL: @test_reshape_downrank_6D func @test_reshape_downrank_6D(%arg0: tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32> { - // CHECK: linalg.tensor_reshape %arg0 {{\[}}[0, 1, 2], [3], [4, 5]] + // CHECK: linalg.tensor_collapse_shape %arg0 {{\[}}[0, 1, 2], [3], [4, 5]] %0 = "tosa.reshape"(%arg0) {new_shape = [2, 3]} : (tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32> return %0 : tensor<6x5x77xf32> } @@ -549,7 +549,7 @@ // CHECK: ^bb0(%arg1: f32, %arg2: f32) // CHECK: [[RES:%.+]] = addf %arg1, %arg2 : f32 // CHECK: linalg.yield [[RES]] : f32 - // CHECK: linalg.tensor_reshape [[GENERIC]] {{\[}}[0, 1]] : tensor<4xf32> into tensor<1x4xf32> + // CHECK: linalg.tensor_expand_shape [[GENERIC]] {{\[}}[0, 1]] : tensor<4xf32> into tensor<1x4xf32> %0 = "tosa.reduce_sum"(%arg0) {axis = 0 : i64} : (tensor<5x4xf32>) -> tensor<1x4xf32> // CHECK: [[INIT:%.+]] = linalg.init_tensor [5] @@ -559,7 +559,7 @@ // CHECK: ^bb0(%arg1: f32, %arg2: f32) // CHECK: [[RES:%.+]] = addf %arg1, %arg2 : f32 // CHECK: linalg.yield [[RES]] : f32 - // CHECK: linalg.tensor_reshape [[GENERIC]] {{\[}}[0, 1]] : tensor<5xf32> into tensor<5x1xf32> + // CHECK: linalg.tensor_expand_shape [[GENERIC]] {{\[}}[0, 1]] : tensor<5xf32> into tensor<5x1xf32> %1 = "tosa.reduce_sum"(%arg0) {axis = 1 : i64} : (tensor<5x4xf32>) -> tensor<5x1xf32> // CHECK: constant 1.0 @@ -600,7 +600,7 @@ // CHECK: ^bb0(%arg1: i32, %arg2: i32) // CHECK: [[RES:%.+]] = addi %arg1, %arg2 : i32 // CHECK: linalg.yield [[RES]] : i32 - // CHECK: linalg.tensor_reshape [[GENERIC]] {{\[}}[0, 1]] : tensor<4xi32> into tensor<1x4xi32> + // CHECK: linalg.tensor_expand_shape [[GENERIC]] {{\[}}[0, 1]] : tensor<4xi32> into tensor<1x4xi32> %0 = "tosa.reduce_sum"(%arg0) {axis = 0 : i64} : (tensor<5x4xi32>) -> tensor<1x4xi32> // CHECK: [[INIT:%.+]] = linalg.init_tensor [5] @@ -610,7 +610,7 @@ // CHECK: ^bb0(%arg1: i32, %arg2: i32) // CHECK: [[RES:%.+]] = addi %arg1, %arg2 : i32 // CHECK: linalg.yield [[RES]] : i32 - // CHECK: linalg.tensor_reshape [[GENERIC]] {{\[}}[0, 1]] : tensor<5xi32> into tensor<5x1xi32> + // CHECK: linalg.tensor_expand_shape [[GENERIC]] {{\[}}[0, 1]] : tensor<5xi32> into tensor<5x1xi32> %1 = "tosa.reduce_sum"(%arg0) {axis = 1 : i64} : (tensor<5x4xi32>) -> tensor<5x1xi32> // CHECK: constant 1 @@ -650,7 +650,7 @@ // CHECK: ^bb0(%arg1: i1, %arg2: i1) // CHECK: [[RES:%.+]] = and %arg1, %arg2 : i1 // CHECK: linalg.yield [[RES]] : i1 - // CHECK: linalg.tensor_reshape [[GENERIC]] {{\[}}[0, 1]] : tensor<4xi1> into tensor<1x4xi1> + // CHECK: linalg.tensor_expand_shape [[GENERIC]] {{\[}}[0, 1]] : tensor<4xi1> into tensor<1x4xi1> %0 = "tosa.reduce_all"(%arg0) {axis = 0 : i64} : (tensor<5x4xi1>) -> tensor<1x4xi1> // CHECK: constant false @@ -822,19 +822,19 @@ // CHECK: [[INIT:%.+]] = linalg.init_tensor [2, 2, 1, 3] // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<2x3xi8>) outs([[INIT]] : tensor<2x2x1x3xi8>) // CHECK: linalg.yield %arg1 : i8 - // CHECK: linalg.tensor_reshape [[GENERIC]] {{\[}}[0, 1, 2], [3]] + // CHECK: linalg.tensor_collapse_shape [[GENERIC]] {{\[}}[0, 1, 2], [3]] %0 = "tosa.tile"(%arg0) {multiples = [2, 1]} : (tensor<2x3xi8>) -> (tensor<4x3xi8>) // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 2, 2, 3] // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<2x3xi8>) outs([[INIT]] : tensor<1x2x2x3xi8>) // CHECK: linalg.yield %arg1 : i8 - // CHECK: linalg.tensor_reshape [[GENERIC]] {{\[}}[0, 1], [2, 3]] + // CHECK: linalg.tensor_collapse_shape [[GENERIC]] {{\[}}[0, 1], [2, 3]] %1 = "tosa.tile"(%arg0) {multiples = [1, 2]} : (tensor<2x3xi8>) -> (tensor<2x6xi8>) // CHECK: [[INIT:%.+]] = linalg.init_tensor [5, 2, 7, 3] // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<2x3xi8>) outs([[INIT]] : tensor<5x2x7x3xi8>) // CHECK: linalg.yield %arg1 : i8 - // CHECK: linalg.tensor_reshape [[GENERIC]] {{\[}}[0, 1], [2, 3]] + // CHECK: linalg.tensor_collapse_shape [[GENERIC]] {{\[}}[0, 1], [2, 3]] %2 = "tosa.tile"(%arg0) {multiples = [5, 7]} : (tensor<2x3xi8>) -> (tensor<10x21xi8>) return @@ -1097,7 +1097,7 @@ func @avg_pool(%arg0: tensor<1x6x34x62xf32>) -> (tensor<1x5x33x62xf32>) { // Initial piece computes the sum of the pooling region, with appropriate padding. // CHECK: [[CONST:%.+]] = constant 0 - // CHECK: [[PAD:%.+]] = linalg.pad_tensor %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0] + // CHECK: [[PAD:%.+]] = linalg.pad_tensor %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0] // CHECK: [[CONST:%.+]] = constant 0 // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 5, 33, 62] // CHECK: [[FILL:%.+]] = linalg.fill([[INIT]], [[CONST]]) @@ -1188,9 +1188,9 @@ // CHECK: ^bb0(%arg3: f32, %arg4: f32): // no predecessors // CHECK: linalg.yield %arg3 : f32 // CHECK: } -> tensor<1x5x5x33xf32> - // CHECK: [[DBIAS:%.+]] = linalg.tensor_reshape [[BIAS]] {{\[}}[0], [1], [2], [3, 4]] + // CHECK: [[DBIAS:%.+]] = linalg.tensor_expand_shape [[BIAS]] {{\[}}[0], [1], [2], [3, 4]] // CHECK: [[DEPTH:%.+]] = linalg.depthwise_conv_2d_input_nhwc_filter_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x7x5x3xf32>, tensor<3x1x3x11xf32>) outs([[DBIAS]] : tensor<1x5x5x3x11xf32>) - // CHECK: linalg.tensor_reshape %3 {{\[}}[0], [1], [2], [3, 4]] + // CHECK: linalg.tensor_collapse_shape %3 {{\[}}[0], [1], [2], [3, 4]] %2 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) { pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1] } : (tensor<1x7x5x3xf32>, tensor<3x1x3x11xf32>, tensor<33xf32>) -> (tensor<1x5x5x33xf32>) return } @@ -1202,10 +1202,10 @@ func @resize_nearest(%input: tensor<1x2x2x1xf32>) -> () { // CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 4, 4, 1] // CHECK: %[[GENERIC:.+]] = linalg.generic - // CHECK: %[[IDX0:.+]] = linalg.index 0 - // CHECK: %[[IDX1:.+]] = linalg.index 1 - // CHECK: %[[IDX2:.+]] = linalg.index 2 - // CHECK: %[[IDX3:.+]] = linalg.index 3 + // CHECK: %[[IDX0:.+]] = linalg.index 0 + // CHECK: %[[IDX1:.+]] = linalg.index 1 + // CHECK: %[[IDX2:.+]] = linalg.index 2 + // CHECK: %[[IDX3:.+]] = linalg.index 3 // CHECK-DAG: %[[XYMIN:.+]] = constant 0 // CHECK-DAG: %[[YMAX:.+]] = constant 1 // CHECK-DAG: %[[XMAX:.+]] = constant 1 @@ -1271,9 +1271,9 @@ func @resize_bilinear(%input: tensor<1x2x2x1xf32>) -> () { // CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 4, 4, 1] // CHECK: %[[GENERIC:.+]] = linalg.generic - // CHECK: %[[IDX0:.+]] = linalg.index 0 - // CHECK: %[[IDX1:.+]] = linalg.index 1 - // CHECK: %[[IDX2:.+]] = linalg.index 2 + // CHECK: %[[IDX0:.+]] = linalg.index 0 + // CHECK: %[[IDX1:.+]] = linalg.index 1 + // CHECK: %[[IDX2:.+]] = linalg.index 2 // CHECK: %[[IDX3:.+]] = linalg.index 3 // CHECK: %[[XYMIN:.+]] = constant 0 // CHECK: %[[YMAX:.+]] = constant 1 @@ -1353,9 +1353,9 @@ func @resize_nearest_int(%input: tensor<1x2x2x1xi32>) -> () { // CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 4, 4, 1] // CHECK: %[[GENERIC:.+]] = linalg.generic - // CHECK: %[[IDX0:.+]] = linalg.index 0 - // CHECK: %[[IDX1:.+]] = linalg.index 1 - // CHECK: %[[IDX2:.+]] = linalg.index 2 + // CHECK: %[[IDX0:.+]] = linalg.index 0 + // CHECK: %[[IDX1:.+]] = linalg.index 1 + // CHECK: %[[IDX2:.+]] = linalg.index 2 // CHECK: %[[IDX3:.+]] = linalg.index 3 // CHECK-DAG: %[[XYMIN:.+]] = constant 0 // CHECK-DAG: %[[YMAX:.+]] = constant 1 @@ -1422,7 +1422,7 @@ // CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 4, 4, 1] // CHECK: %[[GENERIC:.+]] = linalg.generic - // CHECK: %[[IDX0:.+]] = linalg.index 0 + // CHECK: %[[IDX0:.+]] = linalg.index 0 // CHECK: %[[IDX3:.+]] = linalg.index 3 // CHECK: %[[XYMIN:.+]] = constant 0 diff --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir --- a/mlir/test/Dialect/Linalg/bufferize.mlir +++ b/mlir/test/Dialect/Linalg/bufferize.mlir @@ -253,15 +253,15 @@ // ----- -// CHECK-LABEL: func @bufferize_tensor_reshape( +// CHECK-LABEL: func @bufferize_tensor_collapse_shape( // CHECK-SAME: %[[IN:.*]]: tensor<4x5xf32> -func @bufferize_tensor_reshape(%arg0: tensor<4x5xf32>) -> tensor<20xf32> { - %out = linalg.tensor_reshape %arg0 [[0, 1]] : +func @bufferize_tensor_collapse_shape(%arg0: tensor<4x5xf32>) -> tensor<20xf32> { + %out = linalg.tensor_collapse_shape %arg0 [[0, 1]] : tensor<4x5xf32> into tensor<20xf32> return %out : tensor<20xf32> } // CHECK: %[[MEMREF:.*]] = memref.buffer_cast %[[IN]] : memref<4x5xf32> -// CHECK: %[[RESHAPE:.*]] = linalg.reshape %[[MEMREF]] {{\[}}[0, 1]] +// CHECK: %[[RESHAPE:.*]] = linalg.collapse_shape %[[MEMREF]] {{\[}}[0, 1]] // CHECK-SAME: : memref<4x5xf32> into memref<20xf32> // CHECK: %[[TENSOR:.*]] = memref.tensor_load %[[RESHAPE]] : memref<20xf32> // CHECK: return %[[TENSOR]] diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -46,9 +46,9 @@ // CHECK-LABEL: zero_rank_reshape_multi func @zero_rank_reshape_multi(%arg0: tensor) -> tensor { // CHECK: return %arg0 - %0 = linalg.tensor_reshape %arg0 [] : tensor into tensor<1xf32> - %1 = linalg.tensor_reshape %0 [[0, 1]] : tensor<1xf32> into tensor<1x1xf32> - %2 = linalg.tensor_reshape %1 [] : tensor<1x1xf32> into tensor + %0 = linalg.tensor_expand_shape %arg0 [] : tensor into tensor<1xf32> + %1 = linalg.tensor_expand_shape %0 [[0, 1]] : tensor<1xf32> into tensor<1x1xf32> + %2 = linalg.tensor_collapse_shape %1 [] : tensor<1x1xf32> into tensor return %2 : tensor } @@ -56,175 +56,175 @@ func @collapsing_tensor_reshapes(%arg0 : tensor) -> tensor { - %0 = linalg.tensor_reshape %arg0 [[0, 1], [2], [3, 4]] + %0 = linalg.tensor_collapse_shape %arg0 [[0, 1], [2], [3, 4]] : tensor into tensor - %1 = linalg.tensor_reshape %0 [[0, 1], [2]] + %1 = linalg.tensor_collapse_shape %0 [[0, 1], [2]] : tensor into tensor return %1 : tensor } // CHECK-LABEL: collapsing_tensor_reshapes -// CHECK: linalg.tensor_reshape %{{.*}} {{\[}}[0, 1, 2], [3, 4]] -// CHECK-NOT: linalg.tensor_reshape +// CHECK: linalg.tensor_collapse_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]] +// CHECK-NOT: linalg.tensor_collapse_shape // ----- func @collapsing_tensor_reshapes_to_zero_dim(%arg0 : tensor<1x1x1xf32>) -> tensor { - %0 = linalg.tensor_reshape %arg0 [[0, 1, 2]] + %0 = linalg.tensor_collapse_shape %arg0 [[0, 1, 2]] : tensor<1x1x1xf32> into tensor<1xf32> - %1 = linalg.tensor_reshape %0 [] : tensor<1xf32> into tensor + %1 = linalg.tensor_collapse_shape %0 [] : tensor<1xf32> into tensor return %1 : tensor } // CHECK-LABEL: collapsing_tensor_reshapes_to_zero -// CHECK: linalg.tensor_reshape %{{.*}} [] +// CHECK: linalg.tensor_collapse_shape %{{.*}} [] // CHECK-SAME: tensor<1x1x1xf32> into tensor // ----- func @collapsing_memref_reshapes_to_zero_dim(%arg0 : memref<1x1x1xf32>) -> memref { - %0 = linalg.reshape %arg0 [[0, 1, 2]] + %0 = linalg.collapse_shape %arg0 [[0, 1, 2]] : memref<1x1x1xf32> into memref<1xf32> - %1 = linalg.reshape %0 [] : memref<1xf32> into memref + %1 = linalg.collapse_shape %0 [] : memref<1xf32> into memref return %1 : memref } // CHECK-LABEL: collapsing_memref_reshapes_to_zero -// CHECK: linalg.reshape %{{.*}} [] +// CHECK: linalg.collapse_shape %{{.*}} [] // CHECK-SAME: memref<1x1x1xf32> into memref // ----- func @expanding_tensor_reshapes(%arg0 : tensor) -> tensor { - %0 = linalg.tensor_reshape %arg0 [[0, 1], [2]] + %0 = linalg.tensor_expand_shape %arg0 [[0, 1], [2]] : tensor into tensor - %1 = linalg.tensor_reshape %0 [[0, 1], [2], [3, 4]] + %1 = linalg.tensor_expand_shape %0 [[0, 1], [2], [3, 4]] : tensor into tensor return %1 : tensor } // CHECK-LABEL: expanding_tensor_reshapes -// CHECK: linalg.tensor_reshape %{{.*}} {{\[}}[0, 1, 2], [3, 4]] -// CHECK-NOT: linalg.tensor_reshape +// CHECK: linalg.tensor_expand_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]] +// CHECK-NOT: linalg.tensor_expand_shape // ----- func @collapsing_memref_reshapes(%arg0 : memref) -> memref { - %0 = linalg.reshape %arg0 [[0, 1], [2], [3, 4]] + %0 = linalg.collapse_shape %arg0 [[0, 1], [2], [3, 4]] : memref into memref - %1 = linalg.reshape %0 [[0, 1], [2]] + %1 = linalg.collapse_shape %0 [[0, 1], [2]] : memref into memref return %1 : memref } // CHECK-LABEL: collapsing_memref_reshapes -// CHECK: linalg.reshape %{{.*}} {{\[}}[0, 1, 2], [3, 4]] -// CHECK-NOT: linalg.reshape +// CHECK: linalg.collapse_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]] +// CHECK-NOT: linalg.collapse_shape // ----- func @expanding_memref_reshapes(%arg0 : memref) -> memref { - %0 = linalg.reshape %arg0 [[0, 1], [2]] + %0 = linalg.expand_shape %arg0 [[0, 1], [2]] : memref into memref - %1 = linalg.reshape %0 [[0, 1], [2], [3, 4]] + %1 = linalg.expand_shape %0 [[0, 1], [2], [3, 4]] : memref into memref return %1 : memref } // CHECK-LABEL: expanding_memref_reshapes -// CHECK: linalg.reshape %{{.*}} {{\[}}[0, 1, 2], [3, 4]] -// CHECK-NOT: linalg.reshape +// CHECK: linalg.expand_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]] +// CHECK-NOT: linalg.expand_shape // ----- func @expanding_tensor_reshapes_to_zero_dim(%arg0 : tensor) -> tensor<1x1x1xf32> { - %0 = linalg.tensor_reshape %arg0 [] : tensor into tensor<1xf32> - %1 = linalg.tensor_reshape %0 [[0, 1, 2]] + %0 = linalg.tensor_expand_shape %arg0 [] : tensor into tensor<1xf32> + %1 = linalg.tensor_expand_shape %0 [[0, 1, 2]] : tensor<1xf32> into tensor<1x1x1xf32> return %1 : tensor<1x1x1xf32> } // CHECK-LABEL: expanding_tensor_reshapes_to_zero -// CHECK: linalg.tensor_reshape %{{.*}} [] +// CHECK: linalg.tensor_expand_shape %{{.*}} [] // CHECK-SAME: tensor into tensor<1x1x1xf32> // ----- func @expanding_memref_reshapes_to_zero_dim(%arg0 : memref) -> memref<1x1x1xf32> { - %0 = linalg.reshape %arg0 [] : memref into memref<1xf32> - %1 = linalg.reshape %0 [[0, 1, 2]] + %0 = linalg.expand_shape %arg0 [] : memref into memref<1xf32> + %1 = linalg.expand_shape %0 [[0, 1, 2]] : memref<1xf32> into memref<1x1x1xf32> return %1 : memref<1x1x1xf32> } // CHECK-LABEL: expanding_memref_reshapes_to_zero -// CHECK: linalg.reshape %{{.*}} [] +// CHECK: linalg.expand_shape %{{.*}} [] // CHECK-SAME: memref into memref<1x1x1xf32> // ----- func @fold_tensor_reshape(%arg0 : tensor<12x4xf32>) -> tensor<12x4xf32> { - %0 = linalg.tensor_reshape %arg0 [[0, 1], [2]] + %0 = linalg.tensor_expand_shape %arg0 [[0, 1], [2]] : tensor<12x4xf32> into tensor<3x4x4xf32> - %1 = linalg.tensor_reshape %0 [[0, 1], [2]] + %1 = linalg.tensor_collapse_shape %0 [[0, 1], [2]] : tensor<3x4x4xf32> into tensor<12x4xf32> return %1 : tensor<12x4xf32> } // CHECK-LABEL: @fold_tensor_reshape -// CHECK-NOT: linalg.tensor_reshape +// CHECK-NOT: linalg.{{.*}}shape // ----- func @fold_tensor_reshape_dynamic(%arg0 : tensor) -> tensor { - %0 = linalg.tensor_reshape %arg0 [[0, 1], [2]] + %0 = linalg.tensor_expand_shape %arg0 [[0, 1], [2]] : tensor into tensor - %1 = linalg.tensor_reshape %0 [[0, 1], [2]] + %1 = linalg.tensor_collapse_shape %0 [[0, 1], [2]] : tensor into tensor return %1 : tensor } // CHECK-LABEL: @fold_tensor_reshape_dynamic -// CHECK-NOT: linalg.tensor_reshape +// CHECK-NOT: linalg.{{.*}}_shape // ----- func @fold_memref_reshape(%arg0 : memref<12x4xf32>) -> memref<12x4xf32> { - %0 = linalg.reshape %arg0 [[0, 1], [2]] + %0 = linalg.expand_shape %arg0 [[0, 1], [2]] : memref<12x4xf32> into memref<3x4x4xf32> - %1 = linalg.reshape %0 [[0, 1], [2]] + %1 = linalg.collapse_shape %0 [[0, 1], [2]] : memref<3x4x4xf32> into memref<12x4xf32> return %1 : memref<12x4xf32> } // CHECK-LABEL: @fold_memref_reshape -// CHECK-NOT: linalg.reshape +// CHECK-NOT: linalg.{{.*}}_shape // ----- func @fold_memref_reshape_dynamic(%arg0 : memref) -> memref { - %0 = linalg.reshape %arg0 [[0, 1], [2]] + %0 = linalg.expand_shape %arg0 [[0, 1], [2]] : memref into memref - %1 = linalg.reshape %0 [[0, 1], [2]] + %1 = linalg.collapse_shape %0 [[0, 1], [2]] : memref into memref return %1 : memref } // CHECK-LABEL: @fold_memref_reshape_dynamic -// CHECK-NOT: linalg.reshape +// CHECK-NOT: linalg.{{.*}}_shape // ----- func @reshape_collapse(%arg0 : tensor<2x3x4x5x6x7x8xf32>) -> tensor<24x5x42x8xf32> { - %0 = linalg.tensor_reshape %arg0 [[0, 1, 2, 3, 4, 5, 6]] + %0 = linalg.tensor_collapse_shape %arg0 [[0, 1, 2, 3, 4, 5, 6]] : tensor<2x3x4x5x6x7x8xf32> into tensor<40320xf32> - %1 = linalg.tensor_reshape %0 [[0, 1, 2, 3]] + %1 = linalg.tensor_expand_shape %0 [[0, 1, 2, 3]] : tensor<40320xf32> into tensor<24x5x42x8xf32> return %1 : tensor<24x5x42x8xf32> } // CHECK: func @reshape_collapse // CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x4x5x6x7x8xf32> -// CHECK: %[[RESULT:.+]] = linalg.tensor_reshape %[[ARG0]] +// CHECK: %[[RESULT:.+]] = linalg.tensor_collapse_shape %[[ARG0]] // CHECK-SAME: [0, 1, 2], [3], [4, 5], [6] // CHECK: return %[[RESULT]] @@ -232,15 +232,15 @@ func @reshape_expand(%arg0 : tensor<24x5x42x8xf32>) -> tensor<2x3x4x5x6x7x8xf32> { - %0 = linalg.tensor_reshape %arg0 [[0, 1, 2, 3]] + %0 = linalg.tensor_collapse_shape %arg0 [[0, 1, 2, 3]] : tensor<24x5x42x8xf32> into tensor<40320xf32> - %1 = linalg.tensor_reshape %0 [[0, 1, 2, 3, 4, 5, 6]] + %1 = linalg.tensor_expand_shape %0 [[0, 1, 2, 3, 4, 5, 6]] : tensor<40320xf32> into tensor<2x3x4x5x6x7x8xf32> return %1 : tensor<2x3x4x5x6x7x8xf32> } // CHECK: func @reshape_expand // CHECK-SAME: %[[ARG0:.+]]: tensor<24x5x42x8xf32> -// CHECK: %[[RESULT:.+]] = linalg.tensor_reshape %[[ARG0]] +// CHECK: %[[RESULT:.+]] = linalg.tensor_expand_shape %[[ARG0]] // CHECK-SAME: [0, 1, 2], [3], [4, 5], [6] // CHECK: return %[[RESULT]] @@ -248,84 +248,84 @@ func @expand_reshape_1D(%arg0 : tensor<2048xf32>) -> tensor<4x512xf32> { - %0 = linalg.tensor_reshape %arg0 [[0, 1, 2, 3]] + %0 = linalg.tensor_expand_shape %arg0 [[0, 1, 2, 3]] : tensor<2048xf32> into tensor<1x4x1x512xf32> - %1 = linalg.tensor_reshape %0 [[0, 1, 2], [3]] + %1 = linalg.tensor_collapse_shape %0 [[0, 1, 2], [3]] : tensor<1x4x1x512xf32> into tensor<4x512xf32> return %1 : tensor<4x512xf32> } // CHECK: func @expand_reshape_1D -// CHECK: linalg.tensor_reshape %{{.*}} {{\[}}[0, 1]] +// CHECK: linalg.tensor_expand_shape %{{.*}} {{\[}}[0, 1]] // CHECK-SAME: tensor<2048xf32> into tensor<4x512xf32> // ----- func @fold_reshape_1D(%arg0 : tensor<4x512xf32>) -> tensor<2048xf32> { - %0 = linalg.tensor_reshape %arg0 [[0, 1, 2], [3]] + %0 = linalg.tensor_expand_shape %arg0 [[0, 1, 2], [3]] : tensor<4x512xf32> into tensor<1x4x1x512xf32> - %1 = linalg.tensor_reshape %0 [[0, 1, 2, 3]] + %1 = linalg.tensor_collapse_shape %0 [[0, 1, 2, 3]] : tensor<1x4x1x512xf32> into tensor<2048xf32> return %1 : tensor<2048xf32> } // CHECK: func @fold_reshape_1D -// CHECK: linalg.tensor_reshape %{{.*}} {{\[}}[0, 1]] +// CHECK: linalg.tensor_collapse_shape %{{.*}} {{\[}}[0, 1]] // CHECK-SAME: tensor<4x512xf32> into tensor<2048xf32> // ----- func @fold_reshape_unit_dims(%arg0 : tensor<2048x1x1xf32>) -> tensor<4x512x1x1xf32> { - %0 = linalg.tensor_reshape %arg0 [[0, 1, 2, 3], [4], [5]] + %0 = linalg.tensor_expand_shape %arg0 [[0, 1, 2, 3], [4], [5]] : tensor<2048x1x1xf32> into tensor<1x4x1x512x1x1xf32> - %1 = linalg.tensor_reshape %0 [[0, 1, 2], [3], [4], [5]] + %1 = linalg.tensor_collapse_shape %0 [[0, 1, 2], [3], [4], [5]] : tensor<1x4x1x512x1x1xf32> into tensor<4x512x1x1xf32> return %1 : tensor<4x512x1x1xf32> } // CHECK: func @fold_reshape_unit_dims -// CHECK: linalg.tensor_reshape %{{.*}} {{\[}}[0, 1], [2], [3]] +// CHECK: linalg.tensor_expand_shape %{{.*}} {{\[}}[0, 1], [2], [3]] // CHECK-SAME: tensor<2048x1x1xf32> into tensor<4x512x1x1xf32> // ----- func @expand_reshape_unit_dims(%arg0 : tensor<2048x1x2048xf32>) -> tensor<4x512x1x512x4xf32> { - %0 = linalg.tensor_reshape %arg0 [[0, 1, 2, 3, 4], [5], [6, 7, 8]] + %0 = linalg.tensor_expand_shape %arg0 [[0, 1, 2, 3, 4], [5], [6, 7, 8]] : tensor<2048x1x2048xf32> into tensor<1x4x1x512x1x1x512x1x4xf32> - %1 = linalg.tensor_reshape %0 [[0, 1, 2], [3, 4], [5], [6, 7], [8]] + %1 = linalg.tensor_collapse_shape %0 [[0, 1, 2], [3, 4], [5], [6, 7], [8]] : tensor<1x4x1x512x1x1x512x1x4xf32> into tensor<4x512x1x512x4xf32> return %1 : tensor<4x512x1x512x4xf32> } // CHECK: func @expand_reshape_unit_dims -// CHECK: linalg.tensor_reshape %{{.*}} {{\[}}[0, 1], [2], [3, 4]] +// CHECK: linalg.tensor_expand_shape %{{.*}} {{\[}}[0, 1], [2], [3, 4]] // CHECK-SAME: tensor<2048x1x2048xf32> into tensor<4x512x1x512x4xf32> // ----- func @fold_reshape_trailing_unit_dims(%arg0: tensor<2xf32>) -> tensor<2x1xf32> { - %0 = linalg.tensor_reshape %arg0 [[0, 1, 2]] + %0 = linalg.tensor_expand_shape %arg0 [[0, 1, 2]] : tensor<2xf32> into tensor<2x1x1xf32> - %1 = linalg.tensor_reshape %0 [[0], [1, 2]] + %1 = linalg.tensor_collapse_shape %0 [[0], [1, 2]] : tensor<2x1x1xf32> into tensor<2x1xf32> return %1 : tensor<2x1xf32> } // CHECK: func @fold_reshape_trailing_unit_dims -// CHECK: linalg.tensor_reshape %{{.*}} {{\[}}[0, 1]] +// CHECK: linalg.tensor_expand_shape %{{.*}} {{\[}}[0, 1]] // CHECK-SAME: tensor<2xf32> into tensor<2x1xf32> // ----- func @collapse_reshape_unit_dims_dynamic(%arg0 : tensor) -> tensor { - %0 = linalg.tensor_reshape %arg0 [[0], [1, 2], [3], [4], [5], [6, 7, 8]] + %0 = linalg.tensor_collapse_shape %arg0 [[0], [1, 2], [3], [4], [5], [6, 7, 8]] : tensor into tensor - %1 = linalg.tensor_reshape %0 [[0], [1], [2, 3, 4], [5]] + %1 = linalg.tensor_collapse_shape %0 [[0], [1], [2, 3, 4], [5]] : tensor into tensor return %1 : tensor } // CHECK: func @collapse_reshape_unit_dims_dynamic -// CHECK: linalg.tensor_reshape +// CHECK: linalg.tensor_collapse_shape // CHECK-SAME: [0], [1, 2], [3, 4, 5], [6, 7, 8] // CHECK-SAME: tensor into tensor @@ -333,72 +333,72 @@ func @fold_reshape_trailing_unit_dims(%arg0: tensor<2xf32>) -> tensor<2x1xf32> { - %0 = linalg.tensor_reshape %arg0 [[0, 1, 2]] + %0 = linalg.tensor_expand_shape %arg0 [[0, 1, 2]] : tensor<2xf32> into tensor<2x1x1xf32> - %1 = linalg.tensor_reshape %0 [[0], [1, 2]] + %1 = linalg.tensor_collapse_shape %0 [[0], [1, 2]] : tensor<2x1x1xf32> into tensor<2x1xf32> return %1 : tensor<2x1xf32> } // CHECK: func @fold_reshape_trailing_unit_dims -// CHECK: linalg.tensor_reshape %{{.*}} {{\[}}[0, 1]] +// CHECK: linalg.tensor_expand_shape %{{.*}} {{\[}}[0, 1]] // CHECK-SAME: tensor<2xf32> into tensor<2x1xf32> // ----- func @fold_reshape_trailing_unit_dims_dynamic(%arg0: tensor<1x1x?x1x1x1xf32>) -> tensor { - %0 = linalg.tensor_reshape %arg0 [[0, 1, 2], [3], [4], [5]] + %0 = linalg.tensor_collapse_shape %arg0 [[0, 1, 2], [3], [4], [5]] : tensor<1x1x?x1x1x1xf32> into tensor - %1 = linalg.tensor_reshape %0 [[0, 1, 2, 3]] + %1 = linalg.tensor_collapse_shape %0 [[0, 1, 2, 3]] : tensor into tensor return %1 : tensor } // CHECK: func @fold_reshape_trailing_unit_dims_dynamic -// CHECK: linalg.tensor_reshape %{{.*}} {{\[}}[0, 1, 2, 3, 4, 5]] +// CHECK: linalg.tensor_collapse_shape %{{.*}} {{\[}}[0, 1, 2, 3, 4, 5]] // CHECK-SAME: tensor<1x1x?x1x1x1xf32> into tensor // ----- func @no_fold_reshapes(%arg0 : tensor) -> tensor { - %0 = linalg.tensor_reshape %arg0 [[0], [1], [2, 3]] + %0 = linalg.tensor_expand_shape %arg0 [[0], [1], [2, 3]] : tensor into tensor - %1 = linalg.tensor_reshape %0 [[0], [1, 2, 3]] + %1 = linalg.tensor_collapse_shape %0 [[0], [1, 2, 3]] : tensor into tensor return %1 : tensor } // CHECK-LABEL: func @no_fold_reshapes -// CHECK: linalg.tensor_reshape -// CHECK: linalg.tensor_reshape +// CHECK: linalg.tensor_expand_shape +// CHECK: linalg.tensor_collapse_shape // ----- func @no_fold_reshape_incompatible(%arg0 : tensor<4x6x8xf32>) -> tensor<2x6x16xf32> { - %0 = linalg.tensor_reshape %arg0 [[0, 1], [2, 3], [4]] + %0 = linalg.tensor_expand_shape %arg0 [[0, 1], [2, 3], [4]] : tensor<4x6x8xf32> into tensor<2x2x3x2x8xf32> - %1 = linalg.tensor_reshape %0 [[0], [1, 2], [3, 4]] + %1 = linalg.tensor_collapse_shape %0 [[0], [1, 2], [3, 4]] : tensor<2x2x3x2x8xf32> into tensor<2x6x16xf32> return %1 : tensor<2x6x16xf32> } // CHECK-LABEL: func @no_fold_reshape_incompatible -// CHECK: linalg.tensor_reshape -// CHECK: linalg.tensor_reshape +// CHECK: linalg.tensor_expand_shape +// CHECK: linalg.tensor_collapse_shape // ----- func @no_fold_reshape_empty_expr(%arg0: tensor<3x2x2xf32>) -> tensor<12x1xf32> { - %0 = linalg.tensor_reshape %arg0 [[0], [1], [2, 3]] + %0 = linalg.tensor_expand_shape %arg0 [[0], [1], [2, 3]] : tensor<3x2x2xf32> into tensor<3x2x2x1xf32> - %1 = linalg.tensor_reshape %0 [[0, 1, 2], [3]] + %1 = linalg.tensor_collapse_shape %0 [[0, 1, 2], [3]] : tensor<3x2x2x1xf32> into tensor<12x1xf32> return %1 : tensor<12x1xf32> } // CHECK: func @no_fold_reshape_empty_expr // CHECK-SAME: %[[ARG0:.+]]: tensor<3x2x2xf32> -// CHECK: %[[RARG0:.+]] = linalg.tensor_reshape %[[ARG0]] +// CHECK: %[[RARG0:.+]] = linalg.tensor_expand_shape %[[ARG0]] // CHECK-SAME: [0], [1], [2, 3] -// CHECK: %[[RES:.+]] = linalg.tensor_reshape %[[RARG0]] +// CHECK: %[[RES:.+]] = linalg.tensor_collapse_shape %[[RARG0]] // CHECK-SAME: [0, 1, 2], [3] // CHECK: return %[[RES:.+]] : tensor<12x1xf32> @@ -436,49 +436,49 @@ func @reshape_splat_constant_int32() -> tensor<2x4x2xi32> { %c0 = constant dense<42> : tensor<2x8xi32> - %0 = linalg.tensor_reshape %c0 [[0], [1, 2]] + %0 = linalg.tensor_expand_shape %c0 [[0], [1, 2]] : tensor<2x8xi32> into tensor<2x4x2xi32> return %0 : tensor<2x4x2xi32> } // CHECK-LABEL: @reshape_splat_constant_int32 // CHECK: %[[CST:.*]] = constant dense<{{.*}}> : tensor<2x4x2xi32> -// CHECK-NOT: linalg.tensor_reshape +// CHECK-NOT: linalg.tensor_expand_shape // CHECK: return %[[CST]] func @reshape_splat_constant_int16() -> tensor<2x4x2xi16> { %c0 = constant dense<42> : tensor<2x8xi16> - %0 = linalg.tensor_reshape %c0 [[0], [1, 2]] + %0 = linalg.tensor_expand_shape %c0 [[0], [1, 2]] : tensor<2x8xi16> into tensor<2x4x2xi16> return %0 : tensor<2x4x2xi16> } // CHECK-LABEL: @reshape_splat_constant_int16 // CHECK: %[[CST:.*]] = constant dense<{{.*}}> : tensor<2x4x2xi16> -// CHECK-NOT: linalg.tensor_reshape +// CHECK-NOT: linalg.tensor_expand_shape // CHECK: return %[[CST]] func @reshape_splat_constant_float32() -> tensor<2x4x2xf32> { %c0 = constant dense<42.0> : tensor<2x8xf32> - %0 = linalg.tensor_reshape %c0 [[0], [1, 2]] + %0 = linalg.tensor_expand_shape %c0 [[0], [1, 2]] : tensor<2x8xf32> into tensor<2x4x2xf32> return %0 : tensor<2x4x2xf32> } // CHECK-LABEL: @reshape_splat_constant_float32 // CHECK: %[[CST:.*]] = constant dense<{{.*}}> : tensor<2x4x2xf32> -// CHECK-NOT: linalg.tensor_reshape +// CHECK-NOT: linalg.tensor_expand_shape // CHECK: return %[[CST]] func @reshape_splat_constant_float64() -> tensor<2x4x2xf64> { %c0 = constant dense<42.0> : tensor<2x8xf64> - %0 = linalg.tensor_reshape %c0 [[0], [1, 2]] + %0 = linalg.tensor_expand_shape %c0 [[0], [1, 2]] : tensor<2x8xf64> into tensor<2x4x2xf64> return %0 : tensor<2x4x2xf64> } // CHECK-LABEL: @reshape_splat_constant_float64 // CHECK: %[[CST:.*]] = constant dense<{{.*}}> : tensor<2x4x2xf64> -// CHECK-NOT: linalg.tensor_reshape +// CHECK-NOT: linalg.tensor_expand_shape // CHECK: return %[[CST]] // ----- @@ -733,7 +733,7 @@ func @init_tensor_reshape_expansion(%arg0 : index) -> tensor<2x3x5x4x?x7xf32> { %0 = linalg.init_tensor [6, 5, %arg0] : tensor<6x5x?xf32> - %1 = linalg.tensor_reshape %0 [[0, 1], [2], [3, 4, 5]] + %1 = linalg.tensor_expand_shape %0 [[0, 1], [2], [3, 4, 5]] : tensor<6x5x?xf32> into tensor<2x3x5x4x?x7xf32> return %1 : tensor<2x3x5x4x?x7xf32> } @@ -748,7 +748,7 @@ func @init_tensor_reshape_collapse(%arg0 : index) -> tensor<6x5x?xf32> { %0 = linalg.init_tensor [2, 3, 5, 4, %arg0, 7] : tensor<2x3x5x4x?x7xf32> - %1 = linalg.tensor_reshape %0 [[0, 1], [2], [3, 4, 5]] + %1 = linalg.tensor_collapse_shape %0 [[0, 1], [2], [3, 4, 5]] : tensor<2x3x5x4x?x7xf32> into tensor<6x5x?xf32> return %1 : tensor<6x5x?xf32> } @@ -898,7 +898,7 @@ %c1 = constant 1 : index %c3 = constant 3 : index %c4 = constant 4 : index - %0 = linalg.tensor_reshape %arg0 [[0, 1], [2], [3, 4, 5]] + %0 = linalg.tensor_expand_shape %arg0 [[0, 1], [2], [3, 4, 5]] : tensor<6x5x?xf32> into tensor<2x3x5x4x?x7xf32> %1 = memref.dim %0, %c1 : tensor<2x3x5x4x?x7xf32> %2 = memref.dim %0, %c3 : tensor<2x3x5x4x?x7xf32> @@ -921,7 +921,7 @@ { %c1 = constant 1 : index %c2 = constant 2 : index - %0 = linalg.tensor_reshape %arg0 [[0, 1], [2], [3, 4, 5]] + %0 = linalg.tensor_collapse_shape %arg0 [[0, 1], [2], [3, 4, 5]] : tensor<2x3x5x4x?x7xf32> into tensor<6x5x?xf32> %1 = memref.dim %0, %c1 : tensor<6x5x?xf32> %2 = memref.dim %0, %c2 : tensor<6x5x?xf32> @@ -979,7 +979,7 @@ %init = linalg.init_tensor [1, 2, 3, 4] : tensor<1x2x3x4xf32> // CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %cst) : tensor<6x4xf32>, f32 -> tensor<6x4xf32> %fill = linalg.fill(%init, %zero) : tensor<1x2x3x4xf32>, f32 -> tensor<1x2x3x4xf32> - %reshape = linalg.tensor_reshape %fill [[0, 1, 2], [3]] + %reshape = linalg.tensor_collapse_shape %fill [[0, 1, 2], [3]] : tensor<1x2x3x4xf32> into tensor<6x4xf32> // CHECK: return %[[FILL]] : tensor<6x4xf32> return %reshape : tensor<6x4xf32> @@ -991,10 +991,10 @@ // CHECK-SAME: %[[ARG0:.+]]: tensor func @fold_fill_reshape_dynamic(%arg0 : tensor) -> tensor { %zero = constant 0.0 : f32 - // CHECK: %[[RESHAPE:.+]] = linalg.tensor_reshape %[[ARG0]] + // CHECK: %[[RESHAPE:.+]] = linalg.tensor_collapse_shape %[[ARG0]] %0 = linalg.fill(%arg0, %zero) : tensor, f32 -> tensor // CHECK: %[[RESULT:.+]] = linalg.fill(%[[RESHAPE]], %{{.+}}) - %1 = linalg.tensor_reshape %0 [[0, 1, 2], [3, 4]] + %1 = linalg.tensor_collapse_shape %0 [[0, 1, 2], [3, 4]] : tensor into tensor // CHECK: return %[[RESULT]] return %1 : tensor diff --git a/mlir/test/Dialect/Linalg/detensorize_0d.mlir b/mlir/test/Dialect/Linalg/detensorize_0d.mlir --- a/mlir/test/Dialect/Linalg/detensorize_0d.mlir +++ b/mlir/test/Dialect/Linalg/detensorize_0d.mlir @@ -19,7 +19,7 @@ // CHECK-DAG: %[[arg2_val:.*]] = tensor.extract %[[arg2]] // CHECK: %[[detensored_res:.*]] = addf %[[arg1_val]], %[[arg2_val]] // CHECK: %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res]] -// CHECK: %[[reshaped_tensor_res:.*]] = linalg.tensor_reshape %[[new_tensor_res]] +// CHECK: %[[reshaped_tensor_res:.*]] = linalg.tensor_collapse_shape %[[new_tensor_res]] // CHECK: return %[[reshaped_tensor_res]] func @detensor_op_sequence(%arg1: tensor, %arg2: tensor) -> tensor attributes {iree.module.export} { @@ -61,7 +61,7 @@ // CHECK: %[[detensored_res2:.*]] = mulf %[[arg1_val2]], %[[detensored_res]] // CHECK: %[[detensored_res3:.*]] = divf %[[detensored_res]], %[[detensored_res2]] // CHECK: %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res3]] -// CHECK: %[[reshaped_tensor_res:.*]] = linalg.tensor_reshape %[[new_tensor_res]] +// CHECK: %[[reshaped_tensor_res:.*]] = linalg.tensor_collapse_shape %[[new_tensor_res]] // CHECK: return %[[reshaped_tensor_res]] func @detensor_multiple_ops(%arg1: tensor, %arg2: tensor) -> tensor attributes {iree.module.export} { @@ -83,7 +83,7 @@ // CHECK: %[[detensored_res:.*]] = addf %[[arg1_val]], %[[arg2_val]] // CHECK: %[[detensored_res2:.*]] = mulf %[[detensored_res]], %[[arg2_val]] // CHECK: %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res2]] -// CHECK: %[[reshaped_tensor_res:.*]] = linalg.tensor_reshape %[[new_tensor_res]] +// CHECK: %[[reshaped_tensor_res:.*]] = linalg.tensor_collapse_shape %[[new_tensor_res]] // CHECK: return %[[reshaped_tensor_res]] func @detensor_foreign_op(%arg1: tensor, %arg2: tensor) -> tensor attributes {iree.module.export} { @@ -103,5 +103,5 @@ // CHECK-DAG: %[[arg2_val:.*]] = tensor.extract %[[arg2]] // CHECK: %[[detensored_res:.*]] = "foreign.do_something"(%[[arg1_val]], %[[arg2_val]]) // CHECK: %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res]] -// CHECK: %[[reshaped_tensor_res:.*]] = linalg.tensor_reshape %[[new_tensor_res]] +// CHECK: %[[reshaped_tensor_res:.*]] = linalg.tensor_collapse_shape %[[new_tensor_res]] // CHECK: return %[[reshaped_tensor_res]] diff --git a/mlir/test/Dialect/Linalg/detensorize_if.mlir b/mlir/test/Dialect/Linalg/detensorize_if.mlir --- a/mlir/test/Dialect/Linalg/detensorize_if.mlir +++ b/mlir/test/Dialect/Linalg/detensorize_if.mlir @@ -10,10 +10,10 @@ func @main() -> (tensor) attributes {} { %c0 = constant 0 : i32 %0 = tensor.from_elements %c0 : tensor<1xi32> - %reshaped0 = linalg.tensor_reshape %0 [] : tensor<1xi32> into tensor + %reshaped0 = linalg.tensor_collapse_shape %0 [] : tensor<1xi32> into tensor %c10 = constant 10 : i32 %1 = tensor.from_elements %c10 : tensor<1xi32> - %reshaped1 = linalg.tensor_reshape %1 [] : tensor<1xi32> into tensor + %reshaped1 = linalg.tensor_collapse_shape %1 [] : tensor<1xi32> into tensor br ^bb1(%reshaped0 : tensor) ^bb1(%2: tensor): // 2 preds: ^bb0, ^bb2 @@ -55,7 +55,7 @@ // CHECK-NEXT: br ^[[bb3:.*]](%{{.*}} : i32) // CHECK-NEXT: ^[[bb3]](%{{.*}}: i32) // CHECK-NEXT: tensor.from_elements %{{.*}} : tensor<1xi32> -// CHECK-NEXT: linalg.tensor_reshape %{{.*}} [] : tensor<1xi32> into tensor +// CHECK-NEXT: linalg.tensor_collapse_shape %{{.*}} [] : tensor<1xi32> into tensor // CHECK-NEXT: return %{{.*}} // CHECK-NEXT: } @@ -74,10 +74,10 @@ func @main() -> (tensor) attributes {} { %c0 = constant 0 : i32 %0 = tensor.from_elements %c0 : tensor<1xi32> - %reshaped0 = linalg.tensor_reshape %0 [] : tensor<1xi32> into tensor + %reshaped0 = linalg.tensor_collapse_shape %0 [] : tensor<1xi32> into tensor %c10 = constant 10 : i32 %1 = tensor.from_elements %c10 : tensor<1xi32> - %reshaped1 = linalg.tensor_reshape %1 [] : tensor<1xi32> into tensor + %reshaped1 = linalg.tensor_collapse_shape %1 [] : tensor<1xi32> into tensor br ^bb1(%reshaped0 : tensor) ^bb1(%2: tensor): // 2 preds: ^bb0, ^bb2 @@ -124,7 +124,7 @@ // CHECK-NEXT: br ^[[bb4:.*]](%{{.*}} : i32) // CHECK-NEXT: ^[[bb4]](%{{.*}}: i32) // CHECK-NEXT: tensor.from_elements %{{.*}} : tensor<1xi32> -// CHECK-NEXT: linalg.tensor_reshape %{{.*}} [] : tensor<1xi32> into tensor +// CHECK-NEXT: linalg.tensor_collapse_shape %{{.*}} [] : tensor<1xi32> into tensor // CHECK-NEXT: return %{{.*}} // CHECK-NEXT: } @@ -140,10 +140,10 @@ func @main() -> (tensor) attributes {} { %c0 = constant 0 : i32 %0 = tensor.from_elements %c0 : tensor<1xi32> - %reshaped0 = linalg.tensor_reshape %0 [] : tensor<1xi32> into tensor + %reshaped0 = linalg.tensor_collapse_shape %0 [] : tensor<1xi32> into tensor %c10 = constant 10 : i32 %1 = tensor.from_elements %c10 : tensor<1xi32> - %reshaped1 = linalg.tensor_reshape %1 [] : tensor<1xi32> into tensor + %reshaped1 = linalg.tensor_collapse_shape %1 [] : tensor<1xi32> into tensor br ^bb1(%reshaped0 : tensor) ^bb1(%2: tensor): // 2 preds: ^bb0, ^bb2 @@ -164,7 +164,7 @@ ^bb2(%6: tensor): // pred: ^bb1 %12 = tensor.from_elements %c10 : tensor<1xi32> - %reshaped12 = linalg.tensor_reshape %12 [] : tensor<1xi32> into tensor + %reshaped12 = linalg.tensor_collapse_shape %12 [] : tensor<1xi32> into tensor %7 = linalg.init_tensor [] : tensor %8 = linalg.generic #attrs ins(%6, %reshaped12 : tensor, tensor) @@ -191,6 +191,6 @@ // CHECK-NEXT: br ^[[bb3:.*]](%{{.*}} : i32) // CHECK-NEXT: ^[[bb3]](%{{.*}}: i32) // CHECK-NEXT: tensor.from_elements %{{.*}} : tensor<1xi32> -// CHECK-NEXT: linalg.tensor_reshape %{{.*}} [] : tensor<1xi32> into tensor +// CHECK-NEXT: linalg.tensor_collapse_shape %{{.*}} [] : tensor<1xi32> into tensor // CHECK-NEXT: return %{{.*}} // CHECK-NEXT: } diff --git a/mlir/test/Dialect/Linalg/detensorize_trivial.mlir b/mlir/test/Dialect/Linalg/detensorize_trivial.mlir --- a/mlir/test/Dialect/Linalg/detensorize_trivial.mlir +++ b/mlir/test/Dialect/Linalg/detensorize_trivial.mlir @@ -12,7 +12,7 @@ func @main(%farg0 : tensor) -> (tensor) attributes {} { %c10 = constant 10 : i32 %1 = tensor.from_elements %c10 : tensor<1xi32> - %reshaped1 = linalg.tensor_reshape %1 [] : tensor<1xi32> into tensor + %reshaped1 = linalg.tensor_collapse_shape %1 [] : tensor<1xi32> into tensor %3 = linalg.init_tensor [] : tensor %4 = linalg.generic #attrs ins(%farg0, %reshaped1 : tensor, tensor) @@ -30,7 +30,7 @@ // DET-ALL-NEXT: tensor.extract %{{.*}}[] // DET-ALL-NEXT: cmpi slt, %{{.*}}, %{{.*}} // DET-ALL-NEXT: tensor.from_elements %{{.*}} -// DET-ALL-NEXT: linalg.tensor_reshape %{{.*}} +// DET-ALL-NEXT: linalg.tensor_collapse_shape %{{.*}} // DET-ALL-NEXT: return %{{.*}} : tensor // DET-ALL-NEXT: } diff --git a/mlir/test/Dialect/Linalg/detensorize_while.mlir b/mlir/test/Dialect/Linalg/detensorize_while.mlir --- a/mlir/test/Dialect/Linalg/detensorize_while.mlir +++ b/mlir/test/Dialect/Linalg/detensorize_while.mlir @@ -52,7 +52,7 @@ // DET-ALL: br ^[[bb1]](%{{.*}} : i32) // DET-ALL: ^[[bb3]](%{{.*}}: i32) // DET-ALL: tensor.from_elements {{.*}} -// DET-ALL: linalg.tensor_reshape {{.*}} +// DET-ALL: linalg.tensor_collapse_shape {{.*}} // DET-ALL: return %{{.*}} : tensor // Test detensoring only ops involed in control-flow. @@ -69,5 +69,5 @@ // DET-CF: br ^[[bb1]](%{{.*}} : i32) // DET-CF: ^[[bb3]](%{{.*}}: i32) // DET-CF: tensor.from_elements %{{.*}} : tensor<1xi32> -// DET-CF: linalg.tensor_reshape %{{.*}} [] : tensor<1xi32> into tensor +// DET-CF: linalg.tensor_collapse_shape %{{.*}} [] : tensor<1xi32> into tensor // DET-CF: return %{{.*}} : tensor diff --git a/mlir/test/Dialect/Linalg/detensorize_while_failure.mlir b/mlir/test/Dialect/Linalg/detensorize_while_failure.mlir --- a/mlir/test/Dialect/Linalg/detensorize_while_failure.mlir +++ b/mlir/test/Dialect/Linalg/detensorize_while_failure.mlir @@ -80,7 +80,7 @@ // DET-ALL: cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : i32) // DET-ALL: ^[[bb2]](%{{.*}}: i32) // DET-ALL: tensor.from_elements %{{.*}} : tensor<1xi32> -// DET-ALL: linalg.tensor_reshape %{{.*}} [] : tensor<1xi32> into tensor +// DET-ALL: linalg.tensor_collapse_shape %{{.*}} [] : tensor<1xi32> into tensor // DET-ALL: linalg.init_tensor [10] : tensor<10xi32> // DET-ALL: linalg.generic {{{.*}}} ins(%{{.*}} : tensor) outs(%{{.*}} : tensor<10xi32>) { // DET-ALL: ^bb0(%{{.*}}: i32, %{{.*}}: i32): @@ -89,11 +89,11 @@ // DET-ALL: br ^[[bb1]](%{{.*}} : tensor<10xi32>) // DET-ALL: ^[[bb3]](%{{.*}}: i32) // DET-ALL: tensor.from_elements %{{.*}} : tensor<1xi32> -// DET-ALL: linalg.tensor_reshape %{{.*}} [] : tensor<1xi32> into tensor +// DET-ALL: linalg.tensor_collapse_shape %{{.*}} [] : tensor<1xi32> into tensor // DET-ALL: return %{{.*}} : tensor // DET-ALL: } -// Try to detensor pure control-flow. However, that fails since the potential +// Try to detensor pure control-flow. However, that fails since the potential // detensorable component contains some ops that cannot be detensored. // // DET-CF-LABEL: func @main diff --git a/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir b/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir --- a/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir +++ b/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir @@ -10,10 +10,10 @@ func @main() -> () attributes {} { %c0 = constant 0 : i32 %0 = tensor.from_elements %c0 : tensor<1xi32> - %reshaped0 = linalg.tensor_reshape %0 [] : tensor<1xi32> into tensor + %reshaped0 = linalg.tensor_collapse_shape %0 [] : tensor<1xi32> into tensor %c10 = constant 10 : i32 %1 = tensor.from_elements %c10 : tensor<1xi32> - %reshaped1 = linalg.tensor_reshape %1 [] : tensor<1xi32> into tensor + %reshaped1 = linalg.tensor_collapse_shape %1 [] : tensor<1xi32> into tensor br ^bb1(%reshaped0 : tensor) ^bb1(%2: tensor): // 2 preds: ^bb0, ^bb2 diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir --- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir +++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir @@ -23,11 +23,11 @@ // CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)> // CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> // CHECK-LABEL: func @drop_one_trip_loops -// CHECK: linalg.tensor_reshape %{{.*}} {{\[}}[0, 1], [2]] +// CHECK: linalg.tensor_collapse_shape %{{.*}} {{\[}}[0, 1], [2]] // CHECK: linalg.generic // CHECK-SAME: indexing_maps = [#[[$MAP2]], #[[$MAP3]]] // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] -// CHECK: linalg.tensor_reshape %{{.*}} {{\[}}[0, 1], [2, 3], [4]] +// CHECK: linalg.tensor_expand_shape %{{.*}} {{\[}}[0, 1], [2, 3], [4]] // ----- @@ -101,7 +101,7 @@ } // CHECK: #[[$MAP0:.*]] = affine_map<() -> ()> // CHECK-LABEL: func @drop_all_loops -// CHECK: linalg.tensor_reshape %{{.*}} [] +// CHECK: linalg.tensor_collapse_shape %{{.*}} [] // CHECK: linalg.generic // CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP0]]] // CHECK-SAME: iterator_types = [] @@ -162,7 +162,7 @@ // CHECK: #[[$MAP1:.*]] = affine_map<(d0) -> (d0)> // CHECK-LABEL: func @leading_dim_1_canonicalization -// CHECK: linalg.tensor_reshape %{{.*}} {{\[}}[0, 1]] +// CHECK: linalg.tensor_collapse_shape %{{.*}} {{\[}}[0, 1]] // CHECK: linalg.generic // CHECK-SAME: indexing_maps = [#[[$MAP1]], #[[$MAP1]]] // CHECK-SAME: iterator_types = ["parallel"] @@ -183,8 +183,8 @@ func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>, %shape : tensor<5x5xf32>) -> tensor<5x5xf32> { - %0 = linalg.tensor_reshape %arg0 [[0, 1]] : tensor<5xf32> into tensor<1x5xf32> - %1 = linalg.tensor_reshape %arg1 [[0, 1]] : tensor<5xf32> into tensor<5x1xf32> + %0 = linalg.tensor_expand_shape %arg0 [[0, 1]] : tensor<5xf32> into tensor<1x5xf32> + %1 = linalg.tensor_expand_shape %arg1 [[0, 1]] : tensor<5xf32> into tensor<5x1xf32> %2 = linalg.generic #trait ins(%0, %1 : tensor<1x5xf32>, tensor<5x1xf32>) outs(%shape : tensor<5x5xf32>) { @@ -198,11 +198,11 @@ // CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0)> // CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: func @broadcast_test -// CHECK-NOT: linalg.tensor_reshape +// CHECK-NOT: linalg.tensor_{{.*}}shape // CHECK: linalg.generic // CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]] // CHECK-SAME: iterator_types = ["parallel", "parallel"] -// CHECK-NOT: linalg.tensor_reshape +// CHECK-NOT: linalg.tensor_{{.*}}shape // ----- @@ -231,7 +231,7 @@ // CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: func @broadcast_scalar // CHECK-SAME: %[[ARG0:.*]]: tensor<1x1xf32> -// CHECK: %[[A:.*]] = linalg.tensor_reshape %[[ARG0]] [] +// CHECK: %[[A:.*]] = linalg.tensor_collapse_shape %[[ARG0]] [] // CHECK-SAME: tensor<1x1xf32> into tensor // CHECK: linalg.generic // CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]] @@ -251,7 +251,7 @@ ^bb0(%arg1: f32, %arg2: f32): // no predecessors linalg.yield %arg1 : f32 } -> tensor<1x2x5xf32> - %3 = linalg.tensor_reshape %2 [[0, 1], [2]] + %3 = linalg.tensor_collapse_shape %2 [[0, 1], [2]] : tensor<1x2x5xf32> into tensor<2x5xf32> return %3 : tensor<2x5xf32> } @@ -283,7 +283,7 @@ // CHECK: func @fold_unit_dim_for_init_tensor -// CHECK: %[[INPUT_RESHAPE:.+]] = linalg.tensor_reshape %{{.+}} {{\[}}[0, 1]] : tensor<1x1000xf32> into tensor<1000xf32> +// CHECK: %[[INPUT_RESHAPE:.+]] = linalg.tensor_collapse_shape %{{.+}} {{\[}}[0, 1]] : tensor<1x1000xf32> into tensor<1000xf32> // CHECK: %[[INIT:.+]] = linalg.init_tensor [] : tensor // CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %cst) : tensor, f32 -> tensor // CHECK: %[[GENERIC:.+]] = linalg.generic @@ -291,7 +291,7 @@ // CHECK-SAME: iterator_types = ["reduction"] // CHECK-SAME: ins(%[[INPUT_RESHAPE]] : tensor<1000xf32>) // CHECK-SAME: outs(%[[FILL]] : tensor) -// CHECK: %[[GENERIC_RESHAPE:.+]] = linalg.tensor_reshape %[[GENERIC]] [] : tensor into tensor<1xf32> +// CHECK: %[[GENERIC_RESHAPE:.+]] = linalg.tensor_expand_shape %[[GENERIC]] [] : tensor into tensor<1xf32> // CHECK: return %[[GENERIC_RESHAPE:.+]] : tensor<1xf32> @@ -314,11 +314,11 @@ // CHECK-SAME: %[[ARG1:.+]]: tensor<1x?x?x?x?x1x1xf32> // CHECK: %[[SUBTENSOR1:.+]] = subtensor %[[ARG0]] // CHECK-SAME: to tensor -// CHECK: %[[RESULT1:.+]] = linalg.tensor_reshape %[[SUBTENSOR1]] +// CHECK: %[[RESULT1:.+]] = linalg.tensor_expand_shape %[[SUBTENSOR1]] // CHECK-SAME: [0, 1], [2], [3, 4, 5, 6] // CHECK: %[[SUBTENSOR2:.+]] = subtensor %[[ARG1]] // CHECK-SAME: to tensor -// CHECK: %[[RESULT2:.+]] = linalg.tensor_reshape %[[SUBTENSOR2]] +// CHECK: %[[RESULT2:.+]] = linalg.tensor_expand_shape %[[SUBTENSOR2]] // CHECK-SAME: [0, 1], [2], [3, 4, 5, 6] // CHECK: return %[[RESULT1]], %[[RESULT2]] @@ -346,7 +346,7 @@ // CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1) -> (d0)> // CHECK: func @unit_dim_for_reduction // CHECK-SAME: %[[ARG0:.+]]: tensor<1x?x1x?xf32> -// CHECK-DAG: %[[RESHAPE:.+]] = linalg.tensor_reshape %[[ARG0]] {{\[}}[0, 1, 2], [3]] +// CHECK-DAG: %[[RESHAPE:.+]] = linalg.tensor_collapse_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]] // CHECK: %[[INIT:.+]] = linalg.init_tensor [%{{.+}}] : tensor // CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %{{.+}}) // CHECK: %[[RESULT:.+]] = linalg.generic @@ -354,7 +354,7 @@ // CHECK-SAME: iterator_types = ["parallel", "reduction"] // CHECK-SAME: ins(%[[RESHAPE]] : tensor) // CHECK-SAME: outs(%[[FILL]] : tensor) -// CHECK: %[[RESULT_RESHAPE:.+]] = linalg.tensor_reshape %[[RESULT]] {{\[}}[0, 1]] +// CHECK: %[[RESULT_RESHAPE:.+]] = linalg.tensor_expand_shape %[[RESULT]] {{\[}}[0, 1]] // CHECK: return %[[RESULT_RESHAPE]] // ----- @@ -380,7 +380,7 @@ // CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1) -> (d0)> // CHECK: func @unit_dim_for_reduction_keep_one // CHECK-SAME: %[[ARG0:.+]]: tensor<1x?x1x1xf32> -// CHECK-DAG: %[[RESHAPE:.+]] = linalg.tensor_reshape %[[ARG0]] {{\[}}[0, 1, 2], [3]] +// CHECK-DAG: %[[RESHAPE:.+]] = linalg.tensor_collapse_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]] // CHECK: %[[INIT:.+]] = linalg.init_tensor [1] : tensor<1xf32> // CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %{{.+}}) // CHECK: %[[RESULT:.+]] = linalg.generic @@ -388,7 +388,7 @@ // CHECK-SAME: iterator_types = ["parallel", "reduction"] // CHECK-SAME: ins(%[[RESHAPE]] : tensor) // CHECK-SAME: outs(%[[FILL]] : tensor<1xf32>) -// CHECK: %[[RESULT_RESHAPE:.+]] = linalg.tensor_reshape %[[RESULT]] {{\[}}[0, 1]] +// CHECK: %[[RESULT_RESHAPE:.+]] = linalg.tensor_expand_shape %[[RESULT]] {{\[}}[0, 1]] // CHECK: return %[[RESULT_RESHAPE]] // ----- @@ -415,7 +415,7 @@ // CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1) -> (d0)> // CHECK: func @unit_dim_for_reduction_inner // CHECK-SAME: %[[ARG0:.+]]: tensor -// CHECK-DAG: %[[RESHAPE:.+]] = linalg.tensor_reshape %[[ARG0]] {{\[}}[0, 1], [2, 3]] +// CHECK-DAG: %[[RESHAPE:.+]] = linalg.tensor_collapse_shape %[[ARG0]] {{\[}}[0, 1], [2, 3]] // CHECK: %[[INIT:.+]] = linalg.init_tensor [%{{.+}}] : tensor // CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %{{.+}}) // CHECK: %[[RESULT:.+]] = linalg.generic @@ -423,7 +423,7 @@ // CHECK-SAME: iterator_types = ["parallel", "reduction"] // CHECK-SAME: ins(%[[RESHAPE]] : tensor) // CHECK-SAME: outs(%[[FILL]] : tensor) -// CHECK: %[[RESULT_RESHAPE:.+]] = linalg.tensor_reshape %[[RESULT]] {{\[}}[0, 1]] +// CHECK: %[[RESULT_RESHAPE:.+]] = linalg.tensor_expand_shape %[[RESULT]] {{\[}}[0, 1]] // CHECK: return %[[RESULT_RESHAPE]] // ----- @@ -435,7 +435,7 @@ // CHECK-LABEL: func @subtensor_unit_dims // CHECK: %[[SUBTENSOR:.+]] = subtensor // CHECK-SAME: tensor<1x3xf32> to tensor -// CHECK: %[[RESULT:.+]] = linalg.tensor_reshape %[[SUBTENSOR]] [] +// CHECK: %[[RESULT:.+]] = linalg.tensor_expand_shape %[[SUBTENSOR]] [] // CHECK: return %[[RESULT]] // ----- @@ -445,7 +445,7 @@ return %0 : tensor<1x3xf32> } // CHECK-LABEL: func @subtensor_insert_unit_dims -// CHECK: %[[RESHAPE:.+]] = linalg.tensor_reshape %{{.+}} [] +// CHECK: %[[RESHAPE:.+]] = linalg.tensor_collapse_shape %{{.+}} [] // CHECK: %[[RESULT:.+]] = subtensor_insert %[[RESHAPE]] // CHECK-SAME: tensor into tensor<1x3xf32> // CHECK: return %[[RESULT]] diff --git a/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir b/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir --- a/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir +++ b/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir @@ -5,14 +5,14 @@ // CHECK-LABEL: func @reshape // CHECK-SAME: (%[[A:.*]]: tensor, %[[B:.*]]: tensor<16xf32>, %[[INIT:.*]]: tensor) -// CHECK: %[[RI:.*]] = linalg.tensor_reshape %[[INIT]] {{\[}}[0, 1], [2]] : tensor into tensor +// CHECK: %[[RI:.*]] = linalg.tensor_collapse_shape %[[INIT]] {{\[}}[0, 1], [2]] : tensor into tensor // CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]], #[[$MAP2]]], // CHECK-SAME: iterator_types = ["parallel", "parallel"]} // CHECK-SAME: ins(%[[A]], %[[B]] : tensor, tensor<16xf32>) outs(%[[RI]] : tensor) -// CHECK: %[[RR:.*]] = linalg.tensor_reshape %[[R]] {{\[}}[0, 1], [2]] : tensor into tensor +// CHECK: %[[RR:.*]] = linalg.tensor_expand_shape %[[R]] {{\[}}[0, 1], [2]] : tensor into tensor // CHECK: return %[[RR]] : tensor func @reshape(%A: tensor, %B: tensor<16xf32>, %init: tensor) -> tensor { - %0 = linalg.tensor_reshape %A [[0, 1], [2]] + %0 = linalg.tensor_expand_shape %A [[0, 1], [2]] : tensor into tensor %2 = linalg.generic {indexing_maps = [ affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d2)>, @@ -35,17 +35,17 @@ // CHECK-LABEL: func @reshape_multiple // CHECK-SAME: (%[[A:.*]]: tensor<12544x16xf32>, %[[B:.*]]: tensor<12544x16xf32>, %[[C:.*]]: tensor<16xf32>) // CHECK: %[[I:.*]] = linalg.init_tensor [112, 112, 16] : tensor<112x112x16xf32> -// CHECK: %[[RI:.*]] = linalg.tensor_reshape %[[I]] {{\[}}[0, 1], [2]] : tensor<112x112x16xf32> into tensor<12544x16xf32> +// CHECK: %[[RI:.*]] = linalg.tensor_collapse_shape %[[I]] {{\[}}[0, 1], [2]] : tensor<112x112x16xf32> into tensor<12544x16xf32> // CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP2]], #[[$MAP3]], #[[$MAP2]]], // CHECK-SAME: iterator_types = ["parallel", "parallel"]} // CHECK-SAME: ins(%[[A]], %[[B]], %[[C]] : tensor<12544x16xf32>, tensor<12544x16xf32>, tensor<16xf32>) outs(%[[RI]] : tensor<12544x16xf32>) -// CHECK: %[[RR:.*]] = linalg.tensor_reshape %[[R]] {{\[}}[0, 1], [2]] : tensor<12544x16xf32> into tensor<112x112x16xf32> +// CHECK: %[[RR:.*]] = linalg.tensor_expand_shape %[[R]] {{\[}}[0, 1], [2]] : tensor<12544x16xf32> into tensor<112x112x16xf32> // CHECK: return %[[RR]] : tensor<112x112x16xf32> func @reshape_multiple(%A: tensor<12544x16xf32>, %B: tensor<12544x16xf32>, %C: tensor<16xf32>) -> tensor<112x112x16xf32> { - %0 = linalg.tensor_reshape %A [[0, 1], [2]] + %0 = linalg.tensor_expand_shape %A [[0, 1], [2]] : tensor<12544x16xf32> into tensor<112x112x16xf32> - %1 = linalg.tensor_reshape %B [[0, 1], [2]] + %1 = linalg.tensor_expand_shape %B [[0, 1], [2]] : tensor<12544x16xf32> into tensor<112x112x16xf32> %2 = linalg.init_tensor [112, 112, 16] : tensor<112x112x16xf32> %3 = linalg.generic {indexing_maps = [ @@ -69,11 +69,11 @@ // Negative test, since the second source is broadcasted from d1 we cannot merge // d0 and d1 dimensions // CHECK-LABEL: func @reshape_negative -// CHECK: linalg.tensor_reshape {{.*}} : tensor<12544x16xf32> into tensor<112x112x16xf32> +// CHECK: linalg.tensor_expand_shape {{.*}} : tensor<12544x16xf32> into tensor<112x112x16xf32> // CHECK: linalg.generic // CHECK: } -> tensor<112x112x16xf32> func @reshape_negative(%A: tensor<12544x16xf32>, %B: tensor<112xf32>) -> tensor<112x112x16xf32> { - %20 = linalg.tensor_reshape %A [[0, 1], [2]] + %20 = linalg.tensor_expand_shape %A [[0, 1], [2]] : tensor<12544x16xf32> into tensor<112x112x16xf32> %21 = linalg.init_tensor [112, 112, 16] : tensor<112x112x16xf32> %22 = linalg.generic {indexing_maps = [ @@ -96,7 +96,7 @@ %cst_6 = constant 1.000000e+00 : f32 %cst_7 = constant 7.000000e+00 : f32 %cst_8 = constant 1.1920929E-7 : f32 - %25 = linalg.tensor_reshape %arg0 [[0, 1], [2]] + %25 = linalg.tensor_expand_shape %arg0 [[0, 1], [2]] : tensor<6x5xi32> into tensor<2x3x5xi32> %26 = linalg.init_tensor [2, 3, 5] : tensor<2x3x5xf32> %28 = linalg.generic { @@ -122,5 +122,5 @@ // CHECK: %[[OP:.+]] = linalg.generic // CHECK-SAME: ins(%{{.+}}, %{{.+}}, %{{.+}} : tensor<6x5xi32>, tensor<5xf32>, tensor<5xf32>) // CHECK-SAME: outs(%{{.+}} : tensor<6x5xf32>) -// CHECK: linalg.tensor_reshape %[[OP]] +// CHECK: linalg.tensor_expand_shape %[[OP]] // CHECK-SAME: tensor<6x5xf32> into tensor<2x3x5xf32> diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -348,21 +348,35 @@ func @reshape(%arg0: memref) { // expected-error @+1 {{expected non-zero memref ranks}} - %0 = linalg.reshape %arg0 [[0]] : memref into memref + %0 = linalg.expand_shape %arg0 [[0]] : memref into memref +} + +// ----- + +func @collapse_to_higher_rank(%arg0: memref) { + // expected-error @+1 {{expected the type 'memref' to have higher rank than the type = 'memref<1xf32>'}} + %0 = linalg.collapse_shape %arg0 [[0]] : memref into memref<1xf32> +} + +// ----- + +func @expand_to_smaller_rank(%arg0: memref<1xf32>) { + // expected-error @+1 {{expected the type 'memref' to have higher rank than the type = 'memref<1xf32>'}} + %0 = linalg.expand_shape %arg0 [[0]] : memref<1xf32> into memref } // ----- func @reshape(%arg0: memref) { // expected-error @+1 {{expected to collapse or expand dims}} - %0 = linalg.reshape %arg0 [[0]] : memref into memref + %0 = linalg.collapse_shape %arg0 [[0]] : memref into memref } // ----- func @reshape(%arg0: memref) { // expected-error @+1 {{expected rank of the collapsed type(2) to be the number of reassociation maps(1)}} - %0 = linalg.reshape %arg0 [[0, 1]] : + %0 = linalg.collapse_shape %arg0 [[0, 1]] : memref into memref } @@ -370,7 +384,7 @@ func @reshape(%arg0: memref) { // expected-error @+1 {{expected reassociation map #1 to be valid and contiguous}} - %0 = linalg.reshape %arg0 [[0, 1], [1, 2]] : + %0 = linalg.collapse_shape %arg0 [[0, 1], [1, 2]] : memref into memref } @@ -378,7 +392,7 @@ func @reshape(%arg0: memref) { // expected-error @+1 {{expected collapsed type to be 'memref', but got 'memref (d0 * s0 + d1)>>'}} - %0 = linalg.reshape %arg0 [[0, 1], [2]] : + %0 = linalg.collapse_shape %arg0 [[0, 1], [2]] : memref into memref (d0 * s0 + d1)>> } @@ -455,7 +469,7 @@ (%arg0: tensor) -> tensor { // expected-error @+1 {{invalid to have a single dimension (2) expanded into multiple dynamic dims (2,4)}} - %0 = linalg.tensor_reshape %arg0 [[0], [1], [2, 3, 4]] + %0 = linalg.tensor_expand_shape %arg0 [[0], [1], [2, 3, 4]] : tensor into tensor return %0 : tensor } @@ -466,7 +480,7 @@ (%arg0: memref) -> memref { // expected-error @+1 {{invalid to have a single dimension (2) expanded into multiple dynamic dims (2,4)}} - %0 = linalg.reshape %arg0 [[0], [1], [2, 3, 4]] + %0 = linalg.expand_shape %arg0 [[0], [1], [2, 3, 4]] : memref into memref return %0 : memref } @@ -477,7 +491,7 @@ (%arg0: tensor<2x3x20xf32>) -> tensor<2x3x2x4x5xf32> { // expected-error @+1 {{expected dimension 2 of collapsed type to be static value of 40}} - %0 = linalg.tensor_reshape %arg0 [[0], [1], [2, 3, 4]] + %0 = linalg.tensor_expand_shape %arg0 [[0], [1], [2, 3, 4]] : tensor<2x3x20xf32> into tensor<2x3x2x4x5xf32> return %0 : tensor<2x3x2x4x5xf32> } @@ -488,7 +502,7 @@ (%arg0: tensor<2x3x2x4x5xf32>) -> tensor<2x3x20xf32> { // expected-error @+1 {{expected dimension 2 of collapsed type to be static value of 40}} - %0 = linalg.tensor_reshape %arg0 [[0], [1], [2, 3, 4]] + %0 = linalg.tensor_collapse_shape %arg0 [[0], [1], [2, 3, 4]] : tensor<2x3x2x4x5xf32> into tensor<2x3x20xf32> return %0 : tensor<2x3x20xf32> } @@ -499,7 +513,7 @@ (%arg0: memref<2x3x20xf32>) -> memref<2x3x2x4x5xf32> { // expected-error @+1 {{expected dimension 2 of collapsed type to be static value of 40}} - %0 = linalg.reshape %arg0 [[0], [1], [2, 3, 4]] + %0 = linalg.expand_shape %arg0 [[0], [1], [2, 3, 4]] : memref<2x3x20xf32> into memref<2x3x2x4x5xf32> return %0 : memref<2x3x2x4x5xf32> } @@ -510,87 +524,87 @@ (%arg0: memref<2x3x2x4x5xf32>) -> memref<2x3x20xf32> { // expected-error @+1 {{expected dimension 2 of collapsed type to be static value of 40}} - %0 = linalg.reshape %arg0 [[0], [1], [2, 3, 4]] + %0 = linalg.collapse_shape %arg0 [[0], [1], [2, 3, 4]] : memref<2x3x2x4x5xf32> into memref<2x3x20xf32> return %0 : memref<2x3x20xf32> } // ----- -func @illegal_collapsing_reshape_mixed_tensor(%arg0 : tensor) -> tensor +func @illegal_expanding_reshape_mixed_tensor(%arg0 : tensor) -> tensor { // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 5}} - %0 = linalg.tensor_reshape %arg0 [[0, 1], [2]] + %0 = linalg.tensor_expand_shape %arg0 [[0, 1], [2]] : tensor into tensor return %0 : tensor } // ----- -func @illegal_collapsing_reshape_mixed_tensor_2(%arg0 : tensor) -> tensor +func @illegal_expanding_reshape_mixed_tensor_2(%arg0 : tensor) -> tensor { // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 20}} - %0 = linalg.tensor_reshape %arg0 [[0], [1, 2]] + %0 = linalg.tensor_expand_shape %arg0 [[0], [1, 2]] : tensor into tensor return %0 : tensor } // ----- -func @illegal_expanding_reshape_mixed_tensor(%arg0 : tensor) -> tensor +func @illegal_collapsing_reshape_mixed_tensor(%arg0 : tensor) -> tensor { // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 5}} - %0 = linalg.tensor_reshape %arg0 [[0, 1], [2]] + %0 = linalg.tensor_collapse_shape %arg0 [[0, 1], [2]] : tensor into tensor return %0 : tensor } // ----- -func @illegal_expanding_reshape_mixed_tensor_2(%arg0 : tensor) -> tensor +func @illegal_collapsing_reshape_mixed_tensor_2(%arg0 : tensor) -> tensor { // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 20}} - %0 = linalg.tensor_reshape %arg0 [[0], [1, 2]] + %0 = linalg.tensor_collapse_shape %arg0 [[0], [1, 2]] : tensor into tensor return %0 : tensor } // ----- -func @illegal_collapsing_reshape_mixed_memref(%arg0 : memref) -> memref +func @illegal_expanding_reshape_mixed_memref(%arg0 : memref) -> memref { // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 5}} - %0 = linalg.reshape %arg0 [[0, 1], [2]] + %0 = linalg.expand_shape %arg0 [[0, 1], [2]] : memref into memref return %0 : memref } // ----- -func @illegal_collapsing_reshape_mixed_memref_2(%arg0 : memref) -> memref +func @illegal_expanding_reshape_mixed_memref_2(%arg0 : memref) -> memref { // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 20}} - %0 = linalg.reshape %arg0 [[0], [1, 2]] + %0 = linalg.expand_shape %arg0 [[0], [1, 2]] : memref into memref return %0 : memref } // ----- -func @illegal_expanding_reshape_mixed_memref(%arg0 : memref) -> memref +func @illegal_collapsing_reshape_mixed_memref(%arg0 : memref) -> memref { // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 5}} - %0 = linalg.reshape %arg0 [[0, 1], [2]] + %0 = linalg.collapse_shape %arg0 [[0, 1], [2]] : memref into memref return %0 : memref } // ----- -func @illegal_expanding_reshape_mixed_memref_2(%arg0 : memref) -> memref +func @illegal_collapse_reshape_mixed_memref_2(%arg0 : memref) -> memref { // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 20}} - %0 = linalg.reshape %arg0 [[0], [1, 2]] + %0 = linalg.collapse_shape %arg0 [[0], [1, 2]] : memref into memref return %0 : memref } diff --git a/mlir/test/Dialect/Linalg/llvm.mlir b/mlir/test/Dialect/Linalg/llvm.mlir --- a/mlir/test/Dialect/Linalg/llvm.mlir +++ b/mlir/test/Dialect/Linalg/llvm.mlir @@ -14,13 +14,13 @@ // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(i64, i64, i64)> // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(i64, i64, i64)> -func @reshape_static_expand(%arg0: memref<3x4x5xf32>) -> memref<1x3x4x1x5xf32> { +func @expand_shape_static(%arg0: memref<3x4x5xf32>) -> memref<1x3x4x1x5xf32> { // Reshapes that expand a contiguous tensor with some 1's. - %0 = linalg.reshape %arg0 [[0, 1], [2], [3, 4]] + %0 = linalg.expand_shape %arg0 [[0, 1], [2], [3, 4]] : memref<3x4x5xf32> into memref<1x3x4x1x5xf32> return %0 : memref<1x3x4x1x5xf32> } -// CHECK-LABEL: func @reshape_static_expand +// CHECK-LABEL: func @expand_shape_static // CHECK: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> // CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> @@ -49,12 +49,12 @@ // CHECK: llvm.mlir.constant(1 : index) : i64 // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 4] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> -func @reshape_static_collapse(%arg0: memref<1x3x4x1x5xf32>) -> memref<3x4x5xf32> { - %0 = linalg.reshape %arg0 [[0, 1], [2], [3, 4]] : +func @collapse_shape_static(%arg0: memref<1x3x4x1x5xf32>) -> memref<3x4x5xf32> { + %0 = linalg.collapse_shape %arg0 [[0, 1], [2], [3, 4]] : memref<1x3x4x1x5xf32> into memref<3x4x5xf32> return %0 : memref<3x4x5xf32> } -// CHECK-LABEL: func @reshape_static_collapse +// CHECK-LABEL: func @collapse_shape_static // CHECK: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> // CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> @@ -75,11 +75,11 @@ // CHECK: llvm.mlir.constant(1 : index) : i64 // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -func @reshape_fold_zero_dim(%arg0 : memref<1x1xf32>) -> memref { - %0 = linalg.reshape %arg0 [] : memref<1x1xf32> into memref +func @collapse_shape_fold_zero_dim(%arg0 : memref<1x1xf32>) -> memref { + %0 = linalg.collapse_shape %arg0 [] : memref<1x1xf32> into memref return %0 : memref } -// CHECK-LABEL: func @reshape_fold_zero_dim +// CHECK-LABEL: func @collapse_shape_fold_zero_dim // CHECK: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64)> // CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64)> @@ -88,11 +88,11 @@ // CHECK: llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64)> -func @reshape_expand_zero_dim(%arg0 : memref) -> memref<1x1xf32> { - %0 = linalg.reshape %arg0 [] : memref into memref<1x1xf32> +func @expand_shape_zero_dim(%arg0 : memref) -> memref<1x1xf32> { + %0 = linalg.expand_shape %arg0 [] : memref into memref<1x1xf32> return %0 : memref<1x1xf32> } -// CHECK-LABEL: func @reshape_expand_zero_dim +// CHECK-LABEL: func @expand_shape_zero_dim // CHECK: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64)> // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir --- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir +++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir @@ -6,7 +6,7 @@ %arg1 : tensor) -> tensor { - %0 = linalg.tensor_reshape %arg0 [[0], [1, 2], [3]] : + %0 = linalg.tensor_collapse_shape %arg0 [[0], [1, 2], [3]] : tensor into tensor %1 = linalg.generic { indexing_maps = [#map0, #map1, #map1], @@ -25,16 +25,16 @@ // CHECK: func @generic_op_reshape_producer_fusion // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor -// CHECK: %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]] +// CHECK: %[[T0:.+]] = linalg.tensor_collapse_shape %[[ARG0]] // CHECK-SAME: [0], [1, 2], [3] -// CHECK: %[[T1:.+]] = linalg.tensor_reshape %[[ARG1]] +// CHECK: %[[T1:.+]] = linalg.tensor_expand_shape %[[ARG1]] // CHECK-SAME: [0], [1], [2, 3] // CHECK: %[[T3:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP5]], #[[MAP6]], #[[MAP6]]] // CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"] // CHECK-SAME: ins(%[[ARG0]], %[[T1]] : tensor, tensor) // CHECK-SAME: outs(%{{.+}} : tensor) -// CHECK: %[[T4:.+]] = linalg.tensor_reshape %[[T3]] +// CHECK: %[[T4:.+]] = linalg.tensor_collapse_shape %[[T3]] // CHECK-SAME: [0], [1], [2, 3] // CHECK-SAME: tensor into tensor // CHECK: return %[[T4]] @@ -55,19 +55,19 @@ %1 = mulf %arg3, %arg4 : f32 linalg.yield %1 : f32 } -> tensor - %1 = linalg.tensor_reshape %0 [[0], [1, 2, 3]] : + %1 = linalg.tensor_expand_shape %0 [[0], [1, 2, 3]] : tensor into tensor return %1 : tensor } // CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> // CHECK: func @generic_op_reshape_consumer_fusion -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor -// CHECK: %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]] +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor, +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor) +// CHECK: %[[T0:.+]] = linalg.tensor_expand_shape %[[ARG0]] // CHECK-SAME: [0], [1, 2, 3] // CHECK-SAME: tensor into tensor -// CHECK: %[[T1:.+]] = linalg.tensor_reshape %[[ARG1]] +// CHECK: %[[T1:.+]] = linalg.tensor_expand_shape %[[ARG1]] // CHECK-SAME: [0], [1, 2, 3] // CHECK-SAME: tensor into tensor // CHECK: %[[T3:.+]] = linalg.generic @@ -94,7 +94,7 @@ %1 = addf %arg0, %arg1 : f32 linalg.yield %1 : f32 } -> tensor - %d = linalg.tensor_reshape %c [[0, 1], [2], [3, 4, 5]] + %d = linalg.tensor_expand_shape %c [[0, 1], [2], [3, 4, 5]] : tensor into tensor return %d : tensor } @@ -104,10 +104,10 @@ // CHECK: func @reshape_as_consumer_permutation // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor -// CHECK: %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]] +// CHECK: %[[T0:.+]] = linalg.tensor_expand_shape %[[ARG0]] // CHECK-SAME: [0, 1, 2], [3, 4], [5] // CHECK-SAME: tensor into tensor<3x4x?x?x2x?xf32> -// CHECK: %[[T1:.+]] = linalg.tensor_reshape %[[ARG1]] +// CHECK: %[[T1:.+]] = linalg.tensor_expand_shape %[[ARG1]] // CHECK-SAME: [0, 1, 2], [3] // CHECK-SAME: tensor into tensor<3x4x?x?xf32> // CHECK: %[[T3:.+]] = linalg.generic @@ -136,7 +136,7 @@ %2 = mulf %arg1, %arg2 : f32 linalg.yield %2 : f32 } -> tensor<264x4xf32> - %2 = linalg.tensor_reshape %1 [[0, 1], [2]] : + %2 = linalg.tensor_expand_shape %1 [[0, 1], [2]] : tensor<264x4xf32> into tensor<8x33x4xf32> return %2 : tensor<8x33x4xf32> } @@ -144,7 +144,7 @@ // CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> // CHECK: func @generic_op_reshape_consumer_static // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<264x4xf32> -// CHECK: %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]] +// CHECK: %[[T0:.+]] = linalg.tensor_expand_shape %[[ARG0]] // CHECK-SAME: [0, 1], [2] // CHECK-SAME: tensor<264x4xf32> into tensor<8x33x4xf32> // CHECK: %[[T1:.+]] = linalg.init_tensor [8, 33, 4] @@ -163,7 +163,7 @@ %arg1 : tensor) -> tensor { - %0 = linalg.tensor_reshape %arg0 [[0], [1, 2], [3]]: + %0 = linalg.tensor_collapse_shape %arg0 [[0], [1, 2], [3]]: tensor into tensor %1 = linalg.generic { indexing_maps = [#map0, #map1, #map1], @@ -229,7 +229,7 @@ %5 = addi %3, %4 : i32 linalg.yield %5 : i32 } -> tensor - %1 = linalg.tensor_reshape %0 [[0], [1, 2, 3]] : + %1 = linalg.tensor_expand_shape %0 [[0], [1, 2, 3]] : tensor into tensor return %1 : tensor } @@ -279,7 +279,7 @@ %7 = addi %5, %6 : i32 linalg.yield %7 : i32 } -> tensor<6x4x210xi32> - %d = linalg.tensor_reshape %c [[0, 1], [2], [3, 4, 5]] + %d = linalg.tensor_expand_shape %c [[0, 1], [2], [3, 4, 5]] : tensor<6x4x210xi32> into tensor<2x3x4x5x6x7xi32> return %d : tensor<2x3x4x5x6x7xi32> } @@ -293,9 +293,9 @@ // CHECK: func @reshape_as_consumer_permutation // CHECK-SAME: %[[ARG0:.+]]: tensor<210x6x4xi32> // CHECK-SAME: %[[ARG1:.+]]: tensor<210x4xi32> -// CHECK-DAG: %[[T1:.+]] = linalg.tensor_reshape %[[ARG0]] +// CHECK-DAG: %[[T1:.+]] = linalg.tensor_expand_shape %[[ARG0]] // CHECK-SAME: [0, 1, 2], [3, 4], [5] -// CHECK-DAG: %[[T2:.+]] = linalg.tensor_reshape %[[ARG1]] +// CHECK-DAG: %[[T2:.+]] = linalg.tensor_expand_shape %[[ARG1]] // CHECK-SAME: [0, 1, 2], [3] // CHECK-DAG: %[[T0:.+]] = linalg.init_tensor [2, 3, 4, 5, 6, 7] // CHECK: %[[T4:.+]] = linalg.generic @@ -326,7 +326,7 @@ func @reshape_as_producer_projected_permutation( %arg0 : tensor<33x8x?xi32>, %shape : tensor<264x?x4xi32>) -> tensor<264x?x4xi32> { - %0 = linalg.tensor_reshape %arg0 [[0, 1], [2]] + %0 = linalg.tensor_collapse_shape %arg0 [[0, 1], [2]] : tensor<33x8x?xi32> into tensor<264x?xi32> %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>, @@ -372,7 +372,7 @@ // CHECK: %[[T5:.+]] = index_cast %[[IDX3]] : index to i32 // CHECK: %[[T6:.+]] = addi %[[T4]], %[[T5]] : i32 // CHECK: linalg.yield %[[T6]] : i32 -// CHECK: %[[RES2:.+]] = linalg.tensor_reshape %[[RES]] +// CHECK: %[[RES2:.+]] = linalg.tensor_collapse_shape %[[RES]] // CHECK-SAME: [0, 1], [2], [3] // CHECK-SAME: : tensor<33x8x?x4xi32> into tensor<264x?x4xi32> // CHECK: return %[[RES2]] : tensor<264x?x4xi32> @@ -394,7 +394,7 @@ %1 = mulf %arg3, %arg4 : f32 linalg.yield %1 : f32 } -> tensor - %1 = linalg.tensor_reshape %0 [[0], [1, 2, 3]] : + %1 = linalg.tensor_expand_shape %0 [[0], [1, 2, 3]] : tensor into tensor return %1 : tensor } @@ -404,10 +404,10 @@ // CHECK: func @generic_op_reshape_consumer_fusion_projected // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor -// CHECK: %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]] +// CHECK: %[[T0:.+]] = linalg.tensor_expand_shape %[[ARG0]] // CHECK-SAME: [0, 1, 2], [3] // CHECK-SAME: tensor into tensor -// CHECK: %[[T1:.+]] = linalg.tensor_reshape %[[ARG1]] +// CHECK: %[[T1:.+]] = linalg.tensor_expand_shape %[[ARG1]] // CHECK-SAME: [0, 1, 2], [3] // CHECK-SAME: tensor into tensor // CHECK: %[[T3:.+]] = linalg.generic @@ -420,7 +420,7 @@ // ----- func @unit_dim_reshape_expansion(%arg0 : tensor<1x5xf32>) -> tensor<5x5xf32> { - %0 = linalg.tensor_reshape %arg0 [[0, 1]] + %0 = linalg.tensor_collapse_shape %arg0 [[0, 1]] : tensor<1x5xf32> into tensor<5xf32> %1 = linalg.init_tensor [5, 5] : tensor<5x5xf32> %2 = linalg.generic @@ -434,7 +434,7 @@ return %2 : tensor<5x5xf32> } // CHECK: func @unit_dim_reshape_expansion -// CHECK-DAG: linalg.tensor_reshape +// CHECK-DAG: linalg.tensor_collapse_shape // CHECK-DAG: linalg.init_tensor // CHECK: linalg.generic @@ -450,14 +450,14 @@ ^bb0(%arg2: f32, %arg3: f32): // no predecessors linalg.yield %arg2 : f32 } -> tensor<5x5xf32> - %2 = linalg.tensor_reshape %1 [[0, 1], [2]] + %2 = linalg.tensor_expand_shape %1 [[0, 1], [2]] : tensor<5x5xf32> into tensor<5x1x5xf32> return %2 : tensor<5x1x5xf32> } // CHECK: func @unit_dim_reshape_collapse // CHECK: linalg.init_tensor // CHECK: linalg.generic -// CHECK: linalg.tensor_reshape +// CHECK: linalg.tensor_expand_shape // ----- @@ -465,7 +465,7 @@ (%arg0 : tensor<1x?x1x2x1x4xf32>, %arg1 : tensor) -> tensor { %c1 = constant 1 : index - %0 = linalg.tensor_reshape %arg0 [[0, 1, 2], [3, 4], [5]] + %0 = linalg.tensor_collapse_shape %arg0 [[0, 1, 2], [3, 4], [5]] : tensor<1x?x1x2x1x4xf32> into tensor %1 = memref.dim %arg0, %c1 : tensor<1x?x1x2x1x4xf32> %2 = linalg.init_tensor [%1, 2, 4] : tensor @@ -483,7 +483,7 @@ return %3 : tensor } // CHECK: func @unit_dim_reshape_expansion_full -// CHECK-DAG: linalg.tensor_reshape +// CHECK-DAG: linalg.tensor_collapse_shape // CHECK-DAG: linalg.init_tensor // CHECK: linalg.generic // CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor, tensor) @@ -491,7 +491,7 @@ // FOLDUNITDIM: func @unit_dim_reshape_expansion_full // FOLDUNITDIM-SAME: %[[ARG0:.+]]: tensor<1x?x1x2x1x4xf32> // FOLDUNITDIM-SAME: %[[ARG1:.+]]: tensor -// FOLDUNITDIM-DAG: %[[RESHAPE:.+]] = linalg.tensor_reshape %[[ARG1]] +// FOLDUNITDIM-DAG: %[[RESHAPE:.+]] = linalg.tensor_expand_shape %[[ARG1]] // FOLDUNITDIM: linalg.generic // FOLDUNITDIM-SAME: ins(%[[ARG0]], %[[RESHAPE]] : tensor<1x?x1x2x1x4xf32>, tensor<1x?x1x2x1x4xf32>) // FOLDUNITDIM-SAME: outs(%{{.+}} : tensor<1x?x1x2x1x4xf32>) diff --git a/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir --- a/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir +++ b/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir @@ -3,7 +3,7 @@ #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> func @generic_op_reshape_producer_fusion(%arg0 : tensor) -> tensor { - %0 = linalg.tensor_reshape %arg0 [[0], [1, 2], [3]] : + %0 = linalg.tensor_expand_shape %arg0 [[0], [1, 2], [3]] : tensor into tensor %1 = linalg.generic { indexing_maps = [#map0, #map0], @@ -22,7 +22,7 @@ // CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> // CHECK: func @generic_op_reshape_producer_fusion // CHECK-SAME: %[[ARG0:.+]]: tensor -// CHECK: %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]] +// CHECK: %[[T0:.+]] = linalg.tensor_expand_shape %[[ARG0]] // CHECK-SAME: [0], [1, 2], [3] // CHECK: linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP3]], #[[MAP4]]] @@ -46,7 +46,7 @@ %3 = addi %arg6, %2 : i32 linalg.yield %3 : i32 } -> tensor - %1 = linalg.tensor_reshape %0 [[0], [1, 2, 3]] : + %1 = linalg.tensor_collapse_shape %0 [[0], [1, 2, 3]] : tensor into tensor return %1 : tensor } @@ -54,21 +54,21 @@ // CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)> // CHECK: func @generic_op_reshape_consumer_fusion // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor -// CHECK: %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]] +// CHECK: %[[T0:.+]] = linalg.tensor_collapse_shape %[[ARG0]] // CHECK-SAME: [0], [1, 2, 3] // CHECK: linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP3]]] // CHECK-SAME: outs(%[[T0]] : tensor) // CHECK: %[[IDX:.+]] = linalg.index 0 : index // CHECK-NEXT: %[[IDX_CASTED:.+]] = index_cast %[[IDX]] : index to i32 -// CHECK-NOT: linalg.tensor_reshape +// CHECK-NOT: linalg.tensor_collapse_shape // ----- #map2 = affine_map<(d0, d1, d2) -> (d0, d2, d1)> #map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> func @generic_op_021_permultation_reshape_producer_fusion(%arg0 : tensor<3x35xf32>) -> tensor<3x7x5xf32> { - %0 = linalg.tensor_reshape %arg0 [[0], [1, 2]] + %0 = linalg.tensor_expand_shape %arg0 [[0], [1, 2]] : tensor<3x35xf32> into tensor<3x5x7xf32> %1 = linalg.init_tensor [3, 7, 5] : tensor<3x7x5xf32> %2 = linalg.generic @@ -84,7 +84,7 @@ // CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1 + d2 * 7)> // CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> // CHECK: func @generic_op_021_permultation_reshape_producer_fusion -// CHECK-NOT: linalg.tensor_reshape +// CHECK-NOT: linalg.tensor_expand_shape // CHECK: linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] @@ -93,7 +93,7 @@ #map2 = affine_map<(d0, d1, d2) -> (d1, d0, d2)> #map3 = affine_map<(d0, d1, d2) -> (d0, d2, d1)> func @generic_op_120_permutation_reshape_producer_fusion(%arg0 : tensor<3x35xf32>) -> tensor<5x7x3xf32> { - %0 = linalg.tensor_reshape %arg0 [[0], [1, 2]] + %0 = linalg.tensor_expand_shape %arg0 [[0], [1, 2]] : tensor<3x35xf32> into tensor<3x5x7xf32> %1 = linalg.init_tensor [5, 7, 3] : tensor<5x7x3xf32> %2 = linalg.generic @@ -109,7 +109,7 @@ // CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d1, d0 * 7 + d2)> // CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d2, d1)> // CHECK: func @generic_op_120_permutation_reshape_producer_fusion -// CHECK-NOT: linalg.tensor_reshape +// CHECK-NOT: linalg.tensor_expand_shape // CHECK: linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] @@ -120,7 +120,7 @@ #map2 = affine_map<(d0, d1, d2) -> (d1, d0, d2)> #map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> func @generic_op_102_permultation_reshape_producer_fusion(%arg0 : tensor<3x35xf32>) -> tensor<5x3x7xf32> { - %0 = linalg.tensor_reshape %arg0 [[0], [1, 2]] + %0 = linalg.tensor_expand_shape %arg0 [[0], [1, 2]] : tensor<3x35xf32> into tensor<3x5x7xf32> %1 = linalg.init_tensor [5, 3, 7] : tensor<5x3x7xf32> %2 = linalg.generic @@ -137,7 +137,7 @@ // CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d1, d0 * 7 + d2)> // CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> // CHECK: func @generic_op_102_permultation_reshape_producer_fusion -// CHECK-NOT: linalg.tensor_reshape +// CHECK-NOT: linalg.tensor_expand_shape // CHECK: linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] @@ -156,7 +156,7 @@ ^bb0(%arg2: f32, %arg3 : f32): // no predecessors linalg.yield %arg2 : f32 } -> tensor<5x3x7xf32> - %2 = linalg.tensor_reshape %1 [[0], [1, 2]] + %2 = linalg.tensor_collapse_shape %1 [[0], [1, 2]] : tensor<5x3x7xf32> into tensor<5x21xf32> return %2 : tensor<5x21xf32> } @@ -165,7 +165,7 @@ // CHECK: func @generic_op_102_permultation_reshape_consumer_fusion // CHECK-SAME: %[[ARG0:.+]]: tensor<3x5x7xf32> // CHECK: %[[T0:.+]] = linalg.init_tensor [5, 3, 7] -// CHECK: %[[T1:.+]] = linalg.tensor_reshape %[[T0]] +// CHECK: %[[T1:.+]] = linalg.tensor_collapse_shape %[[T0]] // CHECK-SAME: [0], [1, 2] // CHECK: linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP3]]] @@ -188,7 +188,7 @@ %1 = mulf %arg3, %arg4 : f32 linalg.yield %1 : f32 } -> tensor - %1 = linalg.tensor_reshape %0 [[0], [1, 2, 3]] : + %1 = linalg.tensor_collapse_shape %0 [[0], [1, 2, 3]] : tensor into tensor return %1 : tensor } @@ -197,5 +197,5 @@ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor // CHECK: %[[NOFUSE:.+]] = linalg.generic // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] -// CHECK: %[[RESULT:.+]] = linalg.tensor_reshape %[[NOFUSE]] +// CHECK: %[[RESULT:.+]] = linalg.tensor_collapse_shape %[[NOFUSE]] // CHECK: return %[[RESULT]] diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -563,92 +563,92 @@ func @reshape_static(%arg0: memref<3x4x5xf32>, %arg1: tensor<3x4x5xf32>, %arg2: tensor<3x?x5xf32>) { // Reshapes that collapse and expand back a contiguous buffer. - %0 = linalg.reshape %arg0 [[0, 1], [2]] : + %0 = linalg.collapse_shape %arg0 [[0, 1], [2]] : memref<3x4x5xf32> into memref<12x5xf32> - %r0 = linalg.reshape %0 [[0, 1], [2]] : + %r0 = linalg.expand_shape %0 [[0, 1], [2]] : memref<12x5xf32> into memref<3x4x5xf32> - %1 = linalg.reshape %arg0 [[0], [1, 2]] : + %1 = linalg.collapse_shape %arg0 [[0], [1, 2]] : memref<3x4x5xf32> into memref<3x20xf32> - %r1 = linalg.reshape %1 [[0], [1, 2]] : + %r1 = linalg.expand_shape %1 [[0], [1, 2]] : memref<3x20xf32> into memref<3x4x5xf32> - %2 = linalg.reshape %arg0 [[0, 1, 2]] : + %2 = linalg.collapse_shape %arg0 [[0, 1, 2]] : memref<3x4x5xf32> into memref<60xf32> - %r2 = linalg.reshape %2 [[0, 1, 2]] : + %r2 = linalg.expand_shape %2 [[0, 1, 2]] : memref<60xf32> into memref<3x4x5xf32> // Reshapes that expand and collapse back a contiguous buffer with some 1's. - %3 = linalg.reshape %arg0 [[0, 1], [2], [3, 4]] : + %3 = linalg.expand_shape %arg0 [[0, 1], [2], [3, 4]] : memref<3x4x5xf32> into memref<1x3x4x1x5xf32> - %r3 = linalg.reshape %3 [[0, 1], [2], [3, 4]] : + %r3 = linalg.collapse_shape %3 [[0, 1], [2], [3, 4]] : memref<1x3x4x1x5xf32> into memref<3x4x5xf32> // Reshapes on tensors. - %t0 = linalg.tensor_reshape %arg1 [[0, 1], [2], [3, 4]] : + %t0 = linalg.tensor_expand_shape %arg1 [[0, 1], [2], [3, 4]] : tensor<3x4x5xf32> into tensor<1x3x4x1x5xf32> - %rt0 = linalg.tensor_reshape %t0 [[0, 1], [2], [3, 4]] : + %rt0 = linalg.tensor_collapse_shape %t0 [[0, 1], [2], [3, 4]] : tensor<1x3x4x1x5xf32> into tensor<3x4x5xf32> - %t1 = linalg.tensor_reshape %arg2 [[0, 1], [2], [3, 4]] : + %t1 = linalg.tensor_expand_shape %arg2 [[0, 1], [2], [3, 4]] : tensor<3x?x5xf32> into tensor<1x3x?x1x5xf32> - %rt1 = linalg.tensor_reshape %t1 [[0], [1, 2], [3, 4]] : + %rt1 = linalg.tensor_collapse_shape %t1 [[0], [1, 2], [3, 4]] : tensor<1x3x?x1x5xf32> into tensor<1x?x5xf32> return } // CHECK-LABEL: func @reshape_static -// CHECK: linalg.reshape {{.*}} {{\[}}[0, 1], [2]] +// CHECK: linalg.collapse_shape {{.*}} {{\[}}[0, 1], [2]] // CHECK-SAME: memref<3x4x5xf32> into memref<12x5xf32> -// CHECK: linalg.reshape {{.*}} {{\[}}[0, 1], [2]] +// CHECK: linalg.expand_shape {{.*}} {{\[}}[0, 1], [2]] // CHECK-SAME: memref<12x5xf32> into memref<3x4x5xf32> -// CHECK: linalg.reshape {{.*}} {{\[}}[0], [1, 2]] +// CHECK: linalg.collapse_shape {{.*}} {{\[}}[0], [1, 2]] // CHECK-SAME: memref<3x4x5xf32> into memref<3x20xf32> -// CHECK: linalg.reshape {{.*}} {{\[}}[0], [1, 2]] +// CHECK: linalg.expand_shape {{.*}} {{\[}}[0], [1, 2]] // CHECK-SAME: memref<3x20xf32> into memref<3x4x5xf32> -// CHECK: linalg.reshape {{.*}} {{\[}}[0, 1, 2]] +// CHECK: linalg.collapse_shape {{.*}} {{\[}}[0, 1, 2]] // CHECK-SAME: memref<3x4x5xf32> into memref<60xf32> -// CHECK: linalg.reshape {{.*}} {{\[}}[0, 1, 2]] +// CHECK: linalg.expand_shape {{.*}} {{\[}}[0, 1, 2]] // CHECK-SAME: memref<60xf32> into memref<3x4x5xf32> -// CHECK: linalg.reshape {{.*}} {{\[}}[0, 1], [2], [3, 4]] +// CHECK: linalg.expand_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]] // CHECK-SAME: memref<3x4x5xf32> into memref<1x3x4x1x5xf32> -// CHECK: linalg.reshape {{.*}} {{\[}}[0, 1], [2], [3, 4]] +// CHECK: linalg.collapse_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]] // CHECK-SAME: memref<1x3x4x1x5xf32> into memref<3x4x5xf32> // -// CHECK: linalg.tensor_reshape {{.*}}: tensor<3x4x5xf32> into tensor<1x3x4x1x5xf32> -// CHECK: linalg.tensor_reshape {{.*}}: tensor<1x3x4x1x5xf32> into tensor<3x4x5xf32> -// CHECK: linalg.tensor_reshape {{.*}}: tensor<3x?x5xf32> into tensor<1x3x?x1x5xf32> -// CHECK: linalg.tensor_reshape {{.*}}: tensor<1x3x?x1x5xf32> into tensor<1x?x5xf32> +// CHECK: linalg.tensor_expand_shape {{.*}}: tensor<3x4x5xf32> into tensor<1x3x4x1x5xf32> +// CHECK: linalg.tensor_collapse_shape {{.*}}: tensor<1x3x4x1x5xf32> into tensor<3x4x5xf32> +// CHECK: linalg.tensor_expand_shape {{.*}}: tensor<3x?x5xf32> into tensor<1x3x?x1x5xf32> +// CHECK: linalg.tensor_collapse_shape {{.*}}: tensor<1x3x?x1x5xf32> into tensor<1x?x5xf32> // ----- func @reshape_dynamic(%arg0: memref, %arg1: memref, %arg2: memref) { - %0 = linalg.reshape %arg0 [[0, 1], [2]] : + %0 = linalg.collapse_shape %arg0 [[0, 1], [2]] : memref into memref - %r0 = linalg.reshape %0 [[0, 1], [2]] : + %r0 = linalg.expand_shape %0 [[0, 1], [2]] : memref into memref - %1 = linalg.reshape %arg1 [[0, 1], [2]] : + %1 = linalg.collapse_shape %arg1 [[0, 1], [2]] : memref into memref - %r1 = linalg.reshape %1 [[0, 1], [2]] : + %r1 = linalg.expand_shape %1 [[0, 1], [2]] : memref into memref - %2 = linalg.reshape %arg2 [[0, 1], [2]] : + %2 = linalg.collapse_shape %arg2 [[0, 1], [2]] : memref into memref - %r2 = linalg.reshape %2 [[0, 1], [2]] : + %r2 = linalg.expand_shape %2 [[0, 1], [2]] : memref into memref return } // CHECK-LABEL: func @reshape -// CHECK: linalg.reshape {{.*}} {{\[}}[0, 1], [2]] +// CHECK: linalg.collapse_shape {{.*}} {{\[}}[0, 1], [2]] // CHECK-SAME: memref into memref -// CHECK: linalg.reshape {{.*}} {{\[}}[0, 1], [2]] +// CHECK: linalg.expand_shape {{.*}} {{\[}}[0, 1], [2]] // CHECK-SAME: memref into memref -// CHECK: linalg.reshape {{.*}} {{\[}}[0, 1], [2]] +// CHECK: linalg.collapse_shape {{.*}} {{\[}}[0, 1], [2]] // CHECK-SAME: memref into memref -// CHECK: linalg.reshape {{.*}} {{\[}}[0, 1], [2]] +// CHECK: linalg.expand_shape {{.*}} {{\[}}[0, 1], [2]] // CHECK-SAME: memref into memref -// CHECK: linalg.reshape {{.*}} {{\[}}[0, 1], [2]] +// CHECK: linalg.collapse_shape {{.*}} {{\[}}[0, 1], [2]] // CHECK-SAME: memref into memref -// CHECK: linalg.reshape {{.*}} {{\[}}[0, 1], [2]] +// CHECK: linalg.expand_shape {{.*}} {{\[}}[0, 1], [2]] // CHECK-SAME: memref into memref func @named_ops(%a3: memref, %b3: memref, %c3: memref, @@ -679,25 +679,25 @@ func @tensor_reshape_zero_dim(%arg0 : tensor<1x1xf32>, %arg1 : tensor) -> (tensor, tensor<1x1xf32>) { - %0 = linalg.tensor_reshape %arg0 [] : tensor<1x1xf32> into tensor - %1 = linalg.tensor_reshape %0 [] : tensor into tensor<1x1xf32> + %0 = linalg.tensor_collapse_shape %arg0 [] : tensor<1x1xf32> into tensor + %1 = linalg.tensor_expand_shape %0 [] : tensor into tensor<1x1xf32> return %0, %1 : tensor, tensor<1x1xf32> } // CHECK-LABEL: func @tensor_reshape_zero_dim -// CHECK: linalg.tensor_reshape %{{.*}} [] : tensor<1x1xf32> into tensor -// CHECK: linalg.tensor_reshape %{{.*}} [] : tensor into tensor<1x1xf32> +// CHECK: linalg.tensor_collapse_shape %{{.*}} [] : tensor<1x1xf32> into tensor +// CHECK: linalg.tensor_expand_shape %{{.*}} [] : tensor into tensor<1x1xf32> // ----- func @memref_reshape_zero_dim(%arg0 : memref<1x1xf32>, %arg1 : memref) -> (memref, memref<1x1xf32>) { - %0 = linalg.reshape %arg0 [] : memref<1x1xf32> into memref - %1 = linalg.reshape %0 [] : memref into memref<1x1xf32> + %0 = linalg.collapse_shape %arg0 [] : memref<1x1xf32> into memref + %1 = linalg.expand_shape %0 [] : memref into memref<1x1xf32> return %0, %1 : memref, memref<1x1xf32> } // CHECK-LABEL: func @memref_reshape_zero_dim -// CHECK: linalg.reshape %{{.*}} [] : memref<1x1xf32> into memref -// CHECK: linalg.reshape %{{.*}} [] : memref into memref<1x1xf32> +// CHECK: linalg.collapse_shape %{{.*}} [] : memref<1x1xf32> into memref +// CHECK: linalg.expand_shape %{{.*}} [] : memref into memref<1x1xf32> // ----- @@ -716,12 +716,12 @@ func @legal_collapsing_reshape_dynamic_tensor (%arg0: tensor) -> tensor { - %0 = linalg.tensor_reshape %arg0 [[0], [1], [2, 3, 4]] : + %0 = linalg.tensor_collapse_shape %arg0 [[0], [1], [2, 3, 4]] : tensor into tensor return %0 : tensor } // CHECK: func @legal_collapsing_reshape_dynamic_tensor -// CHECK: linalg.tensor_reshape +// CHECK: linalg.tensor_collapse_shape // CHECK-SAME: [0], [1], [2, 3, 4] // ----- @@ -729,12 +729,12 @@ func @legal_collapsing_reshape_dynamic_memref (%arg0: memref) -> memref { - %0 = linalg.reshape %arg0 [[0], [1], [2, 3, 4]] : + %0 = linalg.collapse_shape %arg0 [[0], [1], [2, 3, 4]] : memref into memref return %0 : memref } // CHECK: func @legal_collapsing_reshape_dynamic_memref -// CHECK: linalg.reshape +// CHECK: linalg.collapse_shape // CHECK-SAME: [0], [1], [2, 3, 4] // -----