diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h --- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h +++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h @@ -147,6 +147,7 @@ /// Performs the index computation to get to the element at `indices` of the /// memory pointed to by `basePtr`, using the layout map of `baseType`. +/// Returns null if index computation cannot be performed. // TODO: This method assumes that the `baseType` is a MemRefType with AffineMap // that has static strides. Extend to handle dynamic strides. diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp --- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp @@ -268,6 +268,9 @@ spirv::getElementPtr(typeConverter, memrefType, loadOperands.memref(), loadOperands.indices(), loc, rewriter); + if (!accessChainOp) + return failure(); + int srcBits = memrefType.getElementType().getIntOrFloatBitWidth(); bool isBool = srcBits == 1; if (isBool) @@ -358,6 +361,10 @@ auto loadPtr = spirv::getElementPtr( *getTypeConverter(), memrefType, loadOperands.memref(), loadOperands.indices(), loadOp.getLoc(), rewriter); + + if (!loadPtr) + return failure(); + rewriter.replaceOpWithNewOp(loadOp, loadPtr); return success(); } @@ -376,6 +383,10 @@ spirv::AccessChainOp accessChainOp = spirv::getElementPtr(typeConverter, memrefType, storeOperands.memref(), storeOperands.indices(), loc, rewriter); + + if (!accessChainOp) + return failure(); + int srcBits = memrefType.getElementType().getIntOrFloatBitWidth(); bool isBool = srcBits == 1; @@ -467,6 +478,10 @@ spirv::getElementPtr(*getTypeConverter(), memrefType, storeOperands.memref(), storeOperands.indices(), storeOp.getLoc(), rewriter); + + if (!storePtr) + return failure(); + rewriter.replaceOpWithNewOp(storeOp, storePtr, storeOperands.value()); return success();