diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp --- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp @@ -304,6 +304,9 @@ // Get the SPIR-V type for the allocation. Type spirvType = getTypeConverter()->convertType(allocType); + if (!spirvType) + return rewriter.notifyMatchFailure(allocaOp, "type conversion failed"); + rewriter.replaceOpWithNewOp(allocaOp, spirvType, spirv::StorageClass::Function, /*initializer=*/nullptr); @@ -323,6 +326,8 @@ // Get the SPIR-V type for the allocation. Type spirvType = getTypeConverter()->convertType(allocType); + if (!spirvType) + return rewriter.notifyMatchFailure(operation, "type conversion failed"); // Insert spirv.GlobalVariable for this allocation. Operation *parent = @@ -467,7 +472,7 @@ int dstBits = dstType.getIntOrFloatBitWidth(); assert(dstBits % srcBits == 0); - // If the rewrited load op has the same bit width, use the loading value + // If the rewritten load op has the same bit width, use the loading value // directly. if (srcBits == dstBits) { Value loadVal = rewriter.create(loc, accessChain); @@ -701,12 +706,16 @@ Value result = adaptor.getSource(); Type resultPtrType = typeConverter.convertType(resultType); + if (!resultPtrType) + return rewriter.notifyMatchFailure(addrCastOp, + "failed to convert memref type"); + Type genericPtrType = resultPtrType; // SPIR-V doesn't have a general address space cast operation. Instead, it has // conversions to and from generic pointers. To implement the general case, // we use specific-to-generic conversions when the source class is not // generic. Then when the result storage class is not generic, we convert the - // generic pointer (either the input on ar intermediate result) to theat + // generic pointer (either the input on ar intermediate result) to that // class. This also means that we'll need the intermediate generic pointer // type if neither the source or destination have it. if (sourceSc != spirv::StorageClass::Generic && diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -563,6 +563,12 @@ return nullptr; } + if (*memrefSize == 0) { + LLVM_DEBUG(llvm::dbgs() + << type << " illegal: zero-element memrefs are not supported\n"); + return nullptr; + } + int64_t arrayElemCount = llvm::divideCeil(*memrefSize, *arrayElemSize); int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0; auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride); diff --git a/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir b/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir --- a/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir +++ b/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir @@ -51,7 +51,6 @@ // CHECK: %{{.+}} = spirv.AtomicAnd "Workgroup" "AcquireRelease" %[[PTR]], %{{.+}} : !spirv.ptr // CHECK: %{{.+}} = spirv.AtomicOr "Workgroup" "AcquireRelease" %[[PTR]], %{{.+}} : !spirv.ptr - // ----- module attributes { @@ -92,7 +91,6 @@ // CHECK-SAME: !spirv.ptr>)>, Workgroup> // CHECK-LABEL: func @two_allocs_vector() - // ----- module attributes { @@ -179,3 +177,19 @@ // CHECK: %[[STOREPTR:.+]] = spirv.AccessChain %[[PTR]] // CHECK: spirv.Store "Workgroup" %[[STOREPTR]], %[[VAL]] : f32 // CHECK-NOT: memref.dealloc + +// ----- + +module attributes { + spirv.target_env = #spirv.target_env< + #spirv.vce, #spirv.resource_limits<>> + } +{ + func.func @zero_size() { + %0 = memref.alloc() : memref<0xf32, #spirv.storage_class> + return + } +} + +// Zero-sized allocations are not handled yet. Just make sure we do not crash. +// CHECK-LABEL: func @zero_size() diff --git a/mlir/test/Conversion/MemRefToSPIRV/alloca.mlir b/mlir/test/Conversion/MemRefToSPIRV/alloca.mlir --- a/mlir/test/Conversion/MemRefToSPIRV/alloca.mlir +++ b/mlir/test/Conversion/MemRefToSPIRV/alloca.mlir @@ -16,7 +16,6 @@ // CHECK: %[[STOREPTR:.+]] = spirv.AccessChain %[[VAR]] // CHECK: spirv.Store "Function" %[[STOREPTR]], %[[VAL]] : f32 - // ----- module attributes {spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>>} { @@ -69,3 +68,15 @@ return %1: f32 } } + +// ----- + +module attributes {spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>>} { + func.func @zero_size() { + %0 = memref.alloca() : memref<0xf32, #spirv.storage_class> + return + } +} + +// Zero-sized allocations are not handled yet. Just make sure we do not crash. +// CHECK-LABEL: func @zero_size diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir --- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir +++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir @@ -608,4 +608,14 @@ return %ret : memref<4x4xf32, #spirv.storage_class> } +// TODO: Not supported yet +// CHECK-LABEL: func.func @cast_to_static_zero_elems +// CHECK-SAME: (%[[MEM:.*]]: memref>) +func.func @cast_to_static_zero_elems(%arg: memref>) -> memref<0xf32, #spirv.storage_class> { +// CHECK: %[[MEM1:.*]] = memref.cast %[[MEM]] : memref> to memref<0xf32, #spirv.storage_class> +// CHECK: return %[[MEM1]] + %ret = memref.cast %arg : memref> to memref<0xf32, #spirv.storage_class> + return %ret : memref<0xf32, #spirv.storage_class> +} + }