diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -1374,256 +1374,8 @@ } }; -/// Helper function to convert a vector of `OpFoldResult`s into a vector of -/// `Value`s. -static SmallVector getAsValues(OpBuilder &b, Location loc, - Type &llvmIndexType, - ArrayRef valueOrAttrVec) { - return llvm::to_vector<4>( - llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value { - if (auto attr = value.dyn_cast()) - return b.create(loc, llvmIndexType, attr); - return value.get(); - })); -} - -/// Compute a map that for a given dimension of the expanded type gives the -/// dimension in the collapsed type it maps to. Essentially its the inverse of -/// the `reassocation` maps. -static DenseMap -getExpandedDimToCollapsedDimMap(ArrayRef reassociation) { - llvm::DenseMap expandedDimToCollapsedDim; - for (auto &en : enumerate(reassociation)) { - for (auto dim : en.value()) - expandedDimToCollapsedDim[dim] = en.index(); - } - return expandedDimToCollapsedDim; -} - -static OpFoldResult -getExpandedOutputDimSize(OpBuilder &b, Location loc, Type &llvmIndexType, - int64_t outDimIndex, ArrayRef outStaticShape, - MemRefDescriptor &inDesc, - ArrayRef inStaticShape, - ArrayRef reassocation, - DenseMap &outDimToInDimMap) { - int64_t outDimSize = outStaticShape[outDimIndex]; - if (!ShapedType::isDynamic(outDimSize)) - return b.getIndexAttr(outDimSize); - - // Calculate the multiplication of all the out dim sizes except the - // current dim. - int64_t inDimIndex = outDimToInDimMap[outDimIndex]; - int64_t otherDimSizesMul = 1; - for (auto otherDimIndex : reassocation[inDimIndex]) { - if (otherDimIndex == static_cast(outDimIndex)) - continue; - int64_t otherDimSize = outStaticShape[otherDimIndex]; - assert(!ShapedType::isDynamic(otherDimSize) && - "single dimension cannot be expanded into multiple dynamic " - "dimensions"); - otherDimSizesMul *= otherDimSize; - } - - // outDimSize = inDimSize / otherOutDimSizesMul - int64_t inDimSize = inStaticShape[inDimIndex]; - Value inDimSizeDynamic = - ShapedType::isDynamic(inDimSize) - ? inDesc.size(b, loc, inDimIndex) - : b.create(loc, llvmIndexType, - b.getIndexAttr(inDimSize)); - Value outDimSizeDynamic = b.create( - loc, inDimSizeDynamic, - b.create(loc, llvmIndexType, - b.getIndexAttr(otherDimSizesMul))); - return outDimSizeDynamic; -} - -static OpFoldResult getCollapsedOutputDimSize( - OpBuilder &b, Location loc, Type &llvmIndexType, int64_t outDimIndex, - int64_t outDimSize, ArrayRef inStaticShape, - MemRefDescriptor &inDesc, ArrayRef reassocation) { - if (!ShapedType::isDynamic(outDimSize)) - return b.getIndexAttr(outDimSize); - - Value c1 = b.create(loc, llvmIndexType, b.getIndexAttr(1)); - Value outDimSizeDynamic = c1; - for (auto inDimIndex : reassocation[outDimIndex]) { - int64_t inDimSize = inStaticShape[inDimIndex]; - Value inDimSizeDynamic = - ShapedType::isDynamic(inDimSize) - ? inDesc.size(b, loc, inDimIndex) - : b.create(loc, llvmIndexType, - b.getIndexAttr(inDimSize)); - outDimSizeDynamic = - b.create(loc, outDimSizeDynamic, inDimSizeDynamic); - } - return outDimSizeDynamic; -} - -static SmallVector -getCollapsedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, - ArrayRef reassociation, - ArrayRef inStaticShape, - MemRefDescriptor &inDesc, - ArrayRef outStaticShape) { - return llvm::to_vector<4>(llvm::map_range( - llvm::seq(0, outStaticShape.size()), [&](int64_t outDimIndex) { - return getCollapsedOutputDimSize(b, loc, llvmIndexType, outDimIndex, - outStaticShape[outDimIndex], - inStaticShape, inDesc, reassociation); - })); -} - -static SmallVector -getExpandedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, - ArrayRef reassociation, - ArrayRef inStaticShape, - MemRefDescriptor &inDesc, - ArrayRef outStaticShape) { - DenseMap outDimToInDimMap = - getExpandedDimToCollapsedDimMap(reassociation); - return llvm::to_vector<4>(llvm::map_range( - llvm::seq(0, outStaticShape.size()), [&](int64_t outDimIndex) { - return getExpandedOutputDimSize(b, loc, llvmIndexType, outDimIndex, - outStaticShape, inDesc, inStaticShape, - reassociation, outDimToInDimMap); - })); -} - -static SmallVector -getDynamicOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, - ArrayRef reassociation, - ArrayRef inStaticShape, MemRefDescriptor &inDesc, - ArrayRef outStaticShape) { - return outStaticShape.size() < inStaticShape.size() - ? getAsValues(b, loc, llvmIndexType, - getCollapsedOutputShape(b, loc, llvmIndexType, - reassociation, inStaticShape, - inDesc, outStaticShape)) - : getAsValues(b, loc, llvmIndexType, - getExpandedOutputShape(b, loc, llvmIndexType, - reassociation, inStaticShape, - inDesc, outStaticShape)); -} - -static void fillInStridesForExpandedMemDescriptor( - OpBuilder &b, Location loc, MemRefType srcType, MemRefDescriptor &srcDesc, - MemRefDescriptor &dstDesc, ArrayRef reassociation) { - // See comments for computeExpandedLayoutMap for details on how the strides - // are calculated. - for (auto &en : llvm::enumerate(reassociation)) { - auto currentStrideToExpand = srcDesc.stride(b, loc, en.index()); - for (auto dstIndex : llvm::reverse(en.value())) { - dstDesc.setStride(b, loc, dstIndex, currentStrideToExpand); - Value size = dstDesc.size(b, loc, dstIndex); - currentStrideToExpand = - b.create(loc, size, currentStrideToExpand); - } - } -} - -static void fillInStridesForCollapsedMemDescriptor( - ConversionPatternRewriter &rewriter, Location loc, Operation *op, - TypeConverter *typeConverter, MemRefType srcType, MemRefDescriptor &srcDesc, - MemRefDescriptor &dstDesc, ArrayRef reassociation) { - auto llvmIndexType = typeConverter->convertType(rewriter.getIndexType()); - // See comments for computeCollapsedLayoutMap for details on how the strides - // are calculated. - auto srcShape = srcType.getShape(); - for (auto &en : llvm::enumerate(reassociation)) { - rewriter.setInsertionPoint(op); - auto dstIndex = en.index(); - ArrayRef ref = llvm::ArrayRef(en.value()); - while (srcShape[ref.back()] == 1 && ref.size() > 1) - ref = ref.drop_back(); - if (!ShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) { - dstDesc.setStride(rewriter, loc, dstIndex, - srcDesc.stride(rewriter, loc, ref.back())); - } else { - // Iterate over the source strides in reverse order. Skip over the - // dimensions whose size is 1. - // TODO: we should take the minimum stride in the reassociation group - // instead of just the first where the dimension is not 1. - // - // +------------------------------------------------------+ - // | curEntry: | - // | %srcStride = strides[srcIndex] | - // | %neOne = cmp sizes[srcIndex],1 +--+ - // | cf.cond_br %neOne, continue(%srcStride), nextEntry | | - // +-------------------------+----------------------------+ | - // | | - // v | - // +-----------------------------+ | - // | nextEntry: | | - // | ... +---+ | - // +--------------+--------------+ | | - // | | | - // v | | - // +-----------------------------+ | | - // | nextEntry: | | | - // | ... | | | - // +--------------+--------------+ | +--------+ - // | | | - // v v v - // +--------------------------------------------------+ - // | continue(%newStride): | - // | %newMemRefDes = setStride(%newStride,dstIndex) | - // +--------------------------------------------------+ - OpBuilder::InsertionGuard guard(rewriter); - Block *initBlock = rewriter.getInsertionBlock(); - Block *continueBlock = - rewriter.splitBlock(initBlock, rewriter.getInsertionPoint()); - continueBlock->insertArgument(unsigned(0), srcDesc.getIndexType(), loc); - rewriter.setInsertionPointToStart(continueBlock); - dstDesc.setStride(rewriter, loc, dstIndex, continueBlock->getArgument(0)); - - Block *curEntryBlock = initBlock; - Block *nextEntryBlock; - for (auto srcIndex : llvm::reverse(ref)) { - if (srcShape[srcIndex] == 1 && srcIndex != ref.front()) - continue; - rewriter.setInsertionPointToEnd(curEntryBlock); - Value srcStride = srcDesc.stride(rewriter, loc, srcIndex); - if (srcIndex == ref.front()) { - rewriter.create(loc, srcStride, continueBlock); - break; - } - Value one = rewriter.create(loc, llvmIndexType, - rewriter.getIndexAttr(1)); - Value predNeOne = rewriter.create( - loc, LLVM::ICmpPredicate::ne, srcDesc.size(rewriter, loc, srcIndex), - one); - { - OpBuilder::InsertionGuard guard(rewriter); - nextEntryBlock = rewriter.createBlock( - initBlock->getParent(), Region::iterator(continueBlock), {}); - } - rewriter.create(loc, predNeOne, continueBlock, - srcStride, nextEntryBlock, - std::nullopt); - curEntryBlock = nextEntryBlock; - } - } - } -} - -static void fillInDynamicStridesForMemDescriptor( - ConversionPatternRewriter &b, Location loc, Operation *op, - TypeConverter *typeConverter, MemRefType srcType, MemRefType dstType, - MemRefDescriptor &srcDesc, MemRefDescriptor &dstDesc, - ArrayRef reassociation) { - if (srcType.getRank() > dstType.getRank()) - fillInStridesForCollapsedMemDescriptor(b, loc, op, typeConverter, srcType, - srcDesc, dstDesc, reassociation); - else - fillInStridesForExpandedMemDescriptor(b, loc, srcType, srcDesc, dstDesc, - reassociation); -} - -// ReshapeOp creates a new view descriptor of the proper rank. -// For now, the only conversion supported is for target MemRef with static sizes -// and strides. +/// RessociatingReshapeOp must be expanded before we reach this stage. +/// Report that information. template class ReassociatingReshapeOpConversion : public ConvertOpToLLVMPattern { @@ -1634,56 +1386,9 @@ LogicalResult matchAndRewrite(ReshapeOp reshapeOp, typename ReshapeOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - MemRefType dstType = reshapeOp.getResultType(); - MemRefType srcType = reshapeOp.getSrcType(); - - int64_t offset; - SmallVector strides; - if (failed(getStridesAndOffset(dstType, strides, offset))) { - return rewriter.notifyMatchFailure( - reshapeOp, "failed to get stride and offset exprs"); - } - - MemRefDescriptor srcDesc(adaptor.getSrc()); - Location loc = reshapeOp->getLoc(); - auto dstDesc = MemRefDescriptor::undef( - rewriter, loc, this->typeConverter->convertType(dstType)); - dstDesc.setAllocatedPtr(rewriter, loc, srcDesc.allocatedPtr(rewriter, loc)); - dstDesc.setAlignedPtr(rewriter, loc, srcDesc.alignedPtr(rewriter, loc)); - dstDesc.setOffset(rewriter, loc, srcDesc.offset(rewriter, loc)); - - ArrayRef srcStaticShape = srcType.getShape(); - ArrayRef dstStaticShape = dstType.getShape(); - Type llvmIndexType = - this->typeConverter->convertType(rewriter.getIndexType()); - SmallVector dstShape = getDynamicOutputShape( - rewriter, loc, llvmIndexType, reshapeOp.getReassociationIndices(), - srcStaticShape, srcDesc, dstStaticShape); - for (auto &en : llvm::enumerate(dstShape)) - dstDesc.setSize(rewriter, loc, en.index(), en.value()); - - if (llvm::all_of(strides, isStaticStrideOrOffset)) { - for (auto &en : llvm::enumerate(strides)) - dstDesc.setConstantStride(rewriter, loc, en.index(), en.value()); - } else if (srcType.getLayout().isIdentity() && - dstType.getLayout().isIdentity()) { - Value c1 = rewriter.create(loc, llvmIndexType, - rewriter.getIndexAttr(1)); - Value stride = c1; - for (auto dimIndex : - llvm::reverse(llvm::seq(0, dstShape.size()))) { - dstDesc.setStride(rewriter, loc, dimIndex, stride); - stride = rewriter.create(loc, dstShape[dimIndex], stride); - } - } else { - // There could be mixed static/dynamic strides. For simplicity, we - // recompute all strides if there is at least one dynamic stride. - fillInDynamicStridesForMemDescriptor( - rewriter, loc, reshapeOp, this->typeConverter, srcType, dstType, - srcDesc, dstDesc, reshapeOp.getReassociationIndices()); - } - rewriter.replaceOp(reshapeOp, {dstDesc}); - return success(); + return rewriter.notifyMatchFailure( + reshapeOp, + "reassociation operations should have been expanded beforehand"); } }; diff --git a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir --- a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir @@ -373,3 +373,298 @@ %0 = memref.subview %arg0[6] [7] [-1] : memref<7xf32> to memref<7xf32, strided<[-1], offset: 6>> return %0 : memref<7xf32, strided<[-1], offset: 6>> } + +// ----- + +func.func @collapse_shape_static(%arg0: memref<1x3x4x1x5xf32>) -> memref<3x4x5xf32> { + %0 = memref.collapse_shape %arg0 [[0, 1], [2], [3, 4]] : + memref<1x3x4x1x5xf32> into memref<3x4x5xf32> + return %0 : memref<3x4x5xf32> +} +// CHECK-LABEL: func @collapse_shape_static +// CHECK-SAME: %[[ARG:.*]]: memref<1x3x4x1x5xf32>) -> memref<3x4x5xf32> { +// CHECK: %[[MEM:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<1x3x4x1x5xf32> to !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> +// CHECK: %[[BASE_BUFFER:.*]] = llvm.extractvalue %[[MEM]][0] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> +// CHECK: %[[ALIGNED_BUFFER:.*]] = llvm.extractvalue %[[MEM]][1] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> +// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64 +// CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[DESC0:.*]] = llvm.insertvalue %[[BASE_BUFFER]], %[[DESC]][0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[DESC1:.*]] = llvm.insertvalue %[[ALIGNED_BUFFER]], %[[DESC0]][1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[DESC2:.*]] = llvm.insertvalue %[[C0]], %[[DESC1]][2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[C3:.*]] = llvm.mlir.constant(3 : index) : i64 +// CHECK: %[[DESC3:.*]] = llvm.insertvalue %[[C3]], %[[DESC2]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[C20:.*]] = llvm.mlir.constant(20 : index) : i64 +// CHECK: %[[DESC4:.*]] = llvm.insertvalue %[[C20]], %[[DESC3]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[C4:.*]] = llvm.mlir.constant(4 : index) : i64 +// CHECK: %[[DESC5:.*]] = llvm.insertvalue %[[C4]], %[[DESC4]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[C5:.*]] = llvm.mlir.constant(5 : index) : i64 +// CHECK: %[[DESC6:.*]] = llvm.insertvalue %[[C5]], %[[DESC5]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[DESC7:.*]] = llvm.insertvalue %[[C5]], %[[DESC6]][3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64 +// CHECK: %[[DESC8:.*]] = llvm.insertvalue %[[C1]], %[[DESC7]][4, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[DESC8]] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> to memref<3x4x5xf32> +// CHECK: return %[[RES]] : memref<3x4x5xf32> +// CHECK: } + +// ----- + +func.func @collapse_shape_dynamic_with_non_identity_layout( + %arg0 : memref<4x?x?xf32, strided<[?, 4, 1], offset: ?>>) -> + memref<4x?xf32, strided<[?, ?], offset: ?>> { + %0 = memref.collapse_shape %arg0 [[0], [1, 2]]: + memref<4x?x?xf32, strided<[?, 4, 1], offset: ?>> into + memref<4x?xf32, strided<[?, ?], offset: ?>> + return %0 : memref<4x?xf32, strided<[?, ?], offset: ?>> +} +// CHECK-LABEL: func.func @collapse_shape_dynamic_with_non_identity_layout( +// CHECK-SAME: %[[ARG:.*]]: memref<4x?x?xf32, strided<[?, 4, 1], offset: ?>>) -> memref<4x?xf32, strided<[?, ?], offset: ?>> { +// CHECK: %[[MEM:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<4x?x?xf32, strided<[?, 4, 1], offset: ?>> to !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[BASE_BUFFER:.*]] = llvm.extractvalue %[[MEM]][0] : !llvm.struct<(ptr, ptr, i64, +// CHECK: %[[ALIGNED_BUFFER:.*]] = llvm.extractvalue %[[MEM]][1] : !llvm.struct<(ptr, ptr, i64, +// CHECK: %[[OFFSET:.*]] = llvm.extractvalue %[[MEM]][2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[SIZE1:.*]] = llvm.extractvalue %[[MEM]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[SIZE2:.*]] = llvm.extractvalue %[[MEM]][3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[STRIDE0_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[STRIDE0]] : i64 to index +// CHECK: %[[STRIDE0:.*]] = builtin.unrealized_conversion_cast %[[STRIDE0_TO_IDX]] : index to i64 +// CHECK: %[[FINAL_SIZE1:.*]] = llvm.mul %[[SIZE1]], %[[SIZE2]] : i64 +// CHECK: %[[SIZE1_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[FINAL_SIZE1]] : i64 to index +// CHECK: %[[FINAL_SIZE1:.*]] = builtin.unrealized_conversion_cast %[[SIZE1_TO_IDX]] : index to i64 +// CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[DESC0:.*]] = llvm.insertvalue %[[BASE_BUFFER]], %[[DESC]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[DESC1:.*]] = llvm.insertvalue %[[ALIGNED_BUFFER]], %[[DESC0]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[DESC2:.*]] = llvm.insertvalue %[[OFFSET]], %[[DESC1]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[C4:.*]] = llvm.mlir.constant(4 : index) : i64 +// CHECK: %[[DESC3:.*]] = llvm.insertvalue %[[C4]], %[[DESC2]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[DESC4:.*]] = llvm.insertvalue %[[STRIDE0]], %[[DESC3]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[DESC5:.*]] = llvm.insertvalue %[[FINAL_SIZE1]], %[[DESC4]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64 +// CHECK: %[[DESC6:.*]] = llvm.insertvalue %[[C1]], %[[DESC5]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[DESC6]] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to memref<4x?xf32, strided<[?, ?], offset: ?>> +// CHECK: return %[[RES]] : memref<4x?xf32, strided<[?, ?], offset: ?>> +// CHECK: } +// CHECK32-LABEL: func @collapse_shape_dynamic_with_non_identity_layout( +// CHECK32: llvm.mlir.constant(1 : index) : i32 +// CHECK32: llvm.mlir.constant(4 : index) : i32 +// CHECK32: llvm.mlir.constant(1 : index) : i32 + +// ----- + + +func.func @expand_shape_static(%arg0: memref<3x4x5xf32>) -> memref<1x3x4x1x5xf32> { + // Reshapes that expand a contiguous tensor with some 1's. + %0 = memref.expand_shape %arg0 [[0, 1], [2], [3, 4]] + : memref<3x4x5xf32> into memref<1x3x4x1x5xf32> + return %0 : memref<1x3x4x1x5xf32> +} +// CHECK-LABEL: func @expand_shape_static +// CHECK-SAME: %[[ARG:.*]]: memref<3x4x5xf32>) -> memref<1x3x4x1x5xf32> { +// CHECK: %[[MEM:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<3x4x5xf32> to !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[BASE_BUFFER:.*]] = llvm.extractvalue %[[MEM]][0] : !llvm.struct<(ptr, ptr, i64, +// CHECK: %[[ALIGNED_BUFFER:.*]] = llvm.extractvalue %[[MEM]][1] : !llvm.struct<(ptr, ptr, i64, +// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64 +// CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> +// CHECK: %[[DESC0:.*]] = llvm.insertvalue %[[BASE_BUFFER]], %[[DESC]][0] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> +// CHECK: %[[DESC1:.*]] = llvm.insertvalue %[[ALIGNED_BUFFER]], %[[DESC0]][1] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> +// CHECK: %[[DESC2:.*]] = llvm.insertvalue %[[C0]], %[[DESC1]][2] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> +// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64 +// CHECK: %[[DESC3:.*]] = llvm.insertvalue %[[C1]], %[[DESC2]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> +// CHECK: %[[C60:.*]] = llvm.mlir.constant(60 : index) : i64 +// CHECK: %[[DESC4:.*]] = llvm.insertvalue %[[C60]], %[[DESC3]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> +// CHECK: %[[C3:.*]] = llvm.mlir.constant(3 : index) : i64 +// CHECK: %[[DESC5:.*]] = llvm.insertvalue %[[C3]], %[[DESC4]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> +// CHECK: %[[C20:.*]] = llvm.mlir.constant(20 : index) : i64 +// CHECK: %[[DESC6:.*]] = llvm.insertvalue %[[C20]], %[[DESC5]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> +// CHECK: %[[C4:.*]] = llvm.mlir.constant(4 : index) : i64 +// CHECK: %[[DESC7:.*]] = llvm.insertvalue %[[C4]], %[[DESC6]][3, 2] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> +// CHECK: %[[C5:.*]] = llvm.mlir.constant(5 : index) : i64 +// CHECK: %[[DESC8:.*]] = llvm.insertvalue %[[C5]], %[[DESC7]][4, 2] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> +// CHECK: %[[DESC9:.*]] = llvm.insertvalue %[[C1]], %[[DESC8]][3, 3] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> +// CHECK: %[[DESC10:.*]] = llvm.insertvalue %[[C5]], %[[DESC9]][4, 3] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> +// CHECK: %[[DESC11:.*]] = llvm.insertvalue %[[C5]], %[[DESC10]][3, 4] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> +// CHECK: %[[DESC12:.*]] = llvm.insertvalue %[[C1]], %[[DESC11]][4, 4] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> +// CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[DESC12]] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> to memref<1x3x4x1x5xf32> +// CHECK: return %[[RES]] : memref<1x3x4x1x5xf32> +// CHECK: } + +// ----- + +func.func @collapse_shape_fold_zero_dim(%arg0 : memref<1x1xf32>) -> memref { + %0 = memref.collapse_shape %arg0 [] : memref<1x1xf32> into memref + return %0 : memref +} +// CHECK-LABEL: func.func @collapse_shape_fold_zero_dim( +// CHECK-SAME: %[[ARG:.*]]: memref<1x1xf32>) -> memref { +// CHECK: %[[MEM:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<1x1xf32> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[BASE_BUFFER:.*]] = llvm.extractvalue %[[MEM]][0] : !llvm.struct<(ptr, ptr, i64, +// CHECK: %[[ALIGNED_BUFFER:.*]] = llvm.extractvalue %[[MEM]][1] : !llvm.struct<(ptr, ptr, i64, +// CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64)> +// CHECK: %[[DESC0:.*]] = llvm.insertvalue %[[BASE_BUFFER]], %[[DESC]][0] : !llvm.struct<(ptr, ptr, i64)> +// CHECK: %[[DESC1:.*]] = llvm.insertvalue %[[ALIGNED_BUFFER]], %[[DESC0]][1] : !llvm.struct<(ptr, ptr, i64)> +// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64 +// CHECK: %[[DESC2:.*]] = llvm.insertvalue %[[C0]], %[[DESC1]][2] : !llvm.struct<(ptr, ptr, i64)> +// CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[DESC2]] : !llvm.struct<(ptr, ptr, i64)> to memref +// CHECK: return %[[RES]] : memref +// CHECK: } + +// ----- + +func.func @expand_shape_zero_dim(%arg0 : memref) -> memref<1x1xf32> { + %0 = memref.expand_shape %arg0 [] : memref into memref<1x1xf32> + return %0 : memref<1x1xf32> +} + +// CHECK-LABEL: func.func @expand_shape_zero_dim( +// CHECK-SAME: %[[ARG:.*]]: memref) -> memref<1x1xf32> { +// CHECK: %[[MEM:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref to !llvm.struct<(ptr, ptr, i64)> +// CHECK: %[[BASE_BUFFER:.*]] = llvm.extractvalue %[[MEM]][0] : !llvm.struct<(ptr, ptr, i64)> +// CHECK: %[[ALIGNED_BUFFER:.*]] = llvm.extractvalue %[[MEM]][1] : !llvm.struct<(ptr, ptr, i64)> +// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64 +// CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[DESC0:.*]] = llvm.insertvalue %[[BASE_BUFFER]], %[[DESC]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[DESC1:.*]] = llvm.insertvalue %[[ALIGNED_BUFFER]], %[[DESC0]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[DESC2:.*]] = llvm.insertvalue %[[C0]], %[[DESC1]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64 +// CHECK: %[[DESC3:.*]] = llvm.insertvalue %[[C1]], %[[DESC2]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[DESC4:.*]] = llvm.insertvalue %[[C1]], %[[DESC3]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[DESC5:.*]] = llvm.insertvalue %[[C1]], %[[DESC4]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[DESC6:.*]] = llvm.insertvalue %[[C1]], %[[DESC5]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[DESC6]] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to memref<1x1xf32> +// CHECK: return %[[RES]] : memref<1x1xf32> +// CHECK: } + +// ----- + +func.func @collapse_shape_dynamic(%arg0 : memref<1x2x?xf32>) -> memref<1x?xf32> { + %0 = memref.collapse_shape %arg0 [[0], [1, 2]]: memref<1x2x?xf32> into memref<1x?xf32> + return %0 : memref<1x?xf32> +} + +// CHECK-LABEL: func.func @collapse_shape_dynamic( +// CHECK-SAME: %[[ARG:.*]]: memref<1x2x?xf32>) -> memref<1x?xf32> { +// CHECK: %[[MEM:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<1x2x?xf32> to !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[BASE_BUFFER:.*]] = llvm.extractvalue %[[MEM]][0] : !llvm.struct<(ptr, ptr, i64, +// CHECK: %[[ALIGNED_BUFFER:.*]] = llvm.extractvalue %[[MEM]][1] : !llvm.struct<(ptr, ptr, i64, +// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64 +// CHECK: %[[SIZE2:.*]] = llvm.extractvalue %[[MEM]][3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[STRIDE1:.*]] = llvm.extractvalue %[[MEM]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[C2:.*]] = llvm.mlir.constant(2 : index) : i64 +// CHECK: %[[FINAL_SIZE1:.*]] = llvm.mul %[[SIZE2]], %[[C2]] : i64 +// CHECK: %[[SIZE1_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[FINAL_SIZE1]] : i64 to index +// CHECK: %[[FINAL_SIZE1:.*]] = builtin.unrealized_conversion_cast %[[SIZE1_TO_IDX]] : index to i64 +// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64 +// CHECK: %[[IS_MIN_STRIDE1:.*]] = llvm.icmp "slt" %[[STRIDE1]], %[[C1]] : i64 +// CHECK: %[[MIN_STRIDE1:.*]] = llvm.select %[[IS_MIN_STRIDE1]], %[[STRIDE1]], %[[C1]] : i1, i64 +// CHECK: %[[MIN_STRIDE1_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[MIN_STRIDE1]] : i64 to index +// CHECK: %[[MIN_STRIDE1:.*]] = builtin.unrealized_conversion_cast %[[MIN_STRIDE1_TO_IDX]] : index to i64 +// CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[DESC0:.*]] = llvm.insertvalue %[[BASE_BUFFER]], %[[DESC]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[DESC1:.*]] = llvm.insertvalue %[[ALIGNED_BUFFER]], %[[DESC0]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[DESC2:.*]] = llvm.insertvalue %[[C0]], %[[DESC1]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[DESC3:.*]] = llvm.insertvalue %[[C1]], %[[DESC2]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[DESC4:.*]] = llvm.insertvalue %[[STRIDE0]], %[[DESC3]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[DESC5:.*]] = llvm.insertvalue %[[FINAL_SIZE1]], %[[DESC4]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[DESC6:.*]] = llvm.insertvalue %[[MIN_STRIDE1]], %[[DESC5]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[DESC6]] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to memref<1x?xf32> +// CHECK: return %[[RES]] : memref<1x?xf32> +// CHECK: } + +// ----- + +func.func @expand_shape_dynamic(%arg0 : memref<1x?xf32>) -> memref<1x2x?xf32> { + %0 = memref.expand_shape %arg0 [[0], [1, 2]]: memref<1x?xf32> into memref<1x2x?xf32> + return %0 : memref<1x2x?xf32> +} + +// CHECK-LABEL: func.func @expand_shape_dynamic( +// CHECK-SAME: %[[ARG:.*]]: memref<1x?xf32>) -> memref<1x2x?xf32> { +// CHECK: %[[MEM:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<1x?xf32> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[BASE_BUFFER:.*]] = llvm.extractvalue %[[MEM]][0] : !llvm.struct<(ptr, ptr, i64, +// CHECK: %[[ALIGNED_BUFFER:.*]] = llvm.extractvalue %[[MEM]][1] : !llvm.struct<(ptr, ptr, i64, +// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64 +// CHECK: %[[SIZE1:.*]] = llvm.extractvalue %[[MEM]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[C2:.*]] = llvm.mlir.constant(2 : index) : i64 +// CHECK: %[[CMINUS1:.*]] = llvm.mlir.constant(-1 : index) : i64 +// CHECK: %[[IS_NEGATIVE_SIZE1:.*]] = llvm.icmp "slt" %[[SIZE1]], %[[C0]] : i64 +// CHECK: %[[ABS_SIZE1_MINUS_1:.*]] = llvm.sub %[[CMINUS1]], %[[SIZE1]] : i64 +// CHECK: %[[ADJ_SIZE1:.*]] = llvm.select %[[IS_NEGATIVE_SIZE1]], %[[ABS_SIZE1_MINUS_1]], %[[SIZE1]] : i1, i64 +// CHECK: %[[SIZE2:.*]] = llvm.sdiv %[[ADJ_SIZE1]], %[[C2]] : i64 +// CHECK: %[[NEGATIVE_SIZE2:.*]] = llvm.sub %[[CMINUS1]], %[[SIZE2]] : i64 +// CHECK: %[[FINAL_SIZE2:.*]] = llvm.select %[[IS_NEGATIVE_SIZE1]], %[[NEGATIVE_SIZE2]], %[[SIZE2]] : i1, i64 +// CHECK: %[[SIZE2_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[FINAL_SIZE2]] : i64 to index +// CHECK: %[[FINAL_SIZE2:.*]] = builtin.unrealized_conversion_cast %[[SIZE2_TO_IDX]] : index to i64 +// CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[DESC0:.*]] = llvm.insertvalue %[[BASE_BUFFER]], %[[DESC]][0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[DESC1:.*]] = llvm.insertvalue %[[ALIGNED_BUFFER]], %[[DESC0]][1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[DESC2:.*]] = llvm.insertvalue %[[C0]], %[[DESC1]][2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64 +// CHECK: %[[DESC3:.*]] = llvm.insertvalue %[[C1]], %[[DESC2]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[DESC4:.*]] = llvm.insertvalue %[[STRIDE0]], %[[DESC3]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[DESC5:.*]] = llvm.insertvalue %[[C2]], %[[DESC4]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// In this example stride1 and size2 are the same. +// Hence with CSE, we get the same SSA value. +// CHECK: %[[DESC6:.*]] = llvm.insertvalue %[[FINAL_SIZE2]], %[[DESC5]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[DESC7:.*]] = llvm.insertvalue %[[FINAL_SIZE2]], %[[DESC6]][3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[DESC8:.*]] = llvm.insertvalue %[[C1]], %[[DESC7]][4, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[DESC8]] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> to memref<1x2x?xf32> +// CHECK: return %[[RES]] : memref<1x2x?xf32> +// CHECK: } + +// ----- + +func.func @expand_shape_dynamic_with_non_identity_layout( + %arg0 : memref<1x?xf32, strided<[?, ?], offset: ?>>) -> + memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>> { + %0 = memref.expand_shape %arg0 [[0], [1, 2]]: + memref<1x?xf32, strided<[?, ?], offset: ?>> into + memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>> + return %0 : memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>> +} +// CHECK-LABEL: func.func @expand_shape_dynamic_with_non_identity_layout( +// CHECK-SAME: %[[ARG:.*]]: memref<1x?xf32, strided<[?, ?], offset: ?>>) -> memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>> { +// CHECK: %[[MEM:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<1x?xf32, strided<[?, ?], offset: ?>> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[BASE_BUFFER:.*]] = llvm.extractvalue %[[MEM]][0] : !llvm.struct<(ptr, ptr, i64, +// CHECK: %[[ALIGNED_BUFFER:.*]] = llvm.extractvalue %[[MEM]][1] : !llvm.struct<(ptr, ptr, i64, +// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64 +// CHECK: %[[OFFSET:.*]] = llvm.extractvalue %[[MEM]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[SIZE1:.*]] = llvm.extractvalue %[[MEM]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[STRIDE1:.*]] = llvm.extractvalue %[[MEM]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[C2:.*]] = llvm.mlir.constant(2 : index) : i64 +// CHECK: %[[CMINUS1:.*]] = llvm.mlir.constant(-1 : index) : i64 +// CHECK: %[[IS_NEGATIVE_SIZE1:.*]] = llvm.icmp "slt" %[[SIZE1]], %[[C0]] : i64 +// CHECK: %[[ABS_SIZE1_MINUS_1:.*]] = llvm.sub %[[CMINUS1]], %[[SIZE1]] : i64 +// CHECK: %[[ADJ_SIZE1:.*]] = llvm.select %[[IS_NEGATIVE_SIZE1]], %[[ABS_SIZE1_MINUS_1]], %[[SIZE1]] : i1, i64 +// CHECK: %[[SIZE2:.*]] = llvm.sdiv %[[ADJ_SIZE1]], %[[C2]] : i64 +// CHECK: %[[NEGATIVE_SIZE2:.*]] = llvm.sub %[[CMINUS1]], %[[SIZE2]] : i64 +// CHECK: %[[TMP_SIZE2:.*]] = llvm.select %[[IS_NEGATIVE_SIZE1]], %[[NEGATIVE_SIZE2]], %[[SIZE2]] : i1, i64 +// CHECK: %[[SIZE2_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[TMP_SIZE2]] : i64 to index +// CHECK: %[[FINAL_SIZE2:.*]] = builtin.unrealized_conversion_cast %[[SIZE2_TO_IDX]] : index to i64 +// CHECK: %[[FINAL_STRIDE1:.*]] = llvm.mul %[[TMP_SIZE2]], %[[STRIDE1]] +// CHECK: %[[STRIDE1_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[FINAL_STRIDE1]] : i64 to index +// CHECK: %[[FINAL_STRIDE1:.*]] = builtin.unrealized_conversion_cast %[[STRIDE1_TO_IDX]] : index to i64 +// CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[DESC1:.*]] = llvm.insertvalue %[[BASE_BUFFER]], %[[DESC]][0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[DESC2:.*]] = llvm.insertvalue %[[ALIGNED_BUFFER]], %[[DESC1]][1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[DESC3:.*]] = llvm.insertvalue %[[OFFSET]], %[[DESC2]][2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64 +// CHECK: %[[DESC4:.*]] = llvm.insertvalue %[[C1]], %[[DESC3]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[DESC5:.*]] = llvm.insertvalue %[[STRIDE0]], %[[DESC4]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[DESC6:.*]] = llvm.insertvalue %[[C2]], %[[DESC5]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[DESC7:.*]] = llvm.insertvalue %[[FINAL_STRIDE1]], %[[DESC6]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[DESC8:.*]] = llvm.insertvalue %[[FINAL_SIZE2]], %[[DESC7]][3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[DESC9:.*]] = llvm.insertvalue %[[STRIDE1]], %[[DESC8]][4, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[DESC9]] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> to memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>> +// CHECK: return %[[RES]] : memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>> +// CHECK: } + +// ----- + +// CHECK-LABEL: func @collapse_static_shape_with_non_identity_layout +func.func @collapse_static_shape_with_non_identity_layout(%arg: memref<1x1x8x8xf32, strided<[64, 64, 8, 1], offset: ?>>) -> memref<64xf32, strided<[1], offset: ?>> { +// CHECK-NOT: memref.collapse_shape + %1 = memref.collapse_shape %arg [[0, 1, 2, 3]] : memref<1x1x8x8xf32, strided<[64, 64, 8, 1], offset: ?>> into memref<64xf32, strided<[1], offset: ?>> + return %1 : memref<64xf32, strided<[1], offset: ?>> +} diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir --- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir @@ -290,246 +290,27 @@ // ----- -func.func @collapse_shape_static(%arg0: memref<1x3x4x1x5xf32>) -> memref<3x4x5xf32> { - %0 = memref.collapse_shape %arg0 [[0, 1], [2], [3, 4]] : - memref<1x3x4x1x5xf32> into memref<3x4x5xf32> - return %0 : memref<3x4x5xf32> -} -// CHECK-LABEL: func @collapse_shape_static -// CHECK: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.mlir.constant(3 : index) : i64 -// CHECK: llvm.mlir.constant(4 : index) : i64 -// CHECK: llvm.mlir.constant(5 : index) : i64 -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.mlir.constant(20 : index) : i64 -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.mlir.constant(5 : index) : i64 -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.mlir.constant(1 : index) : i64 -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> - -// ----- - -func.func @collapse_shape_dynamic_with_non_identity_layout( - %arg0 : memref<4x?x?xf32, strided<[?, 4, 1], offset: ?>>) -> - memref<4x?xf32, strided<[?, ?], offset: ?>> { - %0 = memref.collapse_shape %arg0 [[0], [1, 2]]: - memref<4x?x?xf32, strided<[?, 4, 1], offset: ?>> into - memref<4x?xf32, strided<[?, ?], offset: ?>> - return %0 : memref<4x?xf32, strided<[?, ?], offset: ?>> -} -// CHECK-LABEL: func @collapse_shape_dynamic_with_non_identity_layout( -// CHECK: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: llvm.mlir.constant(1 : index) : i64 -// CHECK: llvm.extractvalue %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64 -// CHECK: llvm.extractvalue %{{.*}}[3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64 -// CHECK: llvm.mlir.constant(4 : index) : i64 -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: llvm.extractvalue %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: llvm.extractvalue %{{.*}}[4, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.mlir.constant(1 : index) : i64 -// CHECK: llvm.extractvalue %{{.*}}[3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.icmp "ne" %{{.*}}, %{{.*}} : i64 -// CHECK: llvm.cond_br %{{.*}}, ^bb2(%{{.*}} : i64), ^bb1 -// CHECK: ^bb1: -// CHECK: llvm.extractvalue %{{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.br ^bb2(%{{.*}} : i64) -// CHECK: ^bb2(%[[STRIDE:.*]]: i64): -// CHECK: llvm.insertvalue %[[STRIDE]], %{{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK32-LABEL: func @collapse_shape_dynamic_with_non_identity_layout( -// CHECK32: llvm.mlir.constant(1 : index) : i32 -// CHECK32: llvm.mlir.constant(4 : index) : i32 -// CHECK32: llvm.mlir.constant(1 : index) : i32 - -// ----- - +// Expand shapes need to be expanded outside of the memref-to-llvm pass. +// CHECK-LABEL: func @expand_shape_static( +// CHECK-SAME: %[[ARG:.*]]: memref<{{.*}}>) func.func @expand_shape_static(%arg0: memref<3x4x5xf32>) -> memref<1x3x4x1x5xf32> { + // CHECK: memref.expand_shape %[[ARG]] {{\[}}[0, 1], [2], [3, 4]] // Reshapes that expand a contiguous tensor with some 1's. %0 = memref.expand_shape %arg0 [[0, 1], [2], [3, 4]] : memref<3x4x5xf32> into memref<1x3x4x1x5xf32> return %0 : memref<1x3x4x1x5xf32> } -// CHECK-LABEL: func @expand_shape_static -// CHECK: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> -// CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> -// CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> -// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> -// CHECK: llvm.mlir.constant(1 : index) : i64 -// CHECK: llvm.mlir.constant(3 : index) : i64 -// CHECK: llvm.mlir.constant(4 : index) : i64 -// CHECK: llvm.mlir.constant(1 : index) : i64 -// CHECK: llvm.mlir.constant(5 : index) : i64 -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 2] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 3] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 4] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> -// CHECK: llvm.mlir.constant(60 : index) : i64 -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> -// CHECK: llvm.mlir.constant(20 : index) : i64 -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> -// CHECK: llvm.mlir.constant(5 : index) : i64 -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 2] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> -// CHECK: llvm.mlir.constant(5 : index) : i64 -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 3] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> -// CHECK: llvm.mlir.constant(1 : index) : i64 -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 4] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> - - -// ----- - -func.func @collapse_shape_fold_zero_dim(%arg0 : memref<1x1xf32>) -> memref { - %0 = memref.collapse_shape %arg0 [] : memref<1x1xf32> into memref - return %0 : memref -} -// CHECK-LABEL: func @collapse_shape_fold_zero_dim -// CHECK: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64)> -// CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64)> -// CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64)> -// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64)> // ----- -func.func @expand_shape_zero_dim(%arg0 : memref) -> memref<1x1xf32> { - %0 = memref.expand_shape %arg0 [] : memref into memref<1x1xf32> - return %0 : memref<1x1xf32> -} -// CHECK-LABEL: func @expand_shape_zero_dim -// CHECK: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64)> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64)> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64)> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: llvm.mlir.constant(1 : index) : i64 -// CHECK: llvm.mlir.constant(1 : index) : i64 -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: llvm.mlir.constant(1 : index) : i64 -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: llvm.mlir.constant(1 : index) : i64 -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - -// ----- - -func.func @collapse_shape_dynamic(%arg0 : memref<1x2x?xf32>) -> memref<1x?xf32> { - %0 = memref.collapse_shape %arg0 [[0], [1, 2]]: memref<1x2x?xf32> into memref<1x?xf32> - return %0 : memref<1x?xf32> -} -// CHECK-LABEL: func @collapse_shape_dynamic( -// CHECK: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: llvm.mlir.constant(1 : index) : i64 -// CHECK: llvm.mlir.constant(2 : index) : i64 -// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64 -// CHECK: llvm.extractvalue %{{.*}}[3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64 -// CHECK: llvm.mlir.constant(1 : index) : i64 -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: llvm.mlir.constant(1 : index) : i64 -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64 -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64 - -// ----- - -func.func @expand_shape_dynamic(%arg0 : memref<1x?xf32>) -> memref<1x2x?xf32> { - %0 = memref.expand_shape %arg0 [[0], [1, 2]]: memref<1x?xf32> into memref<1x2x?xf32> - return %0 : memref<1x2x?xf32> -} -// CHECK-LABEL: func @expand_shape_dynamic( -// CHECK: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.extractvalue %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: llvm.mlir.constant(2 : index) : i64 -// CHECK: llvm.sdiv %{{.*}}, %{{.*}} : i64 -// CHECK: llvm.mlir.constant(1 : index) : i64 -// CHECK: llvm.mlir.constant(2 : index) : i64 -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.mlir.constant(1 : index) : i64 -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64 -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64 -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64 - -// ----- - -func.func @expand_shape_dynamic_with_non_identity_layout( - %arg0 : memref<1x?xf32, strided<[?, ?], offset: ?>>) -> - memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>> { - %0 = memref.expand_shape %arg0 [[0], [1, 2]]: - memref<1x?xf32, strided<[?, ?], offset: ?>> into - memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>> - return %0 : memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>> +// Collapse shapes need to be expanded outside of the memref-to-llvm pass. +// CHECK-LABEL: func @collapse_shape_static +// CHECK-SAME: %[[ARG:.*]]: memref<1x3x4x1x5xf32>) -> memref<3x4x5xf32> { +func.func @collapse_shape_static(%arg0: memref<1x3x4x1x5xf32>) -> memref<3x4x5xf32> { + %0 = memref.collapse_shape %arg0 [[0, 1], [2], [3, 4]] : + memref<1x3x4x1x5xf32> into memref<3x4x5xf32> + return %0 : memref<3x4x5xf32> } -// CHECK-LABEL: func @expand_shape_dynamic_with_non_identity_layout( -// CHECK: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.extractvalue %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: llvm.mlir.constant(2 : index) : i64 -// CHECK: llvm.sdiv %{{.*}}, %{{.*}} : i64 -// CHECK: llvm.mlir.constant(1 : index) : i64 -// CHECK: llvm.mlir.constant(2 : index) : i64 -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.extractvalue %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.extractvalue %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64 -// CHECK: llvm.extractvalue %{{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.extractvalue %{{.*}}[3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64 -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.extractvalue %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64 // ----- @@ -579,15 +360,6 @@ // ----- -// CHECK-LABEL: func @collapse_static_shape_with_non_identity_layout -func.func @collapse_static_shape_with_non_identity_layout(%arg: memref<1x1x8x8xf32, strided<[64, 64, 8, 1], offset: ?>>) -> memref<64xf32, strided<[1], offset: ?>> { -// CHECK-NOT: memref.collapse_shape - %1 = memref.collapse_shape %arg [[0, 1, 2, 3]] : memref<1x1x8x8xf32, strided<[64, 64, 8, 1], offset: ?>> into memref<64xf32, strided<[1], offset: ?>> - return %1 : memref<64xf32, strided<[1], offset: ?>> -} - -// ----- - // CHECK-LABEL: func @generic_atomic_rmw func.func @generic_atomic_rmw(%I : memref<10xi32>, %i : index) { %x = memref.generic_atomic_rmw %I[%i] : memref<10xi32> { diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-collapse-tensor.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-collapse-tensor.mlir --- a/mlir/test/Integration/Dialect/Linalg/CPU/test-collapse-tensor.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-collapse-tensor.mlir @@ -1,6 +1,7 @@ // RUN: mlir-opt %s -linalg-bufferize \ // RUN: -arith-bufferize -tensor-bufferize -func-bufferize \ // RUN: -finalizing-bufferize -buffer-deallocation -convert-linalg-to-llvm \ +// RUN: -expand-strided-metadata -lower-affine \ // RUN: -convert-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \ // RUN: mlir-cpu-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_lib_dir/libmlir_runner_utils%shlibext \ diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-expand-tensor.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-expand-tensor.mlir --- a/mlir/test/Integration/Dialect/Linalg/CPU/test-expand-tensor.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-expand-tensor.mlir @@ -1,6 +1,7 @@ // RUN: mlir-opt %s -linalg-bufferize \ // RUN: -arith-bufferize -tensor-bufferize -func-bufferize \ // RUN: -finalizing-bufferize -buffer-deallocation -convert-linalg-to-llvm \ +// RUN: -expand-strided-metadata -lower-affine \ // RUN: -convert-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \ // RUN: mlir-cpu-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_lib_dir/libmlir_runner_utils%shlibext \