diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -1000,6 +1000,139 @@ } }; +/// Helper function to convert a vector of `OpFoldResult`s into a vector of +/// `Value`s. +static SmallVector getAsValues(OpBuilder &b, Location loc, + Type &llvmIndexType, + ArrayRef valueOrAttrVec) { + return llvm::to_vector<4>( + llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value { + if (auto attr = value.dyn_cast()) + return b.create(loc, llvmIndexType, attr); + return value.get(); + })); +} + +/// Compute a map that for a given dimension of the expanded type gives the +/// dimension in the collapsed type it maps to. Essentially its the inverse of +/// the `reassocation` maps. +static DenseMap +getExpandedDimToCollapsedDimMap(ArrayRef reassociation) { + llvm::DenseMap expandedDimToCollapsedDim; + for (auto &en : enumerate(reassociation)) { + for (auto dim : en.value()) + expandedDimToCollapsedDim[dim] = en.index(); + } + return expandedDimToCollapsedDim; +} + +static OpFoldResult +getExpandedOutputDimSize(OpBuilder &b, Location loc, Type &llvmIndexType, + int64_t outDimIndex, ArrayRef outStaticShape, + MemRefDescriptor &inDesc, + ArrayRef inStaticShape, + ArrayRef reassocation, + DenseMap &outDimToInDimMap) { + int64_t outDimSize = outStaticShape[outDimIndex]; + if (!ShapedType::isDynamic(outDimSize)) + return b.getIndexAttr(outDimSize); + + // Calculate the multiplication of all the out dim sizes except the + // current dim. + int64_t inDimIndex = outDimToInDimMap[outDimIndex]; + int64_t otherDimSizesMul = 1; + for (auto &otherDimIndex : reassocation[inDimIndex]) { + if (otherDimIndex == static_cast(outDimIndex)) + continue; + int64_t otherDimSize = outStaticShape[otherDimIndex]; + assert(!ShapedType::isDynamic(otherDimSize) && + "single dimension cannot be expanded into multiple dynamic " + "dimensions"); + otherDimSizesMul *= otherDimSize; + } + + // outDimSize = inDimSize / otherOutDimSizesMul + int64_t inDimSize = inStaticShape[inDimIndex]; + Value inDimSizeDynamic = + ShapedType::isDynamic(inDimSize) + ? inDesc.size(b, loc, inDimIndex) + : b.create(loc, llvmIndexType, + b.getIndexAttr(inDimSize)); + Value outDimSizeDynamic = b.create( + loc, inDimSizeDynamic, + b.create(loc, llvmIndexType, + b.getIndexAttr(otherDimSizesMul))); + return outDimSizeDynamic; +} + +static OpFoldResult getCollapsedOutputDimSize( + OpBuilder &b, Location loc, Type &llvmIndexType, int64_t outDimIndex, + int64_t outDimSize, ArrayRef inStaticShape, + MemRefDescriptor &inDesc, ArrayRef reassocation) { + if (!ShapedType::isDynamic(outDimSize)) + return b.getIndexAttr(outDimSize); + + Value c1 = b.create(loc, llvmIndexType, b.getIndexAttr(1)); + Value outDimSizeDynamic = c1; + for (auto &inDimIndex : reassocation[outDimIndex]) { + int64_t inDimSize = inStaticShape[inDimIndex]; + Value inDimSizeDynamic = + ShapedType::isDynamic(inDimSize) + ? inDesc.size(b, loc, inDimIndex) + : b.create(loc, llvmIndexType, + b.getIndexAttr(inDimSize)); + outDimSizeDynamic = + b.create(loc, outDimSizeDynamic, inDimSizeDynamic); + } + return outDimSizeDynamic; +} + +static SmallVector +getCollapsedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, + ArrayRef reassocation, + ArrayRef inStaticShape, + MemRefDescriptor &inDesc, + ArrayRef outStaticShape) { + return llvm::to_vector<4>(llvm::map_range( + llvm::seq(0, outStaticShape.size()), [&](int64_t outDimIndex) { + return getCollapsedOutputDimSize(b, loc, llvmIndexType, outDimIndex, + outStaticShape[outDimIndex], + inStaticShape, inDesc, reassocation); + })); +} + +static SmallVector +getExpandedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, + ArrayRef reassocation, + ArrayRef inStaticShape, + MemRefDescriptor &inDesc, + ArrayRef outStaticShape) { + DenseMap outDimToInDimMap = + getExpandedDimToCollapsedDimMap(reassocation); + return llvm::to_vector<4>(llvm::map_range( + llvm::seq(0, outStaticShape.size()), [&](int64_t outDimIndex) { + return getExpandedOutputDimSize(b, loc, llvmIndexType, outDimIndex, + outStaticShape, inDesc, inStaticShape, + reassocation, outDimToInDimMap); + })); +} + +static SmallVector +getDynamicOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, + ArrayRef reassocation, + ArrayRef inStaticShape, MemRefDescriptor &inDesc, + ArrayRef outStaticShape) { + return outStaticShape.size() < inStaticShape.size() + ? getAsValues(b, loc, llvmIndexType, + getCollapsedOutputShape(b, loc, llvmIndexType, + reassocation, inStaticShape, + inDesc, outStaticShape)) + : getAsValues(b, loc, llvmIndexType, + getExpandedOutputShape(b, loc, llvmIndexType, + reassocation, inStaticShape, + inDesc, outStaticShape)); +} + // ReshapeOp creates a new view descriptor of the proper rank. // For now, the only conversion supported is for target MemRef with static sizes // and strides. @@ -1014,35 +1147,59 @@ matchAndRewrite(ReshapeOp reshapeOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { MemRefType dstType = reshapeOp.getResultType(); - - if (!dstType.hasStaticShape()) - return failure(); + MemRefType srcType = reshapeOp.getSrcType(); + if (!srcType.getAffineMaps().empty() || !dstType.getAffineMaps().empty()) { + return rewriter.notifyMatchFailure(reshapeOp, + "only empty layout map is supported"); + } int64_t offset; SmallVector strides; - auto res = getStridesAndOffset(dstType, strides, offset); - if (failed(res) || llvm::any_of(strides, [](int64_t val) { - return ShapedType::isDynamicStrideOrOffset(val); - })) - return failure(); + if (failed(getStridesAndOffset(dstType, strides, offset))) { + return rewriter.notifyMatchFailure( + reshapeOp, "failed to get stride and offset exprs"); + } ReshapeOpAdaptor adaptor(operands); - MemRefDescriptor baseDesc(adaptor.src()); + MemRefDescriptor srcDesc(adaptor.src()); Location loc = reshapeOp->getLoc(); - auto desc = - MemRefDescriptor::undef(rewriter, reshapeOp->getLoc(), - this->typeConverter->convertType(dstType)); - desc.setAllocatedPtr(rewriter, loc, baseDesc.allocatedPtr(rewriter, loc)); - desc.setAlignedPtr(rewriter, loc, baseDesc.alignedPtr(rewriter, loc)); - desc.setOffset(rewriter, loc, baseDesc.offset(rewriter, loc)); - for (auto en : llvm::enumerate(dstType.getShape())) - desc.setConstantSize(rewriter, loc, en.index(), en.value()); - for (auto en : llvm::enumerate(strides)) - desc.setConstantStride(rewriter, loc, en.index(), en.value()); - rewriter.replaceOp(reshapeOp, {desc}); + auto dstDesc = MemRefDescriptor::undef( + rewriter, loc, this->typeConverter->convertType(dstType)); + dstDesc.setAllocatedPtr(rewriter, loc, srcDesc.allocatedPtr(rewriter, loc)); + dstDesc.setAlignedPtr(rewriter, loc, srcDesc.alignedPtr(rewriter, loc)); + dstDesc.setOffset(rewriter, loc, srcDesc.offset(rewriter, loc)); + + ArrayRef srcStaticShape = srcType.getShape(); + ArrayRef dstStaticShape = dstType.getShape(); + Type llvmIndexType = + this->typeConverter->convertType(rewriter.getIndexType()); + SmallVector dstShape = getDynamicOutputShape( + rewriter, loc, llvmIndexType, reshapeOp.getReassociationIndices(), + srcStaticShape, srcDesc, dstStaticShape); + for (auto en : llvm::enumerate(dstShape)) + dstDesc.setSize(rewriter, loc, en.index(), en.value()); + + auto isStaticStride = [](int64_t stride) { + return !ShapedType::isDynamicStrideOrOffset(stride); + }; + if (llvm::all_of(strides, isStaticStride)) { + for (auto en : llvm::enumerate(strides)) + dstDesc.setConstantStride(rewriter, loc, en.index(), en.value()); + } else { + Value c1 = rewriter.create(loc, llvmIndexType, + rewriter.getIndexAttr(1)); + Value stride = c1; + for (auto dimIndex : + llvm::reverse(llvm::seq(0, dstShape.size()))) { + dstDesc.setStride(rewriter, loc, dimIndex, stride); + stride = rewriter.create(loc, dstShape[dimIndex], stride); + } + } + rewriter.replaceOp(reshapeOp, {dstDesc}); return success(); } }; + /// Conversion pattern that transforms a subview op into: /// 1. An `llvm.mlir.undef` operation to create a memref descriptor /// 2. Updates to the descriptor to introduce the data ptr, offset, size diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir --- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir @@ -700,6 +700,34 @@ // ----- +func @collapse_shape_static(%arg0: memref<1x3x4x1x5xf32>) -> memref<3x4x5xf32> { + %0 = memref.collapse_shape %arg0 [[0, 1], [2], [3, 4]] : + memref<1x3x4x1x5xf32> into memref<3x4x5xf32> + return %0 : memref<3x4x5xf32> +} +// CHECK-LABEL: func @collapse_shape_static +// CHECK: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.mlir.constant(3 : index) : i64 +// CHECK: llvm.mlir.constant(4 : index) : i64 +// CHECK: llvm.mlir.constant(5 : index) : i64 +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.mlir.constant(20 : index) : i64 +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.mlir.constant(5 : index) : i64 +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.mlir.constant(1 : index) : i64 +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> + +// ----- + func @expand_shape_static(%arg0: memref<3x4x5xf32>) -> memref<1x3x4x1x5xf32> { // Reshapes that expand a contiguous tensor with some 1's. %0 = memref.expand_shape %arg0 [[0, 1], [2], [3, 4]] @@ -715,14 +743,14 @@ // CHECK: llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> // CHECK: llvm.mlir.constant(1 : index) : i64 -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> // CHECK: llvm.mlir.constant(3 : index) : i64 -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> // CHECK: llvm.mlir.constant(4 : index) : i64 -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 2] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> // CHECK: llvm.mlir.constant(1 : index) : i64 -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 3] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> // CHECK: llvm.mlir.constant(5 : index) : i64 +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 2] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 3] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 4] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> // CHECK: llvm.mlir.constant(60 : index) : i64 // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> @@ -735,33 +763,6 @@ // CHECK: llvm.mlir.constant(1 : index) : i64 // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 4] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> -// ----- - -func @collapse_shape_static(%arg0: memref<1x3x4x1x5xf32>) -> memref<3x4x5xf32> { - %0 = memref.collapse_shape %arg0 [[0, 1], [2], [3, 4]] : - memref<1x3x4x1x5xf32> into memref<3x4x5xf32> - return %0 : memref<3x4x5xf32> -} -// CHECK-LABEL: func @collapse_shape_static -// CHECK: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.mlir.constant(3 : index) : i64 -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.mlir.constant(4 : index) : i64 -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.mlir.constant(5 : index) : i64 -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.mlir.constant(20 : index) : i64 -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.mlir.constant(5 : index) : i64 -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.mlir.constant(1 : index) : i64 -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> // ----- @@ -793,10 +794,68 @@ // CHECK: llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64)> // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: llvm.mlir.constant(1 : index) : i64 -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: llvm.mlir.constant(1 : index) : i64 +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: llvm.mlir.constant(1 : index) : i64 // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: llvm.mlir.constant(1 : index) : i64 // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + +// ----- + +func @collapse_shape_dynamic(%arg0 : memref<1x2x?xf32>) -> memref<1x?xf32> { + %0 = memref.collapse_shape %arg0 [[0], [1, 2]]: memref<1x2x?xf32> into memref<1x?xf32> + return %0 : memref<1x?xf32> +} +// CHECK-LABEL: func @collapse_shape_dynamic( +// CHECK: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: llvm.mlir.constant(1 : index) : i64 +// CHECK: llvm.mlir.constant(2 : index) : i64 +// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64 +// CHECK: llvm.extractvalue %{{.*}}[3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64 +// CHECK: llvm.mlir.constant(1 : index) : i64 +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: llvm.mlir.constant(1 : index) : i64 +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64 +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64 + +// ----- + +func @expand_shape_dynamic(%arg0 : memref<1x?xf32>) -> memref<1x2x?xf32> { + %0 = memref.expand_shape %arg0 [[0], [1, 2]]: memref<1x?xf32> into memref<1x2x?xf32> + return %0 : memref<1x2x?xf32> +} +// CHECK-LABEL: func @expand_shape_dynamic( +// CHECK: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.extractvalue %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: llvm.mlir.constant(2 : index) : i64 +// CHECK: llvm.sdiv %{{.*}}, %{{.*}} : i64 +// CHECK: llvm.mlir.constant(1 : index) : i64 +// CHECK: llvm.mlir.constant(2 : index) : i64 +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.mlir.constant(1 : index) : i64 +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64 +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64 +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64 diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-collapse-tensor.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-collapse-tensor.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-collapse-tensor.mlir @@ -0,0 +1,38 @@ +// RUN: mlir-opt %s -linalg-bufferize -std-bufferize \ +// RUN: -tensor-constant-bufferize -tensor-bufferize -func-bufferize \ +// RUN: -finalizing-bufferize -convert-linalg-to-llvm \ +// RUN: -convert-memref-to-llvm -convert-std-to-llvm | \ +// RUN: mlir-cpu-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ +// RUN: | FileCheck %s + + +func @main() { + %const = constant dense<[[[[-3.9058,0.9072],[-2.9470,-2.2055],[18.3946,8.2997]],[[3.4700,5.9006],[-17.2267,4.9777],[1.0450,-0.8201]]],[[[17.6996,-11.1763],[26.7775,-3.8823],[-4.2492,-5.8966]],[[2.1259,13.1794],[-10.7136,0.8428],[16.4233,9.4589]]]]> : tensor<2x2x3x2xf32> + %dynamic = tensor.cast %const: tensor<2x2x3x2xf32> to tensor<2x?x?x?xf32> + %collapsed = call @collapse_dynamic_shape(%dynamic) : (tensor<2x?x?x?xf32>) -> (tensor<2x?x?xf32>) + %unranked = tensor.cast %collapsed: tensor<2x?x?xf32> to tensor<*xf32> + call @print_memref_f32(%unranked) : (tensor<*xf32>) -> () + // CHECK: Unranked Memref base@ = {{0x[-9a-f]*}} + // CHECK-SAME: rank = 3 offset = 0 sizes = [2, 6, 2] strides = [12, 2, 1] data = + // CHECK-NEXT{LITERAL}: [[[-3.9058, 0.9072], + // CHECK-NEXT: [-2.947, -2.2055], + // CHECK-NEXT: [18.3946, 8.2997], + // CHECK-NEXT: [3.47, 5.9006], + // CHECK-NEXT: [-17.2267, 4.9777], + // CHECK-NEXT: [1.045, -0.8201]], + // CHECK-NEXT{LITERAL}: [[17.6996, -11.1763], + // CHECK-NEXT: [26.7775, -3.8823], + // CHECK-NEXT: [-4.2492, -5.8966], + // CHECK-NEXT: [2.1259, 13.1794], + // CHECK-NEXT: [-10.7136, 0.8428], + // CHECK-NEXT: [16.4233, 9.4589]]] + return +} + +func private @print_memref_f32(%ptr : tensor<*xf32>) + +func @collapse_dynamic_shape(%arg0 : tensor<2x?x?x?xf32>) -> tensor<2x?x?xf32> { + %0 = linalg.tensor_collapse_shape %arg0 [[0], [1, 2], [3]]: tensor<2x?x?x?xf32> into tensor<2x?x?xf32> + return %0 : tensor<2x?x?xf32> +} diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-expand-tensor.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-expand-tensor.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-expand-tensor.mlir @@ -0,0 +1,39 @@ +// RUN: mlir-opt %s -linalg-bufferize -std-bufferize \ +// RUN: -tensor-constant-bufferize -tensor-bufferize -func-bufferize \ +// RUN: -finalizing-bufferize -convert-linalg-to-llvm \ +// RUN: -convert-memref-to-llvm -convert-std-to-llvm | \ +// RUN: mlir-cpu-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ +// RUN: | FileCheck %s + + +func @main() { + %const = constant dense<[[[-3.9058,0.9072],[-2.9470,-2.2055],[18.3946,8.2997],[3.4700,5.9006],[-17.2267,4.9777],[1.0450,-0.8201]],[[17.6996,-11.1763],[26.7775,-3.8823],[-4.2492,-5.8966],[2.1259,13.1794],[-10.7136,0.8428],[16.4233,9.4589]]]> : tensor<2x6x2xf32> + %dynamic = tensor.cast %const: tensor<2x6x2xf32> to tensor<2x?x?xf32> + %expanded = call @expand_dynamic_shape(%dynamic) : (tensor<2x?x?xf32>) -> (tensor<2x2x?x1x?xf32>) + %unranked = tensor.cast %expanded: tensor<2x2x?x1x?xf32> to tensor<*xf32> + call @print_memref_f32(%unranked) : (tensor<*xf32>) -> () + + // CHECK: Unranked Memref base@ = {{0x[-9a-f]*}} + // CHECK-SAME: rank = 5 offset = 0 sizes = [2, 2, 3, 1, 2] strides = [12, 6, 2, 2, 1] data = + // CHECK-NEXT{LITERAL}: [[[[[-3.9058, 0.9072]], + // CHECK-NEXT{LITERAL}: [[-2.947, -2.2055]], + // CHECK-NEXT{LITERAL}: [[18.3946, 8.2997]]], + // CHECK-NEXT{LITERAL}: [[[3.47, 5.9006]], + // CHECK-NEXT{LITERAL}: [[-17.2267, 4.9777]], + // CHECK-NEXT{LITERAL}: [[1.045, -0.8201]]]], + // CHECK-NEXT{LITERAL}: [[[[17.6996, -11.1763]], + // CHECK-NEXT{LITERAL}: [[26.7775, -3.8823]], + // CHECK-NEXT{LITERAL}: [[-4.2492, -5.8966]]], + // CHECK-NEXT{LITERAL}: [[[2.1259, 13.1794]], + // CHECK-NEXT{LITERAL}: [[-10.7136, 0.8428]], + // CHECK-NEXT{LITERAL}: [[16.4233, 9.4589]]]]] + return +} + +func private @print_memref_f32(%ptr : tensor<*xf32>) + +func @expand_dynamic_shape(%arg0 : tensor<2x?x?xf32>) -> tensor<2x2x?x1x?xf32> { + %0 = linalg.tensor_expand_shape %arg0 [[0], [1, 2, 3], [4]]: tensor<2x?x?xf32> into tensor<2x2x?x1x?xf32> + return %0 : tensor<2x2x?x1x?xf32> +}