diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -1300,7 +1300,7 @@ static SmallVector getCollapsedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, - ArrayRef reassocation, + ArrayRef reassociation, ArrayRef inStaticShape, MemRefDescriptor &inDesc, ArrayRef outStaticShape) { @@ -1308,42 +1308,84 @@ llvm::seq(0, outStaticShape.size()), [&](int64_t outDimIndex) { return getCollapsedOutputDimSize(b, loc, llvmIndexType, outDimIndex, outStaticShape[outDimIndex], - inStaticShape, inDesc, reassocation); + inStaticShape, inDesc, reassociation); })); } static SmallVector getExpandedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, - ArrayRef reassocation, + ArrayRef reassociation, ArrayRef inStaticShape, MemRefDescriptor &inDesc, ArrayRef outStaticShape) { DenseMap outDimToInDimMap = - getExpandedDimToCollapsedDimMap(reassocation); + getExpandedDimToCollapsedDimMap(reassociation); return llvm::to_vector<4>(llvm::map_range( llvm::seq(0, outStaticShape.size()), [&](int64_t outDimIndex) { return getExpandedOutputDimSize(b, loc, llvmIndexType, outDimIndex, outStaticShape, inDesc, inStaticShape, - reassocation, outDimToInDimMap); + reassociation, outDimToInDimMap); })); } static SmallVector getDynamicOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, - ArrayRef reassocation, + ArrayRef reassociation, ArrayRef inStaticShape, MemRefDescriptor &inDesc, ArrayRef outStaticShape) { return outStaticShape.size() < inStaticShape.size() ? getAsValues(b, loc, llvmIndexType, getCollapsedOutputShape(b, loc, llvmIndexType, - reassocation, inStaticShape, + reassociation, inStaticShape, inDesc, outStaticShape)) : getAsValues(b, loc, llvmIndexType, getExpandedOutputShape(b, loc, llvmIndexType, - reassocation, inStaticShape, + reassociation, inStaticShape, inDesc, outStaticShape)); } +static void fillInStridesForExpandedMemDescriptor( + OpBuilder &b, Location loc, MemRefType srcType, MemRefDescriptor &srcDesc, + MemRefDescriptor &dstDesc, ArrayRef reassociation) { + // See comments for computeExpandedLayoutMap for details on how the strides + // are caculated. + for (auto &en : llvm::enumerate(reassociation)) { + auto currentStrideToExpand = srcDesc.stride(b, loc, en.index()); + for (auto dstIndex : llvm::reverse(en.value())) { + dstDesc.setStride(b, loc, dstIndex, currentStrideToExpand); + Value size = dstDesc.size(b, loc, dstIndex); + currentStrideToExpand = + b.create(loc, size, currentStrideToExpand); + } + } +} + +static void fillInStridesForCollapsedMemDescriptor( + OpBuilder &b, Location loc, MemRefType srcType, MemRefDescriptor &srcDesc, + MemRefDescriptor &dstDesc, ArrayRef reassociation) { + // See comments for computeCollapsedLayoutMap for details on how the strides + // are caculated. + auto srcShape = srcType.getShape(); + for (auto &en : llvm::enumerate(reassociation)) { + ArrayRef ref = llvm::makeArrayRef(en.value()); + while (srcShape[ref.back()] == 1 && ref.size() > 1) + ref = ref.drop_back(); + dstDesc.setStride(b, loc, en.index(), srcDesc.stride(b, loc, ref.back())); + } +} + +static void fillInDynamicStridesForMemDescriptor( + OpBuilder &b, Location loc, MemRefType srcType, MemRefType dstType, + MemRefDescriptor &srcDesc, MemRefDescriptor &dstDesc, + ArrayRef reassociation) { + if (srcType.getRank() > dstType.getRank()) + fillInStridesForCollapsedMemDescriptor(b, loc, srcType, srcDesc, dstDesc, + reassociation); + else + fillInStridesForExpandedMemDescriptor(b, loc, srcType, srcDesc, dstDesc, + reassociation); +} + // ReshapeOp creates a new view descriptor of the proper rank. // For now, the only conversion supported is for target MemRef with static sizes // and strides. @@ -1360,15 +1402,6 @@ MemRefType dstType = reshapeOp.getResultType(); MemRefType srcType = reshapeOp.getSrcType(); - // The condition on the layouts can be ignored when all shapes are static. - if (!srcType.hasStaticShape() || !dstType.hasStaticShape()) { - if (!srcType.getLayout().isIdentity() || - !dstType.getLayout().isIdentity()) { - return rewriter.notifyMatchFailure( - reshapeOp, "only empty layout map is supported"); - } - } - int64_t offset; SmallVector strides; if (failed(getStridesAndOffset(dstType, strides, offset))) { @@ -1401,14 +1434,9 @@ for (auto &en : llvm::enumerate(strides)) dstDesc.setConstantStride(rewriter, loc, en.index(), en.value()); } else { - Value c1 = rewriter.create(loc, llvmIndexType, - rewriter.getIndexAttr(1)); - Value stride = c1; - for (auto dimIndex : - llvm::reverse(llvm::seq(0, dstShape.size()))) { - dstDesc.setStride(rewriter, loc, dimIndex, stride); - stride = rewriter.create(loc, dstShape[dimIndex], stride); - } + fillInDynamicStridesForMemDescriptor(rewriter, loc, srcType, dstType, + srcDesc, dstDesc, + reshapeOp.getReassociationIndices()); } rewriter.replaceOp(reshapeOp, {dstDesc}); return success(); diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir --- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir @@ -802,11 +802,10 @@ // CHECK: llvm.mlir.constant(1 : index) : i64 // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: llvm.mlir.constant(1 : index) : i64 -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64 +// CHECK: llvm.extractvalue %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64 +// CHECK: llvm.extractvalue %{{.*}}[4, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // ----- @@ -830,14 +829,17 @@ // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.mlir.constant(1 : index) : i64 +// CHECK: llvm.extractvalue %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.extractvalue %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64 +// CHECK: llvm.extractvalue %{{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64 +// CHECK: llvm.extractvalue %{{.*}}[3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64 // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64 -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64 - +// CHECK: llvm.extractvalue %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64 // ----- // CHECK-LABEL: func @rank_of_unranked