diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -103,9 +103,12 @@ *getTypeConverter()->getMemRefAddressSpace(allocaOp.getType()); auto elementPtrType = getTypeConverter()->getPointerType(elementType, addrSpace); + auto sizeOfElement = getSizeInBytes(loc, elementType, rewriter); + auto numElements = + rewriter.create(loc, sizeBytes, sizeOfElement); auto allocatedElementPtr = rewriter.create( - loc, elementPtrType, elementType, sizeBytes, + loc, elementPtrType, elementType, numElements, allocaOp.getAlignment().value_or(0)); return std::make_tuple(allocatedElementPtr, allocatedElementPtr); diff --git a/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir --- a/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir +++ b/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir @@ -96,7 +96,7 @@ // CHECK: %[[EL2:.*]] = llvm.extractvalue %[[D]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> // CHECK: %[[EL3:.*]] = llvm.extractvalue %[[D]][2] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> // CHECK: %[[EL4:.*]] = llvm.extractvalue %[[D]][3] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> - // CHECK: %[[BASE:.*]] = llvm.extractvalue %17[1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: %[[BASE:.*]] = llvm.extractvalue %21[1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i64 // CHECK: %[[LI:.*]] = llvm.mul %[[INX]], %[[LDM]] : i64 // CHECK: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]] : i64 @@ -112,7 +112,7 @@ // CHECK32: %[[EL2:.*]] = llvm.extractvalue %[[D]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> // CHECK32: %[[EL3:.*]] = llvm.extractvalue %[[D]][2] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> // CHECK32: %[[EL4:.*]] = llvm.extractvalue %[[D]][3] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> - // CHECK32: %[[BASE:.*]] = llvm.extractvalue %17[1] : !llvm.struct<(ptr<3>, ptr<3>, i32, array<2 x i32>, array<2 x i32>)> + // CHECK32: %[[BASE:.*]] = llvm.extractvalue %21[1] : !llvm.struct<(ptr<3>, ptr<3>, i32, array<2 x i32>, array<2 x i32>)> // CHECK32: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i32 // CHECK32: %[[LI:.*]] = llvm.mul %[[INX]], %[[LDM]] : i32 // CHECK32: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]] : i32 diff --git a/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir b/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir --- a/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir @@ -86,10 +86,14 @@ // CHECK-DAG: %[[N:.*]] = builtin.unrealized_conversion_cast %[[Narg]] // CHECK-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : i64 // CHECK-NEXT: %[[num_elems:.*]] = llvm.mul %[[N]], %[[M]] : i64 -// CHECK-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm.ptr -// CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[num_elems]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 -// CHECK-NEXT: %[[sz_bytes:.*]] = llvm.ptrtoint %[[gep]] : !llvm.ptr to i64 -// CHECK-NEXT: %[[allocated:.*]] = llvm.alloca %[[sz_bytes]] x f32 : (i64) -> !llvm.ptr +// CHECK-NEXT: %[[null1:.*]] = llvm.mlir.null : !llvm.ptr +// CHECK-NEXT: %[[gep1:.*]] = llvm.getelementptr %[[null1]][%[[num_elems]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 +// CHECK-NEXT: %[[sz_bytes:.*]] = llvm.ptrtoint %[[gep1]] : !llvm.ptr to i64 +// CHECK-NEXT: %[[null2:.*]] = llvm.mlir.null : !llvm.ptr +// CHECK-NEXT: %[[gep2:.*]] = llvm.getelementptr %[[null2]][1] : (!llvm.ptr) -> !llvm.ptr, f32 +// CHECK-NEXT: %[[size_elem:.*]] = llvm.ptrtoint %[[gep2]] : !llvm.ptr to i64 +// CHECK-NEXT: %[[size:.*]] = llvm.udiv %[[size_bytes:.*]], %[[size_elem]] : i64 +// CHECK-NEXT: %[[allocated:.*]] = llvm.alloca %[[size]] x f32 : (i64) -> !llvm.ptr // CHECK-NEXT: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: llvm.insertvalue %[[allocated]], %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: llvm.insertvalue %[[allocated]], %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> diff --git a/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir b/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir --- a/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir @@ -79,10 +79,14 @@ // CHECK: %[[sz2:.*]] = llvm.mlir.constant(18 : index) : i64 // CHECK: %[[st2:.*]] = llvm.mlir.constant(1 : index) : i64 // CHECK: %[[num_elems:.*]] = llvm.mlir.constant(576 : index) : i64 -// CHECK: %[[null:.*]] = llvm.mlir.null : !llvm.ptr -// CHECK: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[num_elems]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 -// CHECK: %[[size_bytes:.*]] = llvm.ptrtoint %[[gep]] : !llvm.ptr to i64 -// CHECK: %[[allocated:.*]] = llvm.alloca %[[size_bytes]] x f32 : (i64) -> !llvm.ptr +// CHECK: %[[null1:.*]] = llvm.mlir.null : !llvm.ptr +// CHECK: %[[gep1:.*]] = llvm.getelementptr %[[null1]][%[[num_elems]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 +// CHECK: %[[size_bytes:.*]] = llvm.ptrtoint %[[gep1]] : !llvm.ptr to i64 +// CHECK: %[[null2:.*]] = llvm.mlir.null : !llvm.ptr +// CHECK: %[[gep2:.*]] = llvm.getelementptr %[[null2]][1] : (!llvm.ptr) -> !llvm.ptr, f32 +// CHECK: %[[size_elem:.*]] = llvm.ptrtoint %[[gep2]] : !llvm.ptr to i64 +// CHECK: %[[size:.*]] = llvm.udiv %[[size_bytes:.*]], %[[size_elem]] : i64 +// CHECK: %[[allocated:.*]] = llvm.alloca %[[size]] x f32 : (i64) -> !llvm.ptr %0 = memref.alloca() : memref<32x18xf32> // Test with explicitly specified alignment. llvm.alloca takes care of the