diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp @@ -1033,6 +1033,9 @@ spirv::getElementPtr(typeConverter, memrefType, loadOperands.memref(), loadOperands.indices(), loc, rewriter); + if (!accessChainOp) + return failure(); + int srcBits = memrefType.getElementType().getIntOrFloatBitWidth(); bool isBool = srcBits == 1; if (isBool) @@ -1123,6 +1126,10 @@ auto loadPtr = spirv::getElementPtr( *getTypeConverter(), memrefType, loadOperands.memref(), loadOperands.indices(), loadOp.getLoc(), rewriter); + + if (!loadPtr) + return failure(); + rewriter.replaceOpWithNewOp(loadOp, loadPtr); return success(); } @@ -1194,6 +1201,10 @@ spirv::AccessChainOp accessChainOp = spirv::getElementPtr(typeConverter, memrefType, storeOperands.memref(), storeOperands.indices(), loc, rewriter); + + if (!accessChainOp) + return failure(); + int srcBits = memrefType.getElementType().getIntOrFloatBitWidth(); bool isBool = srcBits == 1; @@ -1285,6 +1296,10 @@ spirv::getElementPtr(*getTypeConverter(), memrefType, storeOperands.memref(), storeOperands.indices(), storeOp.getLoc(), rewriter); + + if (!storePtr) + return failure(); + rewriter.replaceOpWithNewOp(storeOp, storePtr, storeOperands.value()); return success();