diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -1091,11 +1091,32 @@ Type srcType, memref::ReshapeOp reshapeOp, memref::ReshapeOp::Adaptor adaptor, Value *descriptor) const { - // Conversion for statically-known shape args is performed via - // `memref_reinterpret_cast`. auto shapeMemRefType = reshapeOp.shape().getType().cast(); - if (shapeMemRefType.hasStaticShape()) - return failure(); + if (shapeMemRefType.hasStaticShape()) { + MemRefType targetMemRefType = + reshapeOp.getResult().getType().cast(); + auto llvmTargetDescriptorTy = + typeConverter->convertType(targetMemRefType) + .dyn_cast_or_null(); + if (!llvmTargetDescriptorTy) + return failure(); + + // Create descriptor. + Location loc = reshapeOp.getLoc(); + auto desc = + MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); + + // Set allocated and aligned pointers. + Value allocatedPtr, alignedPtr; + extractPointersAndOffset(loc, rewriter, *getTypeConverter(), + reshapeOp.source(), adaptor.source(), + &allocatedPtr, &alignedPtr); + desc.setAllocatedPtr(rewriter, loc, allocatedPtr); + desc.setAlignedPtr(rewriter, loc, alignedPtr); + + *descriptor = desc; + return success(); + } // The shape is a rank-1 tensor with unknown length. Location loc = reshapeOp.getLoc(); diff --git a/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir b/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir --- a/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir @@ -217,3 +217,20 @@ } } +// ----- + +memref.global "private" constant @__constant_3xi64 : memref<3xi64> = dense<[2, 6, 20]> +func @forward(%arg0: memref<4x5x6xf32>) -> memref<2x6x20xf32> { + %0 = memref.get_global @__constant_3xi64 : memref<3xi64> + + // CHECK: %[[undef:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> + // CHECK: %[[elem0:.*]] = llvm.extractvalue %0[0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> + // CHECK: %[[elem1:.*]] = llvm.extractvalue %0[1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> + // CHECK: %[[first:.*]] = llvm.insertvalue %[[elem0]], %[[undef]][0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> + // CHECK: %[[second:.*]] = llvm.insertvalue %[[elem1]], %[[first:.*]][1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> + // CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %[[second]] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> to memref<2x6x20xf32> + %1 = memref.reshape %arg0(%0) : (memref<4x5x6xf32>, memref<3xi64>) -> memref<2x6x20xf32> + + // CHECK: return %[[cast]] : memref<2x6x20xf32> + return %1 : memref<2x6x20xf32> +}