diff --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp --- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp +++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/Support/FormatVariadic.h" using namespace mlir; @@ -286,6 +287,10 @@ SmallVector returnTypes; for (auto result : ifOp.getResults()) { auto convertedType = typeConverter.convertType(result.getType()); + if (!convertedType) + return rewriter.notifyMatchFailure( + loc, llvm::formatv("failed to convert type '{0}'", result.getType())); + returnTypes.push_back(convertedType); } replaceSCFOutputValue(ifOp, selectionOp, rewriter, scfToSPIRVContext, diff --git a/mlir/test/Conversion/SCFToSPIRV/if.mlir b/mlir/test/Conversion/SCFToSPIRV/if.mlir --- a/mlir/test/Conversion/SCFToSPIRV/if.mlir +++ b/mlir/test/Conversion/SCFToSPIRV/if.mlir @@ -153,4 +153,18 @@ return } +// Memrefs without a spirv storage class are not supported. The conversion +// should preserve the `scf.if` and not crash. +func.func @unsupported_yield_type(%arg0 : memref<8xi32>, %arg1 : memref<8xi32>, %c : i1) { +// CHECK-LABEL: @unsupported_yield_type +// CHECK-NEXT: scf.if +// CHECK: spirv.Return + %r = scf.if %c -> (memref<8xi32>) { + scf.yield %arg0 : memref<8xi32> + } else { + scf.yield %arg1 : memref<8xi32> + } + return +} + } // end module