diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -1300,7 +1300,7 @@ static SmallVector getCollapsedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, - ArrayRef reassocation, + ArrayRef reassociation, ArrayRef inStaticShape, MemRefDescriptor &inDesc, ArrayRef outStaticShape) { @@ -1308,42 +1308,151 @@ llvm::seq(0, outStaticShape.size()), [&](int64_t outDimIndex) { return getCollapsedOutputDimSize(b, loc, llvmIndexType, outDimIndex, outStaticShape[outDimIndex], - inStaticShape, inDesc, reassocation); + inStaticShape, inDesc, reassociation); })); } static SmallVector getExpandedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, - ArrayRef reassocation, + ArrayRef reassociation, ArrayRef inStaticShape, MemRefDescriptor &inDesc, ArrayRef outStaticShape) { DenseMap outDimToInDimMap = - getExpandedDimToCollapsedDimMap(reassocation); + getExpandedDimToCollapsedDimMap(reassociation); return llvm::to_vector<4>(llvm::map_range( llvm::seq(0, outStaticShape.size()), [&](int64_t outDimIndex) { return getExpandedOutputDimSize(b, loc, llvmIndexType, outDimIndex, outStaticShape, inDesc, inStaticShape, - reassocation, outDimToInDimMap); + reassociation, outDimToInDimMap); })); } static SmallVector getDynamicOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, - ArrayRef reassocation, + ArrayRef reassociation, ArrayRef inStaticShape, MemRefDescriptor &inDesc, ArrayRef outStaticShape) { return outStaticShape.size() < inStaticShape.size() ? getAsValues(b, loc, llvmIndexType, getCollapsedOutputShape(b, loc, llvmIndexType, - reassocation, inStaticShape, + reassociation, inStaticShape, inDesc, outStaticShape)) : getAsValues(b, loc, llvmIndexType, getExpandedOutputShape(b, loc, llvmIndexType, - reassocation, inStaticShape, + reassociation, inStaticShape, inDesc, outStaticShape)); } +static void fillInStridesForExpandedMemDescriptor( + OpBuilder &b, Location loc, MemRefType srcType, MemRefDescriptor &srcDesc, + MemRefDescriptor &dstDesc, ArrayRef reassociation) { + // See comments for computeExpandedLayoutMap for details on how the strides + // are caculated. + for (auto &en : llvm::enumerate(reassociation)) { + auto currentStrideToExpand = srcDesc.stride(b, loc, en.index()); + for (auto dstIndex : llvm::reverse(en.value())) { + dstDesc.setStride(b, loc, dstIndex, currentStrideToExpand); + Value size = dstDesc.size(b, loc, dstIndex); + currentStrideToExpand = + b.create(loc, size, currentStrideToExpand); + } + } +} + +static void fillInStridesForCollapsedMemDescriptor( + ConversionPatternRewriter &rewriter, Location loc, + TypeConverter *typeConverter, MemRefType srcType, MemRefDescriptor &srcDesc, + MemRefDescriptor &dstDesc, ArrayRef reassociation) { + + // See comments for computeCollapsedLayoutMap for details on how the strides + // are caculated. + auto srcShape = srcType.getShape(); + for (auto &en : llvm::enumerate(reassociation)) { + auto dstIndex = en.index(); + ArrayRef ref = llvm::makeArrayRef(en.value()); + while (srcShape[ref.back()] == 1 && ref.size() > 1) + ref = ref.drop_back(); + if (!ShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) { + dstDesc.setStride(rewriter, loc, dstIndex, + srcDesc.stride(rewriter, loc, ref.back())); + } else { + // Iterate over the source strides in reverse order. Skip over the dims + // whose size is 1. + // + // +------------------------------------------------------+ + // | curEntry: | + // | %srcStride = strides[srcIndex] | + // | %neOne = cmp sizes[srcIndex],1 +--+ + // | cf.cond_br %neOne, continue(%srcStride), nextEntry | | + // +-------------------------+----------------------------+ | + // | | + // v | + // +-----------------------------+ | + // | nextEntry: | | + // | ... +---+ | + // +--------------+--------------+ | | + // | | | + // v | | + // +-----------------------------+ | | + // | nextEntry: | | | + // | ... | | | + // +--------------+--------------+ | +--------+ + // | | | + // v v v + // +--------------------------------------------------+ + // | continue(%newStride): | + // | %newMemRefDes = setStride(%newStride,dstIndex) | + // +--------------------------------------------------+ + OpBuilder::InsertionGuard guard(rewriter); + Block *initBlock = rewriter.getInsertionBlock(); + Block *continueBlock = + rewriter.splitBlock(initBlock, rewriter.getInsertionPoint()); + continueBlock->insertArgument( + unsigned(0), srcDesc.stride(rewriter, loc, 0).getType(), loc); + rewriter.setInsertionPointToStart(continueBlock); + dstDesc.setStride(rewriter, loc, dstIndex, continueBlock->getArgument(0)); + + Block *curEntryBlock = initBlock; + Block *nextEntryBlock; + for (auto srcIndex : llvm::reverse(ref)) { + rewriter.setInsertionPointToEnd(curEntryBlock); + Value srcStride = srcDesc.stride(rewriter, loc, srcIndex); + if (srcIndex == ref.front()) { + rewriter.create(loc, srcStride, continueBlock); + break; + } + Value one = rewriter.create( + loc, typeConverter->convertType(rewriter.getI64Type()), + rewriter.getI32IntegerAttr(1)); + Value predNeOne = rewriter.create( + loc, LLVM::ICmpPredicate::ne, srcDesc.size(rewriter, loc, srcIndex), + one); + { + OpBuilder::InsertionGuard guard(rewriter); + nextEntryBlock = rewriter.createBlock( + initBlock->getParent(), Region::iterator(continueBlock), {}); + } + rewriter.create(loc, predNeOne, continueBlock, + srcStride, nextEntryBlock, llvm::None); + curEntryBlock = nextEntryBlock; + } + } + } +} + +static void fillInDynamicStridesForMemDescriptor( + ConversionPatternRewriter &b, Location loc, TypeConverter *typeConverter, + MemRefType srcType, MemRefType dstType, MemRefDescriptor &srcDesc, + MemRefDescriptor &dstDesc, ArrayRef reassociation) { + if (srcType.getRank() > dstType.getRank()) + fillInStridesForCollapsedMemDescriptor(b, loc, typeConverter, srcType, + srcDesc, dstDesc, reassociation); + else + fillInStridesForExpandedMemDescriptor(b, loc, srcType, srcDesc, dstDesc, + reassociation); +} + // ReshapeOp creates a new view descriptor of the proper rank. // For now, the only conversion supported is for target MemRef with static sizes // and strides. @@ -1357,18 +1466,10 @@ LogicalResult matchAndRewrite(ReshapeOp reshapeOp, typename ReshapeOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { + OpBuilder::InsertionGuard guard(rewriter); MemRefType dstType = reshapeOp.getResultType(); MemRefType srcType = reshapeOp.getSrcType(); - // The condition on the layouts can be ignored when all shapes are static. - if (!srcType.hasStaticShape() || !dstType.hasStaticShape()) { - if (!srcType.getLayout().isIdentity() || - !dstType.getLayout().isIdentity()) { - return rewriter.notifyMatchFailure( - reshapeOp, "only empty layout map is supported"); - } - } - int64_t offset; SmallVector strides; if (failed(getStridesAndOffset(dstType, strides, offset))) { @@ -1401,14 +1502,9 @@ for (auto &en : llvm::enumerate(strides)) dstDesc.setConstantStride(rewriter, loc, en.index(), en.value()); } else { - Value c1 = rewriter.create(loc, llvmIndexType, - rewriter.getIndexAttr(1)); - Value stride = c1; - for (auto dimIndex : - llvm::reverse(llvm::seq(0, dstShape.size()))) { - dstDesc.setStride(rewriter, loc, dimIndex, stride); - stride = rewriter.create(loc, dstShape[dimIndex], stride); - } + fillInDynamicStridesForMemDescriptor(rewriter, loc, this->typeConverter, + srcType, dstType, srcDesc, dstDesc, + reshapeOp.getReassociationIndices()); } rewriter.replaceOp(reshapeOp, {dstDesc}); return success(); diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir --- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir @@ -802,11 +802,10 @@ // CHECK: llvm.mlir.constant(1 : index) : i64 // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: llvm.mlir.constant(1 : index) : i64 -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64 +// CHECK: llvm.extractvalue %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64 +// CHECK: llvm.extractvalue %{{.*}}[4, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // ----- @@ -830,14 +829,17 @@ // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.mlir.constant(1 : index) : i64 +// CHECK: llvm.extractvalue %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.extractvalue %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64 +// CHECK: llvm.extractvalue %{{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64 +// CHECK: llvm.extractvalue %{{.*}}[3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64 // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64 -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64 - +// CHECK: llvm.extractvalue %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64 // ----- // CHECK-LABEL: func @rank_of_unranked