diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp @@ -83,11 +83,12 @@ /// Creates `llvm.mlir.constant` with all bits set for the given type. static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter) { - if (srcType.isa()) + if (srcType.isa()) { return rewriter.create( loc, dstType, SplatElementsAttr::get(srcType.cast(), minusOneIntegerAttribute(srcType, rewriter))); + } return rewriter.create( loc, dstType, minusOneIntegerAttribute(srcType, rewriter)); } @@ -239,7 +240,7 @@ matchAndRewrite(spirv::BitFieldInsertOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto srcType = op.getType(); - auto dstType = this->typeConverter.convertType(srcType); + auto dstType = typeConverter.convertType(srcType); if (!dstType) return failure(); Location loc = op.getLoc(); @@ -328,7 +329,7 @@ matchAndRewrite(spirv::BitFieldSExtractOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto srcType = op.getType(); - auto dstType = this->typeConverter.convertType(srcType); + auto dstType = typeConverter.convertType(srcType); if (!dstType) return failure(); Location loc = op.getLoc(); @@ -381,7 +382,7 @@ matchAndRewrite(spirv::BitFieldUExtractOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto srcType = op.getType(); - auto dstType = this->typeConverter.convertType(srcType); + auto dstType = typeConverter.convertType(srcType); if (!dstType) return failure(); Location loc = op.getLoc(); @@ -473,7 +474,7 @@ } // Function returns a single result. - auto dstType = this->typeConverter.convertType(callOp.getType(0)); + auto dstType = typeConverter.convertType(callOp.getType(0)); rewriter.replaceOpWithNewOp(callOp, dstType, operands, callOp.getAttrs()); return success(); @@ -638,7 +639,7 @@ auto funcType = funcOp.getType(); TypeConverter::SignatureConversion signatureConverter( funcType.getNumInputs()); - auto llvmType = this->typeConverter.convertFunctionSignature( + auto llvmType = typeConverter.convertFunctionSignature( funcOp.getType(), /*isVariadic=*/false, signatureConverter); if (!llvmType) return failure(); @@ -675,7 +676,10 @@ rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), newFuncOp.end()); - rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter); + if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter, + &signatureConverter))) { + return failure(); + } rewriter.eraseOp(funcOp); return success(); }