diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -706,12 +706,52 @@ } }; +/// Pattern to lower a `memref.copy` to llvm. +/// +/// For memrefs with identity layouts, the copy is lowered to the llvm +/// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call +/// to the generic `MemrefCopyFn`. struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { + lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + auto srcType = op.source().getType().dyn_cast(); + + MemRefDescriptor srcDesc(adaptor.source()); + + // Compute number of elements. + Value numElements; + for (int pos = 0; pos < srcType.getRank(); ++pos) { + auto size = srcDesc.size(rewriter, loc, pos); + numElements = numElements + ? rewriter.create(loc, numElements, size) + : size; + } + // Get element size. + auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter); + // Compute total. + Value totalSize = + rewriter.create(loc, numElements, sizeInBytes); + + Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc); + MemRefDescriptor targetDesc(adaptor.target()); + Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc); + Value isVolatile = rewriter.create( + loc, typeConverter->convertType(rewriter.getI1Type()), + rewriter.getBoolAttr(false)); + rewriter.create(loc, targetBasePtr, srcBasePtr, totalSize, + isVolatile); + rewriter.eraseOp(op); + + return success(); + } + + LogicalResult + lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { auto loc = op.getLoc(); auto srcType = op.source().getType().cast(); auto targetType = op.target().getType().cast(); @@ -765,6 +805,21 @@ return success(); } + + LogicalResult + matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto srcType = op.source().getType().cast(); + auto targetType = op.target().getType().cast(); + + if (srcType.hasRank() && + srcType.cast().getLayout().isIdentity() && + targetType.hasRank() && + targetType.cast().getLayout().isIdentity()) + return lowerToMemCopyIntrinsic(op, adaptor, rewriter); + + return lowerToMemCopyFunctionCall(op, adaptor, rewriter); + } }; /// Extracts allocated, aligned pointers and offset from a ranked or unranked diff --git a/mlir/test/mlir-cpu-runner/copy.mlir b/mlir/test/mlir-cpu-runner/copy.mlir --- a/mlir/test/mlir-cpu-runner/copy.mlir +++ b/mlir/test/mlir-cpu-runner/copy.mlir @@ -35,7 +35,7 @@ // CHECK-NEXT: [3, 4, 5] %copy_two = memref.alloc() : memref<3x2xf32> - %copy_two_casted = memref.reinterpret_cast %copy_two to offset: [0], sizes: [2, 3], strides:[1, 2] + %copy_two_casted = memref.reinterpret_cast %copy_two to offset: [0], sizes: [2, 3], strides: [1, 2] : memref<3x2xf32> to memref<2x3xf32, offset: 0, strides: [1, 2]> memref.copy %input, %copy_two_casted : memref<2x3xf32> to memref<2x3xf32, offset: 0, strides: [1, 2]> %unranked_copy_two = memref.cast %copy_two : memref<3x2xf32> to memref<*xf32> @@ -49,6 +49,13 @@ %copy_empty = memref.alloc() : memref<3x0x1xf32> // Copying an empty shape should do nothing (and should not crash). memref.copy %input_empty, %copy_empty : memref<3x0x1xf32> to memref<3x0x1xf32> + + %input_empty_casted = memref.reinterpret_cast %input_empty to offset: [0], sizes: [0, 3, 1], strides: [3, 1, 1] + : memref<3x0x1xf32> to memref<0x3x1xf32, offset: 0, strides: [3, 1, 1]> + %copy_empty_casted = memref.alloc() : memref<0x3x1xf32> + // Copying a casted empty shape should do nothing (and should not crash). + memref.copy %input_empty_casted, %copy_empty_casted : memref<0x3x1xf32, offset: 0, strides: [3, 1, 1]> to memref<0x3x1xf32> + memref.dealloc %copy_empty : memref<3x0x1xf32> memref.dealloc %input_empty : memref<3x0x1xf32> memref.dealloc %copy_two : memref<3x2xf32>