diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -857,12 +857,18 @@ rewriter.create(loc, numElements, sizeInBytes); Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc); + Value srcOffset = srcDesc.offset(rewriter, loc); + Value srcPtr = rewriter.create(loc, srcBasePtr.getType(), + srcBasePtr, srcOffset); MemRefDescriptor targetDesc(adaptor.target()); Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc); + Value targetOffset = targetDesc.offset(rewriter, loc); + Value targetPtr = rewriter.create(loc, targetBasePtr.getType(), + targetBasePtr, targetOffset); Value isVolatile = rewriter.create( loc, typeConverter->convertType(rewriter.getI1Type()), rewriter.getBoolAttr(false)); - rewriter.create(loc, targetBasePtr, srcBasePtr, totalSize, + rewriter.create(loc, targetPtr, srcPtr, totalSize, isVolatile); rewriter.eraseOp(op); @@ -933,10 +939,18 @@ auto srcType = op.source().getType().cast(); auto targetType = op.target().getType().cast(); - if (srcType.hasRank() && - srcType.cast().getLayout().isIdentity() && - targetType.hasRank() && - targetType.cast().getLayout().isIdentity()) + auto isContiguousMemrefType = [](BaseMemRefType type) { + auto memrefType = type.dyn_cast(); + // We can use memcpy for memrefs if they have an identity layout or are + // contiguous with an arbitrary offset. Ignore empty memrefs, which is a + // special case handled by memrefCopy. + return memrefType && + (memrefType.getLayout().isIdentity() || + (memrefType.hasStaticShape() && memrefType.getNumElements() > 0 && + isStaticShapeAndContiguousRowMajor(memrefType))); + }; + + if (isContiguousMemrefType(srcType) && isContiguousMemrefType(targetType)) return lowerToMemCopyIntrinsic(op, adaptor, rewriter); return lowerToMemCopyFunctionCall(op, adaptor, rewriter); diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir --- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir @@ -933,10 +933,55 @@ // CHECK: [[GEP:%.*]] = llvm.getelementptr [[NULL]][[[ONE2]]] : (!llvm.ptr, i64) -> !llvm.ptr // CHECK: [[PTRTOINT:%.*]] = llvm.ptrtoint [[GEP]] : !llvm.ptr to i64 // CHECK: [[SIZE:%.*]] = llvm.mul [[MUL]], [[PTRTOINT]] : i64 - // CHECK: [[EXTRACT1:%.*]] = llvm.extractvalue {{%.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // CHECK: [[EXTRACT2:%.*]] = llvm.extractvalue {{%.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: [[EXTRACT1P:%.*]] = llvm.extractvalue {{%.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: [[EXTRACT1O:%.*]] = llvm.extractvalue {{%.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: [[GEP1:%.*]] = llvm.getelementptr [[EXTRACT1P]][[[EXTRACT1O]]] : (!llvm.ptr, i64) -> !llvm.ptr + // CHECK: [[EXTRACT2P:%.*]] = llvm.extractvalue {{%.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: [[EXTRACT2O:%.*]] = llvm.extractvalue {{%.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: [[GEP2:%.*]] = llvm.getelementptr [[EXTRACT2P]][[[EXTRACT2O]]] : (!llvm.ptr, i64) -> !llvm.ptr // CHECK: [[VOLATILE:%.*]] = llvm.mlir.constant(false) : i1 - // CHECK: "llvm.intr.memcpy"([[EXTRACT2]], [[EXTRACT1]], [[SIZE]], [[VOLATILE]]) : (!llvm.ptr, !llvm.ptr, i64, i1) -> () + // CHECK: "llvm.intr.memcpy"([[GEP2]], [[GEP1]], [[SIZE]], [[VOLATILE]]) : (!llvm.ptr, !llvm.ptr, i64, i1) -> () + return +} + + +// ----- + +// CHECK-LABEL: func @memref_copy_contiguous +#map = affine_map<(d0, d1)[s0] -> (d0 * 2 + s0 + d1)> +func @memref_copy_contiguous(%in: memref<16x2xi32>, %offset: index) { + %buf = memref.alloc() : memref<1x2xi32> + %sub = memref.subview %in[%offset, 0] [1, 2] [1, 1] : memref<16x2xi32> to memref<1x2xi32, #map> + memref.copy %sub, %buf : memref<1x2xi32, #map> to memref<1x2xi32> + // CHECK: [[EXTRACT0:%.*]] = llvm.extractvalue {{%.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: [[MUL1:%.*]] = llvm.mul {{.*}}, [[EXTRACT0]] : i64 + // CHECK: [[EXTRACT1:%.*]] = llvm.extractvalue {{%.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: [[MUL2:%.*]] = llvm.mul [[MUL1]], [[EXTRACT1]] : i64 + // CHECK: [[NULL:%.*]] = llvm.mlir.null : !llvm.ptr + // CHECK: [[ONE2:%.*]] = llvm.mlir.constant(1 : index) : i64 + // CHECK: [[GEP:%.*]] = llvm.getelementptr [[NULL]][[[ONE2]]] : (!llvm.ptr, i64) -> !llvm.ptr + // CHECK: [[PTRTOINT:%.*]] = llvm.ptrtoint [[GEP]] : !llvm.ptr to i64 + // CHECK: [[SIZE:%.*]] = llvm.mul [[MUL2]], [[PTRTOINT]] : i64 + // CHECK: [[EXTRACT1P:%.*]] = llvm.extractvalue {{%.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: [[EXTRACT1O:%.*]] = llvm.extractvalue {{%.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: [[GEP1:%.*]] = llvm.getelementptr [[EXTRACT1P]][[[EXTRACT1O]]] : (!llvm.ptr, i64) -> !llvm.ptr + // CHECK: [[EXTRACT2P:%.*]] = llvm.extractvalue {{%.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: [[EXTRACT2O:%.*]] = llvm.extractvalue {{%.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: [[GEP2:%.*]] = llvm.getelementptr [[EXTRACT2P]][[[EXTRACT2O]]] : (!llvm.ptr, i64) -> !llvm.ptr + // CHECK: [[VOLATILE:%.*]] = llvm.mlir.constant(false) : i1 + // CHECK: "llvm.intr.memcpy"([[GEP2]], [[GEP1]], [[SIZE]], [[VOLATILE]]) : (!llvm.ptr, !llvm.ptr, i64, i1) -> () + return +} + +// ----- + +// CHECK-LABEL: func @memref_copy_noncontiguous +#map = affine_map<(d0, d1)[s0] -> (d0 * 2 + s0 + d1)> +func @memref_copy_noncontiguous(%in: memref<16x2xi32>, %offset: index) { + %buf = memref.alloc() : memref<2x1xi32> + %sub = memref.subview %in[%offset, 0] [2, 1] [1, 1] : memref<16x2xi32> to memref<2x1xi32, #map> + memref.copy %sub, %buf : memref<2x1xi32, #map> to memref<2x1xi32> + // CHECK: llvm.call @memrefCopy return }