diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -1300,7 +1300,7 @@ static SmallVector getCollapsedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, - ArrayRef reassocation, + ArrayRef reassociation, ArrayRef inStaticShape, MemRefDescriptor &inDesc, ArrayRef outStaticShape) { @@ -1308,42 +1308,88 @@ llvm::seq(0, outStaticShape.size()), [&](int64_t outDimIndex) { return getCollapsedOutputDimSize(b, loc, llvmIndexType, outDimIndex, outStaticShape[outDimIndex], - inStaticShape, inDesc, reassocation); + inStaticShape, inDesc, reassociation); })); } static SmallVector getExpandedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, - ArrayRef reassocation, + ArrayRef reassociation, ArrayRef inStaticShape, MemRefDescriptor &inDesc, ArrayRef outStaticShape) { DenseMap outDimToInDimMap = - getExpandedDimToCollapsedDimMap(reassocation); + getExpandedDimToCollapsedDimMap(reassociation); return llvm::to_vector<4>(llvm::map_range( llvm::seq(0, outStaticShape.size()), [&](int64_t outDimIndex) { return getExpandedOutputDimSize(b, loc, llvmIndexType, outDimIndex, outStaticShape, inDesc, inStaticShape, - reassocation, outDimToInDimMap); + reassociation, outDimToInDimMap); })); } static SmallVector getDynamicOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, - ArrayRef reassocation, + ArrayRef reassociation, ArrayRef inStaticShape, MemRefDescriptor &inDesc, ArrayRef outStaticShape) { return outStaticShape.size() < inStaticShape.size() ? getAsValues(b, loc, llvmIndexType, getCollapsedOutputShape(b, loc, llvmIndexType, - reassocation, inStaticShape, + reassociation, inStaticShape, inDesc, outStaticShape)) : getAsValues(b, loc, llvmIndexType, getExpandedOutputShape(b, loc, llvmIndexType, - reassocation, inStaticShape, + reassociation, inStaticShape, inDesc, outStaticShape)); } +static void fillInStridesForExpandedMemDescriptor( + OpBuilder &b, Location loc, MemRefType srcType, MemRefType dstType, + MemRefDescriptor &srcDesc, MemRefDescriptor &dstDesc, + ArrayRef reassociation) { + unsigned dstIndex = dstType.getRank() - 1; + for (auto &en : llvm::enumerate(llvm::reverse(reassociation))) { + const ReassociationIndices &reassoc = en.value(); + unsigned srcIndex = srcType.getRank() - 1 - en.index(); + auto currentStrideToExpand = srcDesc.stride(b, loc, srcIndex); + for (unsigned idx = 0, e = reassoc.size(); idx < e; ++idx) { + dstDesc.setStride(b, loc, dstIndex, currentStrideToExpand); + Value size = dstDesc.size(b, loc, dstIndex--); + currentStrideToExpand = + b.create(loc, size, currentStrideToExpand); + } + } +} + +static void fillInStridesForCollapsedMemDescriptor( + OpBuilder &b, Location loc, MemRefType srcType, + ArrayRef reassociation, MemRefDescriptor &srcDesc, + MemRefDescriptor &dstDesc) { + auto srcShape = srcType.getShape(); + for (auto &en : llvm::enumerate(reassociation)) { + const ReassociationIndices &reassoc = en.value(); + unsigned dstStrideIndex = en.index(); + ArrayRef ref = llvm::makeArrayRef(reassoc); + while (srcShape[ref.back()] == 1) + ref = ref.drop_back(); + dstDesc.setStride(b, loc, dstStrideIndex, + srcDesc.stride(b, loc, ref.back())); + } +} + +static void fillInDynamicStridesForMemDescriptor( + OpBuilder &b, Location loc, MemRefType srcType, MemRefType dstType, + MemRefDescriptor &srcDesc, MemRefDescriptor &dstDesc, + ArrayRef reassociation) { + if (srcType.getRank() > dstType.getRank()) + fillInStridesForCollapsedMemDescriptor(b, loc, srcType, reassociation, + srcDesc, dstDesc); + else + fillInStridesForExpandedMemDescriptor(b, loc, srcType, dstType, srcDesc, + dstDesc, reassociation); +} + // ReshapeOp creates a new view descriptor of the proper rank. // For now, the only conversion supported is for target MemRef with static sizes // and strides. @@ -1361,13 +1407,13 @@ MemRefType srcType = reshapeOp.getSrcType(); // The condition on the layouts can be ignored when all shapes are static. - if (!srcType.hasStaticShape() || !dstType.hasStaticShape()) { - if (!srcType.getLayout().isIdentity() || - !dstType.getLayout().isIdentity()) { - return rewriter.notifyMatchFailure( - reshapeOp, "only empty layout map is supported"); - } - } + // if (!srcType.hasStaticShape() || !dstType.hasStaticShape()) { + // if (!srcType.getLayout().isIdentity() || + // !dstType.getLayout().isIdentity()) { + // return rewriter.notifyMatchFailure( + // reshapeOp, "only empty layout map is supported"); + // } + //} int64_t offset; SmallVector strides; @@ -1401,14 +1447,9 @@ for (auto &en : llvm::enumerate(strides)) dstDesc.setConstantStride(rewriter, loc, en.index(), en.value()); } else { - Value c1 = rewriter.create(loc, llvmIndexType, - rewriter.getIndexAttr(1)); - Value stride = c1; - for (auto dimIndex : - llvm::reverse(llvm::seq(0, dstShape.size()))) { - dstDesc.setStride(rewriter, loc, dimIndex, stride); - stride = rewriter.create(loc, dstShape[dimIndex], stride); - } + fillInDynamicStridesForMemDescriptor(rewriter, loc, srcType, dstType, + srcDesc, dstDesc, + reshapeOp.getReassociationIndices()); } rewriter.replaceOp(reshapeOp, {dstDesc}); return success(); diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1826,8 +1826,12 @@ // reassociation. SmallVector resultStrides; resultStrides.reserve(reassociation.size()); - for (ReassociationIndices reassoc : reassociation) - resultStrides.push_back(srcStrides[reassoc.back()]); + for (const ReassociationIndices &reassoc : reassociation) { + ArrayRef ref = llvm::makeArrayRef(reassoc); + while (srcShape[ref.back()] == 1) + ref = ref.drop_back(); + resultStrides.push_back(srcStrides[ref.back()]); + } // Validate that each reassociation group is contiguous. unsigned resultStrideIndex = resultStrides.size() - 1;