diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -705,7 +705,12 @@ } }; -struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern { +/// Pattern to lower a `memref.copy` to the generic `MemrefCopyFn`. +/// +/// This pattern only applies in the presence of non-identity maps on the +/// operands. +struct MemRefCopyOpWithMapsLowering + : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult @@ -715,6 +720,13 @@ auto srcType = op.source().getType().cast(); auto targetType = op.target().getType().cast(); + if (srcType.hasRank() && + srcType.cast().getLayout().isIdentity() && + targetType.hasRank() && + targetType.cast().getLayout().isIdentity()) + return rewriter.notifyMatchFailure( + op, "identity layouts should be lowered to memcpy"); + // First make sure we have an unranked memref descriptor representation. auto makeUnranked = [&, this](Value ranked, BaseMemRefType type) { auto rank = rewriter.create( @@ -766,6 +778,56 @@ } }; +/// Pattern to lower a `memref.copy` to the llvm `memcpy` intrinsic. +/// +/// This pattern is a fast-path for copies that operate on memrefs with +/// identity maps, i.e., where the data is contiguous in memory. +struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto srcType = op.source().getType().dyn_cast(); + auto targetType = op.target().getType().dyn_cast(); + + // Only ranked memrefs with identity layout are supported. + if (!srcType || !targetType || !srcType.getLayout().isIdentity() || + !targetType.getLayout().isIdentity()) + return rewriter.notifyMatchFailure( + op, "non-identity layout cannot uses memcpy"); + + MemRefDescriptor srcDesc(adaptor.source()); + + // Compute number of elements. + Value numElements; + for (int pos = 0; pos < srcType.getRank(); ++pos) { + auto size = srcDesc.size(rewriter, loc, pos); + numElements = numElements + ? rewriter.create(loc, numElements, size) + : size; + } + // Get element size. + auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter); + // Compute total. + Value totalSize = + rewriter.create(loc, numElements, sizeInBytes); + + Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc); + MemRefDescriptor targetDesc(adaptor.target()); + Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc); + Value isVolatile = rewriter.create( + loc, typeConverter->convertType(rewriter.getI1Type()), + rewriter.getBoolAttr(false)); + rewriter.create(loc, targetBasePtr, srcBasePtr, totalSize, + isVolatile); + rewriter.eraseOp(op); + + return success(); + } +}; + /// Extracts allocated, aligned pointers and offset from a ranked or unranked /// memref type. In unranked case, the fields are extracted from the underlying /// ranked descriptor. @@ -1568,6 +1630,7 @@ LoadOpLowering, MemRefCastOpLowering, MemRefCopyOpLowering, + MemRefCopyOpWithMapsLowering, MemRefReinterpretCastOpLowering, MemRefReshapeOpLowering, PrefetchOpLowering, diff --git a/mlir/test/mlir-cpu-runner/copy.mlir b/mlir/test/mlir-cpu-runner/copy.mlir --- a/mlir/test/mlir-cpu-runner/copy.mlir +++ b/mlir/test/mlir-cpu-runner/copy.mlir @@ -35,9 +35,9 @@ // CHECK-NEXT: [3, 4, 5] %copy_two = memref.alloc() : memref<3x2xf32> - %copy_two_casted = memref.reinterpret_cast %copy_two to offset: [0], sizes: [2,3], strides:[1, 2] - : memref<3x2xf32> to memref<2x3xf32> - memref.copy %input, %copy_two_casted : memref<2x3xf32> to memref<2x3xf32> + %copy_two_casted = memref.reinterpret_cast %copy_two to offset: [0], sizes: [2, 3], strides:[1, 2] + : memref<3x2xf32> to memref<2x3xf32, offset: 0, strides: [1, 2]> + memref.copy %input, %copy_two_casted : memref<2x3xf32> to memref<2x3xf32, offset: 0, strides: [1, 2]> %unranked_copy_two = memref.cast %copy_two : memref<3x2xf32> to memref<*xf32> call @print_memref_f32(%unranked_copy_two) : (memref<*xf32>) -> () // CHECK: rank = 2 offset = 0 sizes = [3, 2] strides = [2, 1] @@ -49,6 +49,13 @@ %copy_empty = memref.alloc() : memref<3x0x1xf32> // Copying an empty shape should do nothing (and should not crash). memref.copy %input_empty, %copy_empty : memref<3x0x1xf32> to memref<3x0x1xf32> + + %input_empty_casted = memref.reinterpret_cast %input_empty to offset: [0], sizes: [0, 3, 1], strides:[1, 3, 1] + : memref<3x0x1xf32> to memref<0x3x1xf32> + %copy_empty_casted = memref.alloc() : memref<0x3x1xf32> + // Copying a casted empty shape should do nothing (and should not crash). + memref.copy %input_empty_casted, %copy_empty_casted : memref<0x3x1xf32> to memref<0x3x1xf32> + memref.dealloc %copy_empty : memref<3x0x1xf32> memref.dealloc %input_empty : memref<3x0x1xf32> memref.dealloc %copy_two : memref<3x2xf32>