diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -723,13 +723,13 @@ MemRefDescriptor srcDesc(adaptor.source()); // Compute number of elements. - Value numElements; + Value numElements = rewriter.create( + loc, getIndexType(), rewriter.getIndexAttr(1)); for (int pos = 0; pos < srcType.getRank(); ++pos) { auto size = srcDesc.size(rewriter, loc, pos); - numElements = numElements - ? rewriter.create(loc, numElements, size) - : size; + numElements = rewriter.create(loc, numElements, size); } + // Get element size. auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter); // Compute total. diff --git a/mlir/test/mlir-cpu-runner/copy.mlir b/mlir/test/mlir-cpu-runner/copy.mlir --- a/mlir/test/mlir-cpu-runner/copy.mlir +++ b/mlir/test/mlir-cpu-runner/copy.mlir @@ -8,6 +8,7 @@ func @main() -> () { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index + %c42 = arith.constant 42.0 : f32 // Initialize input. %input = memref.alloc() : memref<2x3xf32> @@ -56,10 +57,22 @@ // Copying a casted empty shape should do nothing (and should not crash). memref.copy %input_empty_casted, %copy_empty_casted : memref<0x3x1xf32, offset: 0, strides: [3, 1, 1]> to memref<0x3x1xf32> + %scalar = memref.alloc() : memref + memref.store %c42, %scalar[] : memref + %scalar_copy = memref.alloc() : memref + memref.copy %scalar, %scalar_copy : memref to memref + %unranked_scalar_copy = memref.cast %scalar_copy : memref to memref<*xf32> + call @print_memref_f32(%unranked_scalar_copy) : (memref<*xf32>) -> () + // CHECK: rank = 0 offset = 0 sizes = [] strides = [] + // CHECK-NEXT [42] + memref.dealloc %copy_empty : memref<3x0x1xf32> + memref.dealloc %copy_empty_casted : memref<0x3x1xf32> memref.dealloc %input_empty : memref<3x0x1xf32> memref.dealloc %copy_two : memref<3x2xf32> memref.dealloc %copy : memref<2x3xf32> memref.dealloc %input : memref<2x3xf32> + memref.dealloc %scalar : memref + memref.dealloc %scalar_copy : memref return }