diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -396,91 +396,6 @@ def IndexListArrayAttr : TypedArrayAttrBase; -class Linalg_ReshapeOp : Linalg_ReshapeLikeOp]>, - Arguments<(ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation)>, - Results<(outs AnyStridedMemRef:$result)> { - let extraClassDeclaration = commonExtraClassDeclaration # [{ - MemRefType getSrcType() { return src().getType().cast(); } - MemRefType getResultType() { return result().getType().cast(); } - }]; - let hasFolder = 1; - let hasCanonicalizer = 1; - let printer = [{ return ::print(p, *this); }]; -} - -def Linalg_ExpandShapeOp : Linalg_ReshapeOp<"expand_shape"> { - let summary = "operation to produce a memref with a higher rank."; - let description = [{ - The `linalg.expand_shape` op produces a new view with a higher rank whose - sizes are a reassociation of the original `view`. Depending on whether or - not the reassociated MemRefType is contiguous, the resulting memref may - require explicit alloc and copies. - - A reassociation is defined as a continuous grouping of dimensions and is - represented with an array of I64ArrayAttr attribute. - - For now, it is assumed that either: - 1. a reassociation produces and consumes contiguous MemRefType or, - 2. the reshape op will be folded into its consumers (by changing the shape - of the computations). - All other cases are undefined behavior and a reshape op may not lower to - LLVM if it cannot be proven statically that it does not require alloc+copy. - - The operand memref type when dimensions can be zero-ranked if the result - memref type is statically shaped with all dimensions being unit extent. In - such case the reassociation map is empty. - - The verification rule is that the reassociation maps are applied to the - result memref with the larger rank to obtain the operand memref with the - smaller rank. - - Example: - - ```mlir - // Dimension expansion i -> (i', j') and (k) -> (k') - %1 = linalg.expand_shape %0 [[0, 1], [2]] : - memref into memref - ``` - }]; -} - -def Linalg_CollapseShapeOp : Linalg_ReshapeOp<"collapse_shape"> { - let summary = "operation to produce a memref with a smaller rank."; - let description = [{ - The `linalg.collapse_shape` op produces a new view with a smaller rank - whose sizes are a reassociation of the original `view`. Depending on - whether or not the reassociated MemRefType is contiguous, the resulting - memref may require explicit alloc and copies. - - A reassociation is defined as a continuous grouping of dimensions and is - represented with an array of I64ArrayAttr attribute. - - For now, it is assumed that either: - 1. a reassociation produces and consumes contiguous MemRefType or, - 2. the reshape op will be folded into its consumers (by changing the shape - of the computations). - All other cases are undefined behavior and a reshape op may not lower to - LLVM if it cannot be proven statically that it does not require alloc+copy. - - The result memref type of a reshape can be zero-ranked if the operand - memref type is statically shaped with all dimensions being unit extent. In - such case the reassociation map is empty. - - The verification rule is that the reassociation maps are applied to the - operand memref with the larger rank to obtain the result memref with the - smaller rank. - - Examples: - - ```mlir - // Dimension collapse (i, j) -> i' and k -> k' - %1 = linalg.collapse_shape %0 [[0, 1], [2]] : - memref into memref - ``` - }]; -} - class Linalg_TensorReshapeOp : Linalg_ReshapeLikeOp< mnemonic, [DeclareOpInterfaceMethods, SingleBlockImplicitTerminator<"AllocaScopeReturnOp">, @@ -225,8 +225,8 @@ Here, `%myalloca` memref is valid within the explicitly delimited scope and is automatically deallocated at the end of the given region. Conceptually, - `memref.alloca_scope` is a passthrough operation with - `AutomaticAllocationScope` that spans the body of the region within the operation. + `memref.alloca_scope` is a passthrough operation with + `AutomaticAllocationScope` that spans the body of the region within the operation. `memref.alloca_scope` may also return results that are defined in the nested region. To return a value, one should use `memref.alloca_scope.return` @@ -251,14 +251,14 @@ // AllocaScopeReturnOp //===----------------------------------------------------------------------===// -def MemRef_AllocaScopeReturnOp : MemRef_Op<"alloca_scope.return", +def MemRef_AllocaScopeReturnOp : MemRef_Op<"alloca_scope.return", [HasParent<"AllocaScopeOp">, NoSideEffect, ReturnLike, Terminator]> { let summary = "terminator for alloca_scope operation"; let description = [{ - `memref.alloca_scope.return` operation returns zero or more SSA values + `memref.alloca_scope.return` operation returns zero or more SSA values from the region within `memref.alloca_scope`. If no values are returned, the return operation may be omitted. Otherwise, it has to be present to indicate which values are going to be returned. For example: @@ -927,6 +927,150 @@ }]; } +//===----------------------------------------------------------------------===// +// ExpandShapeOp / CollapseShapeOp +//===----------------------------------------------------------------------===// + +def IndexListArrayAttr : + TypedArrayAttrBase; + +class MemRef_ReassociativeReshapeOp traits = []> : + MemRef_Op, + Arguments<(ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation)>, + Results<(outs AnyStridedMemRef:$result)>{ + let builders = [ + // Builders for a contracting reshape whose result type is computed from + // `src` and `reassociation`. + OpBuilder<(ins "Value":$src, + "ArrayRef":$reassociation, + CArg<"ArrayRef", "{}">:$attrs)>, + OpBuilder<(ins "Value":$src, + "ArrayRef":$reassociation, + CArg<"ArrayRef", "{}">:$attrs), + [{ + auto reassociationMaps = + convertReassociationMapsToIndices($_builder, reassociation); + build($_builder, $_state, src, reassociationMaps, attrs); + }]>, + + // Builders for a reshape whose result type is passed explicitly. This may + // be either a contracting or expanding reshape. + OpBuilder<(ins "Type":$resultType, "Value":$src, + "ArrayRef":$reassociation, + CArg<"ArrayRef", "{}">:$attrs), + [{ + build($_builder, $_state, resultType, src, attrs); + $_state.addAttribute("reassociation", + getReassociationIndicesAttribute($_builder, reassociation)); + }]>, + OpBuilder<(ins "Type":$resultType, "Value":$src, + "ArrayRef":$reassociation, + CArg<"ArrayRef", "{}">:$attrs), + [{ + auto reassociationMaps = + convertReassociationMapsToIndices($_builder, reassociation); + build($_builder, $_state, resultType, src, reassociationMaps, attrs); + }]> + ]; + + code commonExtraClassDeclaration = [{ + SmallVector getReassociationMaps(); + SmallVector getReassociationExprs(); + SmallVector getReassociationIndices() { + SmallVector reassociationIndices; + for (auto attr : reassociation()) + reassociationIndices.push_back(llvm::to_vector<2>( + llvm::map_range(attr.cast(), [&](Attribute indexAttr) { + return indexAttr.cast().getInt(); + }))); + return reassociationIndices; + }; + MemRefType getSrcType() { return src().getType().cast(); } + MemRefType getResultType() { return result().getType().cast(); } + Value getViewSource() { return src(); } + }]; + + let hasFolder = 1; + let hasCanonicalizer = 1; + let printer = [{ return ::print(p, *this); }]; + let parser = [{ return ::parseReshapeLikeOp(parser, result); }]; +} + +def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape"> { + let summary = "operation to produce a memref with a higher rank."; + let description = [{ + The `memref.expand_shape` op produces a new view with a higher rank whose + sizes are a reassociation of the original `view`. Depending on whether or + not the reassociated MemRefType is contiguous, the resulting memref may + require explicit alloc and copies. + + A reassociation is defined as a continuous grouping of dimensions and is + represented with an array of I64ArrayAttr attribute. + + For now, it is assumed that either: + 1. a reassociation produces and consumes contiguous MemRefType or, + 2. the reshape op will be folded into its consumers (by changing the shape + of the computations). + All other cases are undefined behavior and a reshape op may not lower to + LLVM if it cannot be proven statically that it does not require alloc+copy. + + The operand memref type when dimensions can be zero-ranked if the result + memref type is statically shaped with all dimensions being unit extent. In + such case the reassociation map is empty. + + The verification rule is that the reassociation maps are applied to the + result memref with the larger rank to obtain the operand memref with the + smaller rank. + + Example: + + ```mlir + // Dimension expansion i -> (i', j') and (k) -> (k') + %1 = memref.expand_shape %0 [[0, 1], [2]] : + memref into memref + ``` + }]; + let extraClassDeclaration = commonExtraClassDeclaration; +} + +def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape"> { + let summary = "operation to produce a memref with a smaller rank."; + let description = [{ + The `memref.collapse_shape` op produces a new view with a smaller rank + whose sizes are a reassociation of the original `view`. Depending on + whether or not the reassociated MemRefType is contiguous, the resulting + memref may require explicit alloc and copies. + + A reassociation is defined as a continuous grouping of dimensions and is + represented with an array of I64ArrayAttr attribute. + + For now, it is assumed that either: + 1. a reassociation produces and consumes contiguous MemRefType or, + 2. the reshape op will be folded into its consumers (by changing the shape + of the computations). + All other cases are undefined behavior and a reshape op may not lower to + LLVM if it cannot be proven statically that it does not require alloc+copy. + + The result memref type of a reshape can be zero-ranked if the operand + memref type is statically shaped with all dimensions being unit extent. In + such case the reassociation map is empty. + + The verification rule is that the reassociation maps are applied to the + operand memref with the larger rank to obtain the result memref with the + smaller rank. + + Examples: + + ```mlir + // Dimension collapse (i, j) -> i' and k -> k' + %1 = memref.collapse_shape %0 [[0, 1], [2]] : + memref into memref + ``` + }]; + let extraClassDeclaration = commonExtraClassDeclaration; +} + //===----------------------------------------------------------------------===// // StoreOp //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h --- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h @@ -1,4 +1,4 @@ -//===- RehshapeOpsUtils.h - Utilities used by reshape ops --*- C++ -*------===// +//===- ReshapeOpsUtils.h - Utilities used by reshape ops --*- C++ -*------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -26,7 +26,7 @@ using ReassociationExprs = SmallVector; /// Attribute name for the ArrayAttr which encodes reassociation indices. -constexpr StringRef getReassociationAttrName(); +constexpr StringRef getReassociationAttrName() { return "reassociation"; } /// Compose reassociation maps that are used in pair of reshape ops where one /// is a producer and other is the consumer. Only valid to use this method when @@ -45,6 +45,23 @@ ArrayRef consumerReassociations, MLIRContext *context); +/// Convert reassociation indices to affine expressions. +SmallVector, 2> convertReassociationIndicesToExprs( + OpBuilder &b, ArrayRef reassociationIndices); + +/// Constructs affine maps out of Array>. +SmallVector +getSymbolLessAffineMaps(ArrayRef reassociation); + +/// Wraps a list of reassociations in an ArrayAttr. +ArrayAttr +getReassociationIndicesAttribute(OpBuilder &b, + ArrayRef reassociation); + +/// Convert Array> to Array>. +SmallVector convertReassociationMapsToIndices( + OpBuilder &b, ArrayRef reassociationExprs); + /// Return the reassociations maps to use to reshape given the source type and /// the target type when possible. Return llvm::None when this computation /// failed. @@ -78,7 +95,7 @@ p << "] "; p.printOptionalAttrDict(op->getAttrs(), - /*elidedAttrs=*/{op.getReassociationAttrName()}); + /*elidedAttrs=*/{getReassociationAttrName()}); p << ": " << op.src().getType() << " into " << op.getType(); } diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp --- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp +++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp @@ -13,6 +13,7 @@ #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" #include "mlir/Conversion/SCFToStandard/SCFToStandard.h" #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" @@ -93,48 +94,6 @@ } }; -// ReshapeOp creates a new view descriptor of the proper rank. -// For now, the only conversion supported is for target MemRef with static sizes -// and strides. -template -class ReshapeOpConversion : public ConvertOpToLLVMPattern { -public: - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - using ReshapeOpAdaptor = typename ReshapeOp::Adaptor; - - LogicalResult - matchAndRewrite(ReshapeOp reshapeOp, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - MemRefType dstType = reshapeOp.getResultType(); - - if (!dstType.hasStaticShape()) - return failure(); - - int64_t offset; - SmallVector strides; - auto res = getStridesAndOffset(dstType, strides, offset); - if (failed(res) || llvm::any_of(strides, [](int64_t val) { - return ShapedType::isDynamicStrideOrOffset(val); - })) - return failure(); - - ReshapeOpAdaptor adaptor(operands); - MemRefDescriptor baseDesc(adaptor.src()); - Location loc = reshapeOp->getLoc(); - auto desc = - MemRefDescriptor::undef(rewriter, reshapeOp->getLoc(), - this->typeConverter->convertType(dstType)); - desc.setAllocatedPtr(rewriter, loc, baseDesc.allocatedPtr(rewriter, loc)); - desc.setAlignedPtr(rewriter, loc, baseDesc.alignedPtr(rewriter, loc)); - desc.setOffset(rewriter, loc, baseDesc.offset(rewriter, loc)); - for (auto en : llvm::enumerate(dstType.getShape())) - desc.setConstantSize(rewriter, loc, en.index(), en.value()); - for (auto en : llvm::enumerate(strides)) - desc.setConstantStride(rewriter, loc, en.index(), en.value()); - rewriter.replaceOp(reshapeOp, {desc}); - return success(); - } -}; // YieldOp produces and LLVM::ReturnOp. class YieldOpConversion : public ConvertOpToLLVMPattern { @@ -153,9 +112,7 @@ /// Populate the given list with patterns that convert from Linalg to LLVM. void mlir::populateLinalgToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns) { - patterns.add, - ReshapeOpConversion, YieldOpConversion>( - converter); + patterns.add(converter); // Populate the type conversions for the linalg types. converter.addConversion( @@ -176,6 +133,7 @@ RewritePatternSet patterns(&getContext()); LLVMTypeConverter converter(&getContext()); populateLinalgToLLVMConversionPatterns(converter, patterns); + populateMemRefToLLVMConversionPatterns(converter, patterns); LLVMConversionTarget target(getContext()); target.addIllegalOp(); diff --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp --- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp +++ b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp @@ -186,9 +186,7 @@ ConversionTarget target(getContext()); target.addLegalDialect(); - target.addLegalOp(); - target.addLegalOp(); + target.addLegalOp(); RewritePatternSet patterns(&getContext()); populateLinalgToStandardConversionPatterns(patterns); if (failed(applyFullConversion(module, target, std::move(patterns)))) diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -1000,6 +1000,49 @@ } }; +// ReshapeOp creates a new view descriptor of the proper rank. +// For now, the only conversion supported is for target MemRef with static sizes +// and strides. +template +class ReassociatingReshapeOpConversion + : public ConvertOpToLLVMPattern { +public: + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + using ReshapeOpAdaptor = typename ReshapeOp::Adaptor; + + LogicalResult + matchAndRewrite(ReshapeOp reshapeOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + MemRefType dstType = reshapeOp.getResultType(); + + if (!dstType.hasStaticShape()) + return failure(); + + int64_t offset; + SmallVector strides; + auto res = getStridesAndOffset(dstType, strides, offset); + if (failed(res) || llvm::any_of(strides, [](int64_t val) { + return ShapedType::isDynamicStrideOrOffset(val); + })) + return failure(); + + ReshapeOpAdaptor adaptor(operands); + MemRefDescriptor baseDesc(adaptor.src()); + Location loc = reshapeOp->getLoc(); + auto desc = + MemRefDescriptor::undef(rewriter, reshapeOp->getLoc(), + this->typeConverter->convertType(dstType)); + desc.setAllocatedPtr(rewriter, loc, baseDesc.allocatedPtr(rewriter, loc)); + desc.setAlignedPtr(rewriter, loc, baseDesc.alignedPtr(rewriter, loc)); + desc.setOffset(rewriter, loc, baseDesc.offset(rewriter, loc)); + for (auto en : llvm::enumerate(dstType.getShape())) + desc.setConstantSize(rewriter, loc, en.index(), en.value()); + for (auto en : llvm::enumerate(strides)) + desc.setConstantStride(rewriter, loc, en.index(), en.value()); + rewriter.replaceOp(reshapeOp, {desc}); + return success(); + } +}; /// Conversion pattern that transforms a subview op into: /// 1. An `llvm.mlir.undef` operation to create a memref descriptor /// 2. Updates to the descriptor to introduce the data ptr, offset, size @@ -1355,6 +1398,8 @@ MemRefReinterpretCastOpLowering, MemRefReshapeOpLowering, PrefetchOpLowering, + ReassociatingReshapeOpConversion, + ReassociatingReshapeOpConversion, StoreOpLowering, SubViewOpLowering, TransposeOpLowering, diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/ReshapeOpsUtils.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/Matchers.h" @@ -1103,14 +1104,6 @@ // ReshapeOp //===----------------------------------------------------------------------===// -static void print(OpAsmPrinter &p, linalg::ExpandShapeOp op) { - ::mlir::printReshapeOp(p, op); -} - -static void print(OpAsmPrinter &p, linalg::CollapseShapeOp op) { - ::mlir::printReshapeOp(p, op); -} - static void print(OpAsmPrinter &p, linalg::TensorExpandShapeOp op) { ::mlir::printReshapeOp(p, op); } @@ -1260,20 +1253,6 @@ return reassociationMaps; } -SmallVector CollapseShapeOp::getReassociationMaps() { - return getSymbolLessAffineMaps(getReassociationExprs()); -} -SmallVector CollapseShapeOp::getReassociationExprs() { - OpBuilder b(this->getContext()); - return convertReassociationIndicesToExprs(b, getReassociationIndices()); -} -SmallVector ExpandShapeOp::getReassociationMaps() { - return getSymbolLessAffineMaps(getReassociationExprs()); -} -SmallVector ExpandShapeOp::getReassociationExprs() { - OpBuilder b(this->getContext()); - return convertReassociationIndicesToExprs(b, getReassociationIndices()); -} SmallVector TensorCollapseShapeOp::getReassociationMaps() { return getSymbolLessAffineMaps(getReassociationExprs()); @@ -1422,71 +1401,6 @@ return b.getArrayAttr(reassociationAttr); } -void mlir::linalg::ExpandShapeOp::build( - OpBuilder &b, OperationState &result, Value src, - ArrayRef reassociation, - ArrayRef attrs) { - auto memRefType = src.getType().cast(); - auto resultType = computeReshapeCollapsedType( - memRefType, getSymbolLessAffineMaps( - convertReassociationIndicesToExprs(b, reassociation))); - build(b, result, resultType, src, attrs); - result.addAttribute(getReassociationAttrName(), - getReassociationIndicesAttribute(b, reassociation)); -} - -Value mlir::linalg::ExpandShapeOp::getViewSource() { return src(); } - -void mlir::linalg::CollapseShapeOp::build( - OpBuilder &b, OperationState &result, Value src, - ArrayRef reassociation, - ArrayRef attrs) { - auto memRefType = src.getType().cast(); - auto resultType = computeReshapeCollapsedType( - memRefType, getSymbolLessAffineMaps( - convertReassociationIndicesToExprs(b, reassociation))); - build(b, result, resultType, src, attrs); - result.addAttribute(getReassociationAttrName(), - getReassociationIndicesAttribute(b, reassociation)); -} - -Value mlir::linalg::CollapseShapeOp::getViewSource() { return src(); } - -template ::value> -static LogicalResult verifyReshapeOp(ReshapeOp op, MemRefType expandedType, - MemRefType collapsedType) { - if (failed( - verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion))) - return failure(); - auto maps = op.getReassociationMaps(); - MemRefType expectedType = computeReshapeCollapsedType(expandedType, maps); - if (collapsedType != expectedType) - return op.emitOpError("expected collapsed type to be ") - << expectedType << ", but got " << collapsedType; - return success(); -} - -static LogicalResult verify(ExpandShapeOp op) { - return verifyReshapeOp(op, op.getResultType(), op.getSrcType()); -} - -void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add, - CollapseMixedReshapeOps>(context); -} - -static LogicalResult verify(CollapseShapeOp op) { - return verifyReshapeOp(op, op.getSrcType(), op.getResultType()); -} - -void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add, - CollapseMixedReshapeOps>(context); -} - //===----------------------------------------------------------------------===// // TensorReshapeOp //===----------------------------------------------------------------------===// @@ -2433,16 +2347,6 @@ // TODO: Consider making all this boilerplate easy to autogenerate // with Tablegen. This seems a desirable property in the context of // OpInterfaces where a Linalg "named" op **isa** LinalgOp. -OpFoldResult ExpandShapeOp::fold(ArrayRef operands) { - if (succeeded(foldMemRefCast(*this))) - return getResult(); - return foldReshapeOp(*this, operands); -} -OpFoldResult CollapseShapeOp::fold(ArrayRef operands) { - if (succeeded(foldMemRefCast(*this))) - return getResult(); - return foldReshapeOp(*this, operands); -} OpFoldResult TensorExpandShapeOp::fold(ArrayRef operands) { return foldReshapeOp(*this, operands); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp @@ -155,8 +155,8 @@ public: using OpConversionPattern::OpConversionPattern; using ReshapeOp = typename std::conditional_t< - std::is_same::value, ExpandShapeOp, - CollapseShapeOp>; + std::is_same::value, + memref::ExpandShapeOp, memref::CollapseShapeOp>; LogicalResult matchAndRewrite(TensorReshapeOp op, ArrayRef operands, diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -352,7 +352,7 @@ convertAffineMapArrayToExprs(reassociationMap)); } if (origResultType.isa()) { - return rewriter.create( + return rewriter.create( loc, origResultType, result, convertAffineMapArrayToExprs(reassociationMap)); } @@ -368,7 +368,7 @@ if (operandType == newInputOutputType) return operand; if (operandType.isa()) { - return rewriter.create( + return rewriter.create( loc, newInputOutputType, operand, convertAffineMapArrayToExprs(reassociationMap)); } diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1300,6 +1300,189 @@ return success(); } +//===----------------------------------------------------------------------===// +// Reassociative reshape ops +//===----------------------------------------------------------------------===// + +SmallVector CollapseShapeOp::getReassociationMaps() { + return getSymbolLessAffineMaps(getReassociationExprs()); +} +SmallVector CollapseShapeOp::getReassociationExprs() { + OpBuilder b(this->getContext()); + return convertReassociationIndicesToExprs(b, getReassociationIndices()); +} + +SmallVector ExpandShapeOp::getReassociationMaps() { + return getSymbolLessAffineMaps(getReassociationExprs()); +} +SmallVector ExpandShapeOp::getReassociationExprs() { + OpBuilder b(this->getContext()); + return convertReassociationIndicesToExprs(b, getReassociationIndices()); +} + +static void print(OpAsmPrinter &p, ExpandShapeOp op) { + ::mlir::printReshapeOp(p, op); +} + +static void print(OpAsmPrinter &p, CollapseShapeOp op) { + ::mlir::printReshapeOp(p, op); +} + +/// Detect whether memref dims [dim, dim + extent) can be reshaped without +/// copies. +static bool isReshapableDimBand(unsigned dim, unsigned extent, + ArrayRef sizes, + ArrayRef strides) { + assert(sizes.size() == strides.size() && "mismatched ranks"); + // off by 1 indexing to avoid out of bounds + // V + for (auto idx = dim, e = dim + extent; idx + 1 < e; ++idx) { + // Only bands of static shapes are reshapable. This is due to the fact that + // there is no relation between dynamic sizes and dynamic strides: we do not + // have enough information to know whether a "-1" size corresponds to the + // proper symbol in the AffineExpr of a stride. + if (ShapedType::isDynamic(sizes[dim + 1])) + return false; + // TODO: Refine this by passing the proper nDims and nSymbols so we can + // simplify on the fly and catch more reshapable cases. + if (strides[idx] != strides[idx + 1] * sizes[idx + 1]) + return false; + } + return true; +} + +/// Compute the MemRefType obtained by applying the `reassociation` (which is +/// expected to be valid) to `type`. +/// If `type` is Contiguous MemRefType, this always produce a contiguous +/// MemRefType. +static MemRefType +computeReshapeCollapsedType(MemRefType type, + ArrayRef reassociation) { + auto sizes = type.getShape(); + AffineExpr offset; + SmallVector strides; + auto status = getStridesAndOffset(type, strides, offset); + (void)status; + assert(succeeded(status) && "expected strided memref"); + + SmallVector newSizes; + newSizes.reserve(reassociation.size()); + SmallVector newStrides; + newStrides.reserve(reassociation.size()); + + // Use the fact that reassociation is valid to simplify the logic: only use + // each map's rank. + assert(isReassociationValid(reassociation) && "invalid reassociation"); + unsigned currentDim = 0; + for (AffineMap m : reassociation) { + unsigned dim = m.getNumResults(); + int64_t size = 1; + AffineExpr stride = strides[currentDim + dim - 1]; + if (!isReshapableDimBand(currentDim, dim, sizes, strides)) { + size = ShapedType::kDynamicSize; + stride = AffineExpr(); + } else { + for (unsigned d = 0; d < dim; ++d) + size *= sizes[currentDim + d]; + } + newSizes.push_back(size); + newStrides.push_back(stride); + currentDim += dim; + } + + // Early-exit: if `type` is contiguous, the result must be contiguous. + if (canonicalizeStridedLayout(type).getAffineMaps().empty()) + return MemRefType::Builder(type).setShape(newSizes).setAffineMaps({}); + + // Convert back to int64_t because we don't have enough information to create + // new strided layouts from AffineExpr only. This corresponds to a case where + // copies may be necessary. + int64_t intOffset = ShapedType::kDynamicStrideOrOffset; + if (auto o = offset.dyn_cast()) + intOffset = o.getValue(); + SmallVector intStrides; + intStrides.reserve(strides.size()); + for (auto stride : newStrides) { + if (auto cst = stride.dyn_cast_or_null()) + intStrides.push_back(cst.getValue()); + else + intStrides.push_back(ShapedType::kDynamicStrideOrOffset); + } + auto layout = + makeStridedLinearLayoutMap(intStrides, intOffset, type.getContext()); + return canonicalizeStridedLayout( + MemRefType::Builder(type).setShape(newSizes).setAffineMaps({layout})); +} + +void ExpandShapeOp::build(OpBuilder &b, OperationState &result, Value src, + ArrayRef reassociation, + ArrayRef attrs) { + auto memRefType = src.getType().cast(); + auto resultType = computeReshapeCollapsedType( + memRefType, getSymbolLessAffineMaps( + convertReassociationIndicesToExprs(b, reassociation))); + build(b, result, resultType, src, attrs); + result.addAttribute(getReassociationAttrName(), + getReassociationIndicesAttribute(b, reassociation)); +} + +void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src, + ArrayRef reassociation, + ArrayRef attrs) { + auto memRefType = src.getType().cast(); + auto resultType = computeReshapeCollapsedType( + memRefType, getSymbolLessAffineMaps( + convertReassociationIndicesToExprs(b, reassociation))); + build(b, result, resultType, src, attrs); + result.addAttribute(getReassociationAttrName(), + getReassociationIndicesAttribute(b, reassociation)); +} + +template ::value> +static LogicalResult verifyReshapeOp(ReshapeOp op, MemRefType expandedType, + MemRefType collapsedType) { + if (failed( + verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion))) + return failure(); + auto maps = op.getReassociationMaps(); + MemRefType expectedType = computeReshapeCollapsedType(expandedType, maps); + if (collapsedType != expectedType) + return op.emitOpError("expected collapsed type to be ") + << expectedType << ", but got " << collapsedType; + return success(); +} + +static LogicalResult verify(ExpandShapeOp op) { + return verifyReshapeOp(op, op.getResultType(), op.getSrcType()); +} + +void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add, + CollapseMixedReshapeOps>(context); +} + +static LogicalResult verify(CollapseShapeOp op) { + return verifyReshapeOp(op, op.getSrcType(), op.getResultType()); +} + +void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add, + CollapseMixedReshapeOps>(context); +} +OpFoldResult ExpandShapeOp::fold(ArrayRef operands) { + if (succeeded(foldMemRefCast(*this))) + return getResult(); + return foldReshapeOp(*this, operands); +} +OpFoldResult CollapseShapeOp::fold(ArrayRef operands) { + if (succeeded(foldMemRefCast(*this))) + return getResult(); + return foldReshapeOp(*this, operands); +} + //===----------------------------------------------------------------------===// // ReshapeOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp --- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp +++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp @@ -15,8 +15,6 @@ using namespace mlir; -constexpr StringRef mlir::getReassociationAttrName() { return "reassociation"; } - Optional> mlir::getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType) { @@ -183,6 +181,70 @@ return composedIndices; } +SmallVector, 2> +mlir::convertReassociationIndicesToExprs( + OpBuilder &b, ArrayRef reassociationIndices) { + SmallVector, 2> reassociationMaps; + for (const auto &indices : reassociationIndices) { + SmallVector reassociationMap; + reassociationMap.reserve(indices.size()); + for (int64_t index : indices) + reassociationMap.push_back(b.getAffineDimExpr(index)); + reassociationMaps.push_back(std::move(reassociationMap)); + } + return reassociationMaps; +} + +template +unsigned getMaxPosOfType(ArrayRef exprArrays) { + unsigned pos = 0; + for (const auto &exprs : exprArrays) { + for (auto expr : exprs) { + expr.walk([&pos](AffineExpr e) { + if (auto d = e.dyn_cast()) + pos = std::max(pos, d.getPosition()); + }); + } + } + return pos; +} + +ArrayAttr mlir::getReassociationIndicesAttribute( + OpBuilder &b, ArrayRef reassociation) { + SmallVector reassociationAttr = + llvm::to_vector<4>(llvm::map_range( + reassociation, [&](ReassociationIndices indices) -> Attribute { + return b.getI64ArrayAttr(indices).cast(); + })); + return b.getArrayAttr(reassociationAttr); +} + +SmallVector mlir::convertReassociationMapsToIndices( + OpBuilder &b, ArrayRef reassociationExprs) { + SmallVector reassociationIndices; + for (const auto &exprs : reassociationExprs) { + ReassociationIndices indices; + indices.reserve(exprs.size()); + for (const auto &expr : exprs) + indices.push_back(expr.cast().getPosition()); + reassociationIndices.push_back(indices); + } + return reassociationIndices; +} + +SmallVector +mlir::getSymbolLessAffineMaps(ArrayRef reassociation) { + unsigned maxDim = getMaxPosOfType(reassociation); + assert(getMaxPosOfType(reassociation) == 0 && + "Expected symbol-less expressions"); + SmallVector maps; + maps.reserve(reassociation.size()); + for (const auto &exprs : reassociation) { + assert(!exprs.empty()); + maps.push_back(AffineMap::get(maxDim + 1, 0, exprs, exprs[0].getContext())); + } + return maps; +} bool mlir::isReassociationValid(ArrayRef reassociation, int *invalidIndex) { if (reassociation.empty()) diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir --- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir @@ -698,3 +698,105 @@ return } +// ----- + +func @expand_shape_static(%arg0: memref<3x4x5xf32>) -> memref<1x3x4x1x5xf32> { + // Reshapes that expand a contiguous tensor with some 1's. + %0 = memref.expand_shape %arg0 [[0, 1], [2], [3, 4]] + : memref<3x4x5xf32> into memref<1x3x4x1x5xf32> + return %0 : memref<1x3x4x1x5xf32> +} +// CHECK-LABEL: func @expand_shape_static +// CHECK: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> +// CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> +// CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> +// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> +// CHECK: llvm.mlir.constant(1 : index) : i64 +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> +// CHECK: llvm.mlir.constant(3 : index) : i64 +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> +// CHECK: llvm.mlir.constant(4 : index) : i64 +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 2] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> +// CHECK: llvm.mlir.constant(1 : index) : i64 +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 3] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> +// CHECK: llvm.mlir.constant(5 : index) : i64 +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 4] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> +// CHECK: llvm.mlir.constant(60 : index) : i64 +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> +// CHECK: llvm.mlir.constant(20 : index) : i64 +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> +// CHECK: llvm.mlir.constant(5 : index) : i64 +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 2] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> +// CHECK: llvm.mlir.constant(5 : index) : i64 +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 3] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> +// CHECK: llvm.mlir.constant(1 : index) : i64 +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 4] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> + +// ----- + +func @collapse_shape_static(%arg0: memref<1x3x4x1x5xf32>) -> memref<3x4x5xf32> { + %0 = memref.collapse_shape %arg0 [[0, 1], [2], [3, 4]] : + memref<1x3x4x1x5xf32> into memref<3x4x5xf32> + return %0 : memref<3x4x5xf32> +} +// CHECK-LABEL: func @collapse_shape_static +// CHECK: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.mlir.constant(3 : index) : i64 +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.mlir.constant(4 : index) : i64 +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.mlir.constant(5 : index) : i64 +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.mlir.constant(20 : index) : i64 +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.mlir.constant(5 : index) : i64 +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.mlir.constant(1 : index) : i64 +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> + +// ----- + +func @collapse_shape_fold_zero_dim(%arg0 : memref<1x1xf32>) -> memref { + %0 = memref.collapse_shape %arg0 [] : memref<1x1xf32> into memref + return %0 : memref +} +// CHECK-LABEL: func @collapse_shape_fold_zero_dim +// CHECK: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64)> +// CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64)> +// CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64)> +// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64)> + +// ----- + +func @expand_shape_zero_dim(%arg0 : memref) -> memref<1x1xf32> { + %0 = memref.expand_shape %arg0 [] : memref into memref<1x1xf32> + return %0 : memref<1x1xf32> +} +// CHECK-LABEL: func @expand_shape_zero_dim +// CHECK: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64)> +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64)> +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64)> +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: llvm.mlir.constant(1 : index) : i64 +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: llvm.mlir.constant(1 : index) : i64 +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: llvm.mlir.constant(1 : index) : i64 +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: llvm.mlir.constant(1 : index) : i64 +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> diff --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir --- a/mlir/test/Dialect/Linalg/bufferize.mlir +++ b/mlir/test/Dialect/Linalg/bufferize.mlir @@ -261,7 +261,7 @@ return %out : tensor<20xf32> } // CHECK: %[[MEMREF:.*]] = memref.buffer_cast %[[IN]] : memref<4x5xf32> -// CHECK: %[[RESHAPE:.*]] = linalg.collapse_shape %[[MEMREF]] {{\[}}[0, 1]] +// CHECK: %[[RESHAPE:.*]] = memref.collapse_shape %[[MEMREF]] {{\[}}[0, 1]] // CHECK-SAME: : memref<4x5xf32> into memref<20xf32> // CHECK: %[[TENSOR:.*]] = memref.tensor_load %[[RESHAPE]] : memref<20xf32> // CHECK: return %[[TENSOR]] diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -81,19 +81,6 @@ // ----- -func @collapsing_memref_reshapes_to_zero_dim(%arg0 : memref<1x1x1xf32>) - -> memref { - %0 = linalg.collapse_shape %arg0 [[0, 1, 2]] - : memref<1x1x1xf32> into memref<1xf32> - %1 = linalg.collapse_shape %0 [] : memref<1xf32> into memref - return %1 : memref -} -// CHECK-LABEL: collapsing_memref_reshapes_to_zero -// CHECK: linalg.collapse_shape %{{.*}} [] -// CHECK-SAME: memref<1x1x1xf32> into memref - -// ----- - func @expanding_tensor_reshapes(%arg0 : tensor) -> tensor { %0 = linalg.tensor_expand_shape %arg0 [[0, 1], [2]] @@ -108,34 +95,6 @@ // ----- -func @collapsing_memref_reshapes(%arg0 : memref) -> memref -{ - %0 = linalg.collapse_shape %arg0 [[0, 1], [2], [3, 4]] - : memref into memref - %1 = linalg.collapse_shape %0 [[0, 1], [2]] - : memref into memref - return %1 : memref -} -// CHECK-LABEL: collapsing_memref_reshapes -// CHECK: linalg.collapse_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]] -// CHECK-NOT: linalg.collapse_shape - -// ----- - -func @expanding_memref_reshapes(%arg0 : memref) -> memref -{ - %0 = linalg.expand_shape %arg0 [[0, 1], [2]] - : memref into memref - %1 = linalg.expand_shape %0 [[0, 1], [2], [3, 4]] - : memref into memref - return %1 : memref -} -// CHECK-LABEL: expanding_memref_reshapes -// CHECK: linalg.expand_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]] -// CHECK-NOT: linalg.expand_shape - -// ----- - func @expanding_tensor_reshapes_to_zero_dim(%arg0 : tensor) -> tensor<1x1x1xf32> { %0 = linalg.tensor_expand_shape %arg0 [] : tensor into tensor<1xf32> @@ -149,19 +108,6 @@ // ----- -func @expanding_memref_reshapes_to_zero_dim(%arg0 : memref) - -> memref<1x1x1xf32> { - %0 = linalg.expand_shape %arg0 [] : memref into memref<1xf32> - %1 = linalg.expand_shape %0 [[0, 1, 2]] - : memref<1xf32> into memref<1x1x1xf32> - return %1 : memref<1x1x1xf32> -} -// CHECK-LABEL: expanding_memref_reshapes_to_zero -// CHECK: linalg.expand_shape %{{.*}} [] -// CHECK-SAME: memref into memref<1x1x1xf32> - -// ----- - func @fold_tensor_reshape(%arg0 : tensor<12x4xf32>) -> tensor<12x4xf32> { %0 = linalg.tensor_expand_shape %arg0 [[0, 1], [2]] @@ -188,32 +134,6 @@ // ----- -func @fold_memref_reshape(%arg0 : memref<12x4xf32>) -> memref<12x4xf32> -{ - %0 = linalg.expand_shape %arg0 [[0, 1], [2]] - : memref<12x4xf32> into memref<3x4x4xf32> - %1 = linalg.collapse_shape %0 [[0, 1], [2]] - : memref<3x4x4xf32> into memref<12x4xf32> - return %1 : memref<12x4xf32> -} -// CHECK-LABEL: @fold_memref_reshape -// CHECK-NOT: linalg.{{.*}}_shape - -// ----- - -func @fold_memref_reshape_dynamic(%arg0 : memref) -> memref -{ - %0 = linalg.expand_shape %arg0 [[0, 1], [2]] - : memref into memref - %1 = linalg.collapse_shape %0 [[0, 1], [2]] - : memref into memref - return %1 : memref -} -// CHECK-LABEL: @fold_memref_reshape_dynamic -// CHECK-NOT: linalg.{{.*}}_shape - -// ----- - func @reshape_collapse(%arg0 : tensor<2x3x4x5x6x7x8xf32>) -> tensor<24x5x42x8xf32> { %0 = linalg.tensor_collapse_shape %arg0 [[0, 1, 2, 3, 4, 5, 6]] diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir --- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir +++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir @@ -479,7 +479,7 @@ // CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> ()> // CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> // CHECK-LABEL: func @drop_one_trip_loops -// CHECK: linalg.collapse_shape %{{.*}} {{\[}}[0, 1], [2]] +// CHECK: memref.collapse_shape %{{.*}} {{\[}}[0, 1], [2]] // CHECK: linalg.generic // CHECK-SAME: indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP3]]] // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] @@ -556,7 +556,7 @@ } // CHECK: #[[$MAP0:.*]] = affine_map<() -> ()> // CHECK-LABEL: func @drop_all_loops -// CHECK: linalg.collapse_shape %{{.*}} [] +// CHECK: memref.collapse_shape %{{.*}} [] // CHECK: linalg.generic // CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP0]]] // CHECK-SAME: iterator_types = [] @@ -617,7 +617,7 @@ // CHECK: #[[$MAP1:.*]] = affine_map<(d0) -> (d0)> // CHECK-LABEL: func @leading_dim_1_canonicalization -// CHECK: linalg.collapse_shape %{{.*}} {{\[}}[0, 1]] +// CHECK: memref.collapse_shape %{{.*}} {{\[}}[0, 1]] // CHECK: linalg.generic // CHECK-SAME: indexing_maps = [#[[$MAP1]], #[[$MAP1]]] // CHECK-SAME: iterator_types = ["parallel"] @@ -638,8 +638,8 @@ func @broadcast_test(%arg0 : memref<5xf32>, %arg1 : memref<5xf32>, %shape : memref<5x5xf32>) -> memref<5x5xf32> { - %0 = linalg.expand_shape %arg0 [[0, 1]] : memref<5xf32> into memref<1x5xf32> - %1 = linalg.expand_shape %arg1 [[0, 1]] : memref<5xf32> into memref<5x1xf32> + %0 = memref.expand_shape %arg0 [[0, 1]] : memref<5xf32> into memref<1x5xf32> + %1 = memref.expand_shape %arg1 [[0, 1]] : memref<5xf32> into memref<5x1xf32> linalg.generic #trait ins(%0, %1 : memref<1x5xf32>, memref<5x1xf32>) outs(%shape : memref<5x5xf32>) { @@ -686,7 +686,7 @@ // CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: func @broadcast_scalar // CHECK-SAME: %[[ARG0:.*]]: memref<1x1xf32> -// CHECK: %[[A:.*]] = linalg.collapse_shape %[[ARG0]] [] +// CHECK: %[[A:.*]] = memref.collapse_shape %[[ARG0]] [] // CHECK-SAME: memref<1x1xf32> into memref // CHECK: linalg.generic // CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]] @@ -706,16 +706,16 @@ ^bb0(%arg1: f32, %arg2: f32): // no predecessors linalg.yield %arg1 : f32 } - %3 = linalg.collapse_shape %1 [[0, 1], [2]] + %3 = memref.collapse_shape %1 [[0, 1], [2]] : memref<1x2x5xf32> into memref<2x5xf32> return %3 : memref<2x5xf32> } // CHECK-LABEL: func @fold_unit_dim_memref_reshape_op // CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<1x2x5xf32> -// CHECK: %[[OUT:.*]] = linalg.collapse_shape %[[ALLOC]] +// CHECK: %[[OUT:.*]] = memref.collapse_shape %[[ALLOC]] // CHECK: linalg.generic // CHECK-SAME: outs(%[[OUT:.*]] : -// CHECK: %[[RESULT:.*]] = linalg.collapse_shape %[[ALLOC]] +// CHECK: %[[RESULT:.*]] = memref.collapse_shape %[[ALLOC]] // CHECK: return %[[RESULT]] // ----- @@ -740,8 +740,8 @@ // CHECK: func @fold_unit_dim_for_init_memref // CHECK: %[[INIT:.+]] = memref.alloc() : memref<1xf32> -// CHECK: %[[INPUT_RESHAPE:.+]] = linalg.collapse_shape %{{.+}} {{\[}}[0, 1]] : memref<1x1000xf32> into memref<1000xf32> -// CHECK: %[[INIT_RESHAPE:.+]] = linalg.collapse_shape %[[INIT]] [] : memref<1xf32> into memref +// CHECK: %[[INPUT_RESHAPE:.+]] = memref.collapse_shape %{{.+}} {{\[}}[0, 1]] : memref<1x1000xf32> into memref<1000xf32> +// CHECK: %[[INIT_RESHAPE:.+]] = memref.collapse_shape %[[INIT]] [] : memref<1xf32> into memref // CHECK: linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP1]], #[[MAP2]]] // CHECK-SAME: iterator_types = ["reduction"] diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -308,58 +308,6 @@ // ----- -func @reshape(%arg0: memref) { - // expected-error @+1 {{expected non-zero memref ranks}} - %0 = linalg.expand_shape %arg0 [[0]] : memref into memref -} - -// ----- - -func @collapse_to_higher_rank(%arg0: memref) { - // expected-error @+1 {{expected the type 'memref' to have higher rank than the type = 'memref<1xf32>'}} - %0 = linalg.collapse_shape %arg0 [[0]] : memref into memref<1xf32> -} - -// ----- - -func @expand_to_smaller_rank(%arg0: memref<1xf32>) { - // expected-error @+1 {{expected the type 'memref' to have higher rank than the type = 'memref<1xf32>'}} - %0 = linalg.expand_shape %arg0 [[0]] : memref<1xf32> into memref -} - -// ----- - -func @reshape(%arg0: memref) { - // expected-error @+1 {{expected to collapse or expand dims}} - %0 = linalg.collapse_shape %arg0 [[0]] : memref into memref -} - -// ----- - -func @reshape(%arg0: memref) { - // expected-error @+1 {{expected rank of the collapsed type(2) to be the number of reassociation maps(1)}} - %0 = linalg.collapse_shape %arg0 [[0, 1]] : - memref into memref -} - -// ----- - -func @reshape(%arg0: memref) { - // expected-error @+1 {{expected reassociation map #1 to be valid and contiguous}} - %0 = linalg.collapse_shape %arg0 [[0, 1], [1, 2]] : - memref into memref -} - -// ----- - -func @reshape(%arg0: memref) { - // expected-error @+1 {{expected collapsed type to be 'memref', but got 'memref (d0 * s0 + d1)>>'}} - %0 = linalg.collapse_shape %arg0 [[0, 1], [2]] : - memref into memref (d0 * s0 + d1)>> -} - -// ----- - func @pooling_rank_mismatch(%arg0: memref, %arg1: memref<2x3xf32>, %arg2: memref) { @@ -397,7 +345,6 @@ return } - // ----- func @init_tensor_err(%arg0 : index, %arg1 : index) @@ -438,16 +385,6 @@ // ----- -func @illegal_expanding_reshape_dynamic_memref - (%arg0: memref) -> memref -{ - // expected-error @+1 {{invalid to have a single dimension (2) expanded into multiple dynamic dims (2,4)}} - %0 = linalg.expand_shape %arg0 [[0], [1], [2, 3, 4]] - : memref into memref - return %0 : memref -} - -// ----- func @illegal_expanding_reshape_static_tensor (%arg0: tensor<2x3x20xf32>) -> tensor<2x3x2x4x5xf32> @@ -471,28 +408,6 @@ // ----- -func @illegal_expanding_reshape_static_memref - (%arg0: memref<2x3x20xf32>) -> memref<2x3x2x4x5xf32> -{ - // expected-error @+1 {{expected dimension 2 of collapsed type to be static value of 40}} - %0 = linalg.expand_shape %arg0 [[0], [1], [2, 3, 4]] - : memref<2x3x20xf32> into memref<2x3x2x4x5xf32> - return %0 : memref<2x3x2x4x5xf32> -} - -// ----- - -func @illegal_collapsing_reshape_static_memref - (%arg0: memref<2x3x2x4x5xf32>) -> memref<2x3x20xf32> -{ - // expected-error @+1 {{expected dimension 2 of collapsed type to be static value of 40}} - %0 = linalg.collapse_shape %arg0 [[0], [1], [2, 3, 4]] - : memref<2x3x2x4x5xf32> into memref<2x3x20xf32> - return %0 : memref<2x3x20xf32> -} - -// ----- - func @illegal_expanding_reshape_mixed_tensor(%arg0 : tensor) -> tensor { // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 5}} @@ -533,46 +448,6 @@ // ----- -func @illegal_expanding_reshape_mixed_memref(%arg0 : memref) -> memref -{ - // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 5}} - %0 = linalg.expand_shape %arg0 [[0, 1], [2]] - : memref into memref - return %0 : memref -} - -// ----- - -func @illegal_expanding_reshape_mixed_memref_2(%arg0 : memref) -> memref -{ - // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 20}} - %0 = linalg.expand_shape %arg0 [[0], [1, 2]] - : memref into memref - return %0 : memref -} - -// ----- - -func @illegal_collapsing_reshape_mixed_memref(%arg0 : memref) -> memref -{ - // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 5}} - %0 = linalg.collapse_shape %arg0 [[0, 1], [2]] - : memref into memref - return %0 : memref -} - -// ----- - -func @illegal_collapse_reshape_mixed_memref_2(%arg0 : memref) -> memref -{ - // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 20}} - %0 = linalg.collapse_shape %arg0 [[0], [1, 2]] - : memref into memref - return %0 : memref -} - -// ----- - func @pad_result_type(%arg0: tensor, %arg1: index, %arg2: i32) -> tensor { // expected-error @+1 {{specified type 'tensor' does not match the inferred type 'tensor}} %0 = linalg.pad_tensor %arg0 low[1, %arg1, 2, 2] high[1, 2, %arg1, 3] { @@ -824,6 +699,6 @@ linalg.generic #attrs ins(%A: memref<5xf32>) outs(%B: memref<5xf32>) { ^bb0(%a: f32, %b: f32): linalg.yield %a : f32 - } + } return } diff --git a/mlir/test/Dialect/Linalg/llvm.mlir b/mlir/test/Dialect/Linalg/llvm.mlir --- a/mlir/test/Dialect/Linalg/llvm.mlir +++ b/mlir/test/Dialect/Linalg/llvm.mlir @@ -13,98 +13,3 @@ // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(i64, i64, i64)> // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(i64, i64, i64)> // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(i64, i64, i64)> - -func @expand_shape_static(%arg0: memref<3x4x5xf32>) -> memref<1x3x4x1x5xf32> { - // Reshapes that expand a contiguous tensor with some 1's. - %0 = linalg.expand_shape %arg0 [[0, 1], [2], [3, 4]] - : memref<3x4x5xf32> into memref<1x3x4x1x5xf32> - return %0 : memref<1x3x4x1x5xf32> -} -// CHECK-LABEL: func @expand_shape_static -// CHECK: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> -// CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> -// CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> -// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> -// CHECK: llvm.mlir.constant(1 : index) : i64 -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> -// CHECK: llvm.mlir.constant(3 : index) : i64 -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> -// CHECK: llvm.mlir.constant(4 : index) : i64 -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 2] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> -// CHECK: llvm.mlir.constant(1 : index) : i64 -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 3] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> -// CHECK: llvm.mlir.constant(5 : index) : i64 -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 4] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> -// CHECK: llvm.mlir.constant(60 : index) : i64 -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> -// CHECK: llvm.mlir.constant(20 : index) : i64 -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> -// CHECK: llvm.mlir.constant(5 : index) : i64 -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 2] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> -// CHECK: llvm.mlir.constant(5 : index) : i64 -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 3] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> -// CHECK: llvm.mlir.constant(1 : index) : i64 -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 4] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> - -func @collapse_shape_static(%arg0: memref<1x3x4x1x5xf32>) -> memref<3x4x5xf32> { - %0 = linalg.collapse_shape %arg0 [[0, 1], [2], [3, 4]] : - memref<1x3x4x1x5xf32> into memref<3x4x5xf32> - return %0 : memref<3x4x5xf32> -} -// CHECK-LABEL: func @collapse_shape_static -// CHECK: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.mlir.constant(3 : index) : i64 -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.mlir.constant(4 : index) : i64 -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.mlir.constant(5 : index) : i64 -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.mlir.constant(20 : index) : i64 -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.mlir.constant(5 : index) : i64 -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.mlir.constant(1 : index) : i64 -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> - -func @collapse_shape_fold_zero_dim(%arg0 : memref<1x1xf32>) -> memref { - %0 = linalg.collapse_shape %arg0 [] : memref<1x1xf32> into memref - return %0 : memref -} -// CHECK-LABEL: func @collapse_shape_fold_zero_dim -// CHECK: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64)> -// CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64)> -// CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64)> -// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64)> - -func @expand_shape_zero_dim(%arg0 : memref) -> memref<1x1xf32> { - %0 = linalg.expand_shape %arg0 [] : memref into memref<1x1xf32> - return %0 : memref<1x1xf32> -} -// CHECK-LABEL: func @expand_shape_zero_dim -// CHECK: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64)> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64)> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64)> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: llvm.mlir.constant(1 : index) : i64 -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: llvm.mlir.constant(1 : index) : i64 -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: llvm.mlir.constant(1 : index) : i64 -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: llvm.mlir.constant(1 : index) : i64 -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -12,9 +12,7 @@ // CHECK-DAG: #[[$permute_1:.*]] = affine_map<(d0, d1, d2) -> (d2, d1, d0)> // CHECK-DAG: #[[$strided1D:.*]] = affine_map<(d0)[s0] -> (d0 + s0)> // CHECK-DAG: #[[$strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> -// CHECK-DAG: #[[$strided2DOFF0:.*]] = affine_map<(d0, d1)[s0] -> (d0 * s0 + d1)> // CHECK-DAG: #[[$strided3D:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)> -// CHECK-DAG: #[[$strided3DOFF0:.*]] = affine_map<(d0, d1, d2)[s0, s1] -> (d0 * s0 + d1 * s1 + d2)> // CHECK-DAG: #[[$strided3DT:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2 * s1 + s0 + d1 * s2 + d0)> // CHECK-DAG: #[[$strided6D:.*]] = affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3 * s4 + d4 * s5 + d5)> @@ -169,7 +167,6 @@ // ----- - func @fill_view(%arg0: memref, %arg1: f32) { linalg.fill(%arg1, %arg0) : f32, memref return @@ -541,96 +538,6 @@ // ----- -func @reshape_static(%arg0: memref<3x4x5xf32>, %arg1: tensor<3x4x5xf32>, - %arg2: tensor<3x?x5xf32>) { - // Reshapes that collapse and expand back a contiguous buffer. - %0 = linalg.collapse_shape %arg0 [[0, 1], [2]] : - memref<3x4x5xf32> into memref<12x5xf32> - %r0 = linalg.expand_shape %0 [[0, 1], [2]] : - memref<12x5xf32> into memref<3x4x5xf32> - %1 = linalg.collapse_shape %arg0 [[0], [1, 2]] : - memref<3x4x5xf32> into memref<3x20xf32> - %r1 = linalg.expand_shape %1 [[0], [1, 2]] : - memref<3x20xf32> into memref<3x4x5xf32> - %2 = linalg.collapse_shape %arg0 [[0, 1, 2]] : - memref<3x4x5xf32> into memref<60xf32> - %r2 = linalg.expand_shape %2 [[0, 1, 2]] : - memref<60xf32> into memref<3x4x5xf32> - // Reshapes that expand and collapse back a contiguous buffer with some 1's. - %3 = linalg.expand_shape %arg0 [[0, 1], [2], [3, 4]] : - memref<3x4x5xf32> into memref<1x3x4x1x5xf32> - %r3 = linalg.collapse_shape %3 [[0, 1], [2], [3, 4]] : - memref<1x3x4x1x5xf32> into memref<3x4x5xf32> - // Reshapes on tensors. - %t0 = linalg.tensor_expand_shape %arg1 [[0, 1], [2], [3, 4]] : - tensor<3x4x5xf32> into tensor<1x3x4x1x5xf32> - %rt0 = linalg.tensor_collapse_shape %t0 [[0, 1], [2], [3, 4]] : - tensor<1x3x4x1x5xf32> into tensor<3x4x5xf32> - %t1 = linalg.tensor_expand_shape %arg2 [[0, 1], [2], [3, 4]] : - tensor<3x?x5xf32> into tensor<1x3x?x1x5xf32> - %rt1 = linalg.tensor_collapse_shape %t1 [[0], [1, 2], [3, 4]] : - tensor<1x3x?x1x5xf32> into tensor<1x?x5xf32> - return -} -// CHECK-LABEL: func @reshape_static -// CHECK: linalg.collapse_shape {{.*}} {{\[}}[0, 1], [2]] -// CHECK-SAME: memref<3x4x5xf32> into memref<12x5xf32> -// CHECK: linalg.expand_shape {{.*}} {{\[}}[0, 1], [2]] -// CHECK-SAME: memref<12x5xf32> into memref<3x4x5xf32> -// CHECK: linalg.collapse_shape {{.*}} {{\[}}[0], [1, 2]] -// CHECK-SAME: memref<3x4x5xf32> into memref<3x20xf32> -// CHECK: linalg.expand_shape {{.*}} {{\[}}[0], [1, 2]] -// CHECK-SAME: memref<3x20xf32> into memref<3x4x5xf32> -// CHECK: linalg.collapse_shape {{.*}} {{\[}}[0, 1, 2]] -// CHECK-SAME: memref<3x4x5xf32> into memref<60xf32> -// CHECK: linalg.expand_shape {{.*}} {{\[}}[0, 1, 2]] -// CHECK-SAME: memref<60xf32> into memref<3x4x5xf32> -// CHECK: linalg.expand_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]] -// CHECK-SAME: memref<3x4x5xf32> into memref<1x3x4x1x5xf32> -// CHECK: linalg.collapse_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]] -// CHECK-SAME: memref<1x3x4x1x5xf32> into memref<3x4x5xf32> -// -// CHECK: linalg.tensor_expand_shape {{.*}}: tensor<3x4x5xf32> into tensor<1x3x4x1x5xf32> -// CHECK: linalg.tensor_collapse_shape {{.*}}: tensor<1x3x4x1x5xf32> into tensor<3x4x5xf32> -// CHECK: linalg.tensor_expand_shape {{.*}}: tensor<3x?x5xf32> into tensor<1x3x?x1x5xf32> -// CHECK: linalg.tensor_collapse_shape {{.*}}: tensor<1x3x?x1x5xf32> into tensor<1x?x5xf32> - -// ----- - -func @reshape_dynamic(%arg0: memref, - %arg1: memref, - %arg2: memref) { - %0 = linalg.collapse_shape %arg0 [[0, 1], [2]] : - memref into memref - %r0 = linalg.expand_shape %0 [[0, 1], [2]] : - memref into memref - %1 = linalg.collapse_shape %arg1 [[0, 1], [2]] : - memref into - memref - %r1 = linalg.expand_shape %1 [[0, 1], [2]] : - memref into - memref - %2 = linalg.collapse_shape %arg2 [[0, 1], [2]] : - memref into - memref - %r2 = linalg.expand_shape %2 [[0, 1], [2]] : - memref into - memref - return -} -// CHECK-LABEL: func @reshape -// CHECK: linalg.collapse_shape {{.*}} {{\[}}[0, 1], [2]] -// CHECK-SAME: memref into memref -// CHECK: linalg.expand_shape {{.*}} {{\[}}[0, 1], [2]] -// CHECK-SAME: memref into memref -// CHECK: linalg.collapse_shape {{.*}} {{\[}}[0, 1], [2]] -// CHECK-SAME: memref into memref -// CHECK: linalg.expand_shape {{.*}} {{\[}}[0, 1], [2]] -// CHECK-SAME: memref into memref -// CHECK: linalg.collapse_shape {{.*}} {{\[}}[0, 1], [2]] -// CHECK-SAME: memref into memref -// CHECK: linalg.expand_shape {{.*}} {{\[}}[0, 1], [2]] -// CHECK-SAME: memref into memref func @named_ops(%a3: memref, %b3: memref, %c3: memref, %ta3: tensor, %tb3: tensor, %tc3: tensor) @@ -670,17 +577,6 @@ // ----- -func @memref_reshape_zero_dim(%arg0 : memref<1x1xf32>, %arg1 : memref) -> (memref, memref<1x1xf32>) -{ - %0 = linalg.collapse_shape %arg0 [] : memref<1x1xf32> into memref - %1 = linalg.expand_shape %0 [] : memref into memref<1x1xf32> - return %0, %1 : memref, memref<1x1xf32> -} -// CHECK-LABEL: func @memref_reshape_zero_dim -// CHECK: linalg.collapse_shape %{{.*}} [] : memref<1x1xf32> into memref -// CHECK: linalg.expand_shape %{{.*}} [] : memref into memref<1x1xf32> - -// ----- func @init_tensor(%arg0 : index, %arg1 : index) { @@ -707,19 +603,6 @@ // ----- -func @legal_collapsing_reshape_dynamic_memref - (%arg0: memref) -> memref -{ - %0 = linalg.collapse_shape %arg0 [[0], [1], [2, 3, 4]] : - memref into memref - return %0 : memref -} -// CHECK: func @legal_collapsing_reshape_dynamic_memref -// CHECK: linalg.collapse_shape -// CHECK-SAME: [0], [1], [2, 3, 4] - -// ----- - func @fill_tensor(%arg0 : index, %arg1 : index, %arg2 : f32) -> tensor { %0 = linalg.init_tensor [%arg0, %arg1] : tensor %1 = linalg.fill(%arg2, %0) : f32, tensor -> tensor diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -395,3 +395,81 @@ memref.store %0, %arg0[] : memref> return } + +// ----- + +func @collapsing_memref_reshapes_to_zero_dim(%arg0 : memref<1x1x1xf32>) + -> memref { + %0 = memref.collapse_shape %arg0 [[0, 1, 2]] + : memref<1x1x1xf32> into memref<1xf32> + %1 = memref.collapse_shape %0 [] : memref<1xf32> into memref + return %1 : memref +} +// CHECK-LABEL: collapsing_memref_reshapes_to_zero +// CHECK: memref.collapse_shape %{{.*}} [] +// CHECK-SAME: memref<1x1x1xf32> into memref + +// ----- + +func @collapsing_memref_reshapes(%arg0 : memref) + -> memref { + %0 = memref.collapse_shape %arg0 [[0, 1], [2], [3, 4]] + : memref into memref + %1 = memref.collapse_shape %0 [[0, 1], [2]] + : memref into memref + return %1 : memref +} +// CHECK-LABEL: collapsing_memref_reshapes +// CHECK: memref.collapse_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]] +// CHECK-NOT: memref.collapse_shape + +// ----- + +func @expanding_memref_reshapes(%arg0 : memref) + -> memref { + %0 = memref.expand_shape %arg0 [[0, 1], [2]] + : memref into memref + %1 = memref.expand_shape %0 [[0, 1], [2], [3, 4]] + : memref into memref + return %1 : memref +} +// CHECK-LABEL: expanding_memref_reshapes +// CHECK: memref.expand_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]] +// CHECK-NOT: memref.expand_shape + +// ----- + +func @expanding_memref_reshapes_to_zero_dim(%arg0 : memref) + -> memref<1x1x1xf32> { + %0 = memref.expand_shape %arg0 [] : memref into memref<1xf32> + %1 = memref.expand_shape %0 [[0, 1, 2]] + : memref<1xf32> into memref<1x1x1xf32> + return %1 : memref<1x1x1xf32> +} +// CHECK-LABEL: expanding_memref_reshapes_to_zero +// CHECK: memref.expand_shape %{{.*}} [] +// CHECK-SAME: memref into memref<1x1x1xf32> + +// ----- + +func @fold_memref_reshape(%arg0 : memref<12x4xf32>) -> memref<12x4xf32> { + %0 = memref.expand_shape %arg0 [[0, 1], [2]] + : memref<12x4xf32> into memref<3x4x4xf32> + %1 = memref.collapse_shape %0 [[0, 1], [2]] + : memref<3x4x4xf32> into memref<12x4xf32> + return %1 : memref<12x4xf32> +} +// CHECK-LABEL: @fold_memref_reshape +// CHECK-NOT: linalg.{{.*}}_shape + +// ----- + +func @fold_memref_reshape_dynamic(%arg0 : memref) -> memref { + %0 = memref.expand_shape %arg0 [[0, 1], [2]] + : memref into memref + %1 = memref.collapse_shape %0 [[0, 1], [2]] + : memref into memref + return %1 : memref +} +// CHECK-LABEL: @fold_memref_reshape_dynamic +// CHECK-NOT: linalg.{{.*}}_shape diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir --- a/mlir/test/Dialect/MemRef/invalid.mlir +++ b/mlir/test/Dialect/MemRef/invalid.mlir @@ -231,3 +231,125 @@ memref.copy %arg0, %arg1 : memref<2xf32> to memref<2xf16> return } + +// ----- + +func @expand_shape(%arg0: memref) { + // expected-error @+1 {{expected non-zero memref ranks}} + %0 = memref.expand_shape %arg0 [[0]] : memref into memref +} + +// ----- + +func @collapse_shape_to_higher_rank(%arg0: memref) { + // expected-error @+1 {{expected the type 'memref' to have higher rank than the type = 'memref<1xf32>'}} + %0 = memref.collapse_shape %arg0 [[0]] : memref into memref<1xf32> +} + +// ----- + +func @expand_shape_to_smaller_rank(%arg0: memref<1xf32>) { + // expected-error @+1 {{expected the type 'memref' to have higher rank than the type = 'memref<1xf32>'}} + %0 = memref.expand_shape %arg0 [[0]] : memref<1xf32> into memref +} + +// ----- + +func @collapse_shape(%arg0: memref) { + // expected-error @+1 {{expected to collapse or expand dims}} + %0 = memref.collapse_shape %arg0 [[0]] : memref into memref +} + +// ----- + +func @collapse_shape_mismatch_indices_num(%arg0: memref) { + // expected-error @+1 {{expected rank of the collapsed type(2) to be the number of reassociation maps(1)}} + %0 = memref.collapse_shape %arg0 [[0, 1]] : + memref into memref +} + +// ----- + +func @collapse_shape_invalid_reassociation(%arg0: memref) { + // expected-error @+1 {{expected reassociation map #1 to be valid and contiguous}} + %0 = memref.collapse_shape %arg0 [[0, 1], [1, 2]] : + memref into memref +} + +// ----- + +func @collapse_shape_wrong_collapsed_type(%arg0: memref) { + // expected-error @+1 {{expected collapsed type to be 'memref', but got 'memref (d0 * s0 + d1)>>'}} + %0 = memref.collapse_shape %arg0 [[0, 1], [2]] : + memref into memref (d0 * s0 + d1)>> +} + +// ----- + +func @expand_shape_illegal_dynamic_memref + (%arg0: memref) -> memref { + // expected-error @+1 {{invalid to have a single dimension (2) expanded into multiple dynamic dims (2,4)}} + %0 = memref.expand_shape %arg0 [[0], [1], [2, 3, 4]] + : memref into memref + return %0 : memref +} + +// ----- + +func @expand_shape_illegal_static_memref + (%arg0: memref<2x3x20xf32>) -> memref<2x3x2x4x5xf32> { + // expected-error @+1 {{expected dimension 2 of collapsed type to be static value of 40}} + %0 = memref.expand_shape %arg0 [[0], [1], [2, 3, 4]] + : memref<2x3x20xf32> into memref<2x3x2x4x5xf32> + return %0 : memref<2x3x2x4x5xf32> +} + +// ----- + +func @collapse_shape_illegal_static_memref + (%arg0: memref<2x3x2x4x5xf32>) -> memref<2x3x20xf32> { + // expected-error @+1 {{expected dimension 2 of collapsed type to be static value of 40}} + %0 = memref.collapse_shape %arg0 [[0], [1], [2, 3, 4]] + : memref<2x3x2x4x5xf32> into memref<2x3x20xf32> + return %0 : memref<2x3x20xf32> +} + +// ----- + +func @expand_shape_illegal_mixed_memref(%arg0 : memref) + -> memref { + // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 5}} + %0 = memref.expand_shape %arg0 [[0, 1], [2]] + : memref into memref + return %0 : memref +} + +// ----- + +func @expand_shape_illegal_mixed_memref_2(%arg0 : memref) + -> memref { + // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 20}} + %0 = memref.expand_shape %arg0 [[0], [1, 2]] + : memref into memref + return %0 : memref +} + +// ----- + +func @collapse_shape_illegal_mixed_memref(%arg0 : memref) + -> memref { + // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 5}} + %0 = memref.collapse_shape %arg0 [[0, 1], [2]] + : memref into memref + return %0 : memref +} + +// ----- + +func @collapse_shape_illegal_mixed_memref_2(%arg0 : memref) + -> memref { + // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 20}} + %0 = memref.collapse_shape %arg0 [[0], [1, 2]] + : memref into memref + return %0 : memref +} diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir --- a/mlir/test/Dialect/MemRef/ops.mlir +++ b/mlir/test/Dialect/MemRef/ops.mlir @@ -1,6 +1,11 @@ // RUN: mlir-opt %s | mlir-opt | FileCheck %s // RUN: mlir-opt %s --mlir-print-op-generic | mlir-opt | FileCheck %s +// CHECK-DAG: #[[$strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> +// CHECK-DAG: #[[$strided3D:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)> +// CHECK-DAG: #[[$strided2DOFF0:.*]] = affine_map<(d0, d1)[s0] -> (d0 * s0 + d1)> +// CHECK-DAG: #[[$strided3DOFF0:.*]] = affine_map<(d0, d1, d2)[s0, s1] -> (d0 * s0 + d1 * s1 + d2)> + // CHECK-LABEL: test_buffer_cast func @test_buffer_cast(%arg0: tensor, %arg1: tensor<*xi64>) -> (memref (d0 + 7)>>, memref<*xi64, 1>) { %0 = memref.buffer_cast %arg0 : memref (d0 + 7)>> @@ -95,3 +100,114 @@ } return } + +func @expand_collapse_shape_static(%arg0: memref<3x4x5xf32>, + %arg1: tensor<3x4x5xf32>, + %arg2: tensor<3x?x5xf32>) { + // Reshapes that collapse and expand back a contiguous buffer. + %0 = memref.collapse_shape %arg0 [[0, 1], [2]] : + memref<3x4x5xf32> into memref<12x5xf32> + %r0 = memref.expand_shape %0 [[0, 1], [2]] : + memref<12x5xf32> into memref<3x4x5xf32> + %1 = memref.collapse_shape %arg0 [[0], [1, 2]] : + memref<3x4x5xf32> into memref<3x20xf32> + %r1 = memref.expand_shape %1 [[0], [1, 2]] : + memref<3x20xf32> into memref<3x4x5xf32> + %2 = memref.collapse_shape %arg0 [[0, 1, 2]] : + memref<3x4x5xf32> into memref<60xf32> + %r2 = memref.expand_shape %2 [[0, 1, 2]] : + memref<60xf32> into memref<3x4x5xf32> + // Reshapes that expand and collapse back a contiguous buffer with some 1's. + %3 = memref.expand_shape %arg0 [[0, 1], [2], [3, 4]] : + memref<3x4x5xf32> into memref<1x3x4x1x5xf32> + %r3 = memref.collapse_shape %3 [[0, 1], [2], [3, 4]] : + memref<1x3x4x1x5xf32> into memref<3x4x5xf32> + // Reshapes on tensors. + %t0 = linalg.tensor_expand_shape %arg1 [[0, 1], [2], [3, 4]] : + tensor<3x4x5xf32> into tensor<1x3x4x1x5xf32> + %rt0 = linalg.tensor_collapse_shape %t0 [[0, 1], [2], [3, 4]] : + tensor<1x3x4x1x5xf32> into tensor<3x4x5xf32> + %t1 = linalg.tensor_expand_shape %arg2 [[0, 1], [2], [3, 4]] : + tensor<3x?x5xf32> into tensor<1x3x?x1x5xf32> + %rt1 = linalg.tensor_collapse_shape %t1 [[0], [1, 2], [3, 4]] : + tensor<1x3x?x1x5xf32> into tensor<1x?x5xf32> + return +} +// CHECK-LABEL: func @expand_collapse_shape_static +// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]] +// CHECK-SAME: memref<3x4x5xf32> into memref<12x5xf32> +// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2]] +// CHECK-SAME: memref<12x5xf32> into memref<3x4x5xf32> +// CHECK: memref.collapse_shape {{.*}} {{\[}}[0], [1, 2]] +// CHECK-SAME: memref<3x4x5xf32> into memref<3x20xf32> +// CHECK: memref.expand_shape {{.*}} {{\[}}[0], [1, 2]] +// CHECK-SAME: memref<3x20xf32> into memref<3x4x5xf32> +// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1, 2]] +// CHECK-SAME: memref<3x4x5xf32> into memref<60xf32> +// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1, 2]] +// CHECK-SAME: memref<60xf32> into memref<3x4x5xf32> +// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]] +// CHECK-SAME: memref<3x4x5xf32> into memref<1x3x4x1x5xf32> +// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]] +// CHECK-SAME: memref<1x3x4x1x5xf32> into memref<3x4x5xf32> +// +// CHECK: linalg.tensor_expand_shape {{.*}}: tensor<3x4x5xf32> into tensor<1x3x4x1x5xf32> +// CHECK: linalg.tensor_collapse_shape {{.*}}: tensor<1x3x4x1x5xf32> into tensor<3x4x5xf32> +// CHECK: linalg.tensor_expand_shape {{.*}}: tensor<3x?x5xf32> into tensor<1x3x?x1x5xf32> +// CHECK: linalg.tensor_collapse_shape {{.*}}: tensor<1x3x?x1x5xf32> into tensor<1x?x5xf32> + + +func @expand_collapse_shape_dynamic(%arg0: memref, + %arg1: memref, + %arg2: memref) { + %0 = memref.collapse_shape %arg0 [[0, 1], [2]] : + memref into memref + %r0 = memref.expand_shape %0 [[0, 1], [2]] : + memref into memref + %1 = memref.collapse_shape %arg1 [[0, 1], [2]] : + memref into + memref + %r1 = memref.expand_shape %1 [[0, 1], [2]] : + memref into + memref + %2 = memref.collapse_shape %arg2 [[0, 1], [2]] : + memref into + memref + %r2 = memref.expand_shape %2 [[0, 1], [2]] : + memref into + memref + return +} +// CHECK-LABEL: func @expand_collapse_shape_dynamic +// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]] +// CHECK-SAME: memref into memref +// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2]] +// CHECK-SAME: memref into memref +// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]] +// CHECK-SAME: memref into memref +// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2]] +// CHECK-SAME: memref into memref +// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]] +// CHECK-SAME: memref into memref +// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2]] +// CHECK-SAME: memref into memref + +func @expand_collapse_shape_zero_dim(%arg0 : memref<1x1xf32>, %arg1 : memref) + -> (memref, memref<1x1xf32>) { + %0 = memref.collapse_shape %arg0 [] : memref<1x1xf32> into memref + %1 = memref.expand_shape %0 [] : memref into memref<1x1xf32> + return %0, %1 : memref, memref<1x1xf32> +} +// CHECK-LABEL: func @expand_collapse_shape_zero_dim +// CHECK: memref.collapse_shape %{{.*}} [] : memref<1x1xf32> into memref +// CHECK: memref.expand_shape %{{.*}} [] : memref into memref<1x1xf32> + +func @collapse_shape_to_dynamic + (%arg0: memref) -> memref { + %0 = memref.collapse_shape %arg0 [[0], [1], [2, 3, 4]] : + memref into memref + return %0 : memref +} +// CHECK: func @collapse_shape_to_dynamic +// CHECK: memref.collapse_shape +// CHECK-SAME: [0], [1], [2, 3, 4] diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -5854,6 +5854,7 @@ ":LLVMDialect", ":LinalgOps", ":LinalgTransforms", + ":MemRefToLLVM", ":Pass", ":SCFDialect", ":SCFToStandard",