diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -223,6 +223,14 @@ /// the type to convert to on success, and a null type on failure. Type convertType(Type t); + /// Attempts a 1-1 type conversion, expecting the result type to be + /// `TargetType`. Returns the converted type cast to `TargetType` on success, + /// and a null type on conversion or cast failure. + template + TargetType convertType(Type t) { + return dyn_cast_or_null(convertType(t)); + } + /// Convert the given set of types, filling 'results' as necessary. This /// returns failure if the conversion of any of the types fails, success /// otherwise. diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp --- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" #include "mlir/IR/BuiltinTypes.h" #include "llvm/Support/Debug.h" @@ -410,9 +411,12 @@ bool isBool = srcBits == 1; if (isBool) srcBits = typeConverter.getOptions().boolNumBits; - Type pointeeType = typeConverter.convertType(memrefType) - .cast() - .getPointeeType(); + + auto pointerType = typeConverter.convertType(memrefType); + if (!pointerType) + return rewriter.notifyMatchFailure(loadOp, "failed to convert memref type"); + + Type pointeeType = pointerType.getPointeeType(); Type dstType; if (typeConverter.allows(spirv::Capability::Kernel)) { if (auto arrayType = pointeeType.dyn_cast()) @@ -541,9 +545,12 @@ if (isBool) srcBits = typeConverter.getOptions().boolNumBits; - Type pointeeType = typeConverter.convertType(memrefType) - .cast() - .getPointeeType(); + auto pointerType = typeConverter.convertType(memrefType); + if (!pointerType) + return rewriter.notifyMatchFailure(storeOp, + "failed to convert memref type"); + + Type pointeeType = pointerType.getPointeeType(); Type dstType; if (typeConverter.allows(spirv::Capability::Kernel)) { if (auto arrayType = pointeeType.dyn_cast()) diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp --- a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp @@ -13,6 +13,7 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/Transforms/FuncConversions.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" @@ -206,7 +207,11 @@ matchAndRewrite(arith::ConstantOp op, OpAdaptor, ConversionPatternRewriter &rewriter) const override { Type oldType = op.getType(); - auto newType = getTypeConverter()->convertType(oldType).cast(); + auto newType = getTypeConverter()->convertType(oldType); + if (!newType) + return rewriter.notifyMatchFailure( + op, llvm::formatv("unsupported type: {0}", op.getType())); + unsigned newBitWidth = newType.getElementTypeBitWidth(); Attribute oldValue = op.getValueAttr(); @@ -264,9 +269,7 @@ matchAndRewrite(arith::AddIOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); - auto newTy = getTypeConverter() - ->convertType(op.getType()) - .dyn_cast_or_null(); + auto newTy = getTypeConverter()->convertType(op.getType()); if (!newTy) return rewriter.notifyMatchFailure( loc, llvm::formatv("unsupported type: {0}", op.getType())); @@ -307,9 +310,8 @@ matchAndRewrite(BinaryOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); - auto newTy = this->getTypeConverter() - ->convertType(op.getType()) - .template dyn_cast_or_null(); + auto newTy = this->getTypeConverter()->template convertType( + op.getType()); if (!newTy) return rewriter.notifyMatchFailure( loc, llvm::formatv("unsupported type: {0}", op.getType())); @@ -357,9 +359,8 @@ matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); - auto inputTy = getTypeConverter() - ->convertType(op.getLhs().getType()) - .dyn_cast_or_null(); + auto inputTy = + getTypeConverter()->convertType(op.getLhs().getType()); if (!inputTy) return rewriter.notifyMatchFailure( loc, llvm::formatv("unsupported type: {0}", op.getType())); @@ -414,9 +415,7 @@ matchAndRewrite(arith::MulIOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); - auto newTy = getTypeConverter() - ->convertType(op.getType()) - .dyn_cast_or_null(); + auto newTy = getTypeConverter()->convertType(op.getType()); if (!newTy) return rewriter.notifyMatchFailure( loc, llvm::formatv("unsupported type: {0}", op.getType())); @@ -457,9 +456,7 @@ matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); - auto newTy = getTypeConverter() - ->convertType(op.getType()) - .dyn_cast_or_null(); + auto newTy = getTypeConverter()->convertType(op.getType()); if (!newTy) return rewriter.notifyMatchFailure( loc, llvm::formatv("unsupported type: {0}", op.getType())); @@ -497,9 +494,7 @@ matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); - auto newTy = getTypeConverter() - ->convertType(op.getType()) - .dyn_cast_or_null(); + auto newTy = getTypeConverter()->convertType(op.getType()); if (!newTy) return rewriter.notifyMatchFailure( loc, llvm::formatv("unsupported type: {0}", op.getType())); @@ -577,9 +572,8 @@ Location loc = op.getLoc(); Type inType = op.getIn().getType(); - auto newInTy = this->getTypeConverter() - ->convertType(inType) - .template dyn_cast_or_null(); + auto newInTy = + this->getTypeConverter()->template convertType(inType); if (!newInTy) return rewriter.notifyMatchFailure( loc, llvm::formatv("unsupported type: {0}", inType)); @@ -608,8 +602,7 @@ this->template getTypeConverter(); Type resultType = op.getType(); - auto newTy = typeConverter->convertType(resultType) - .template dyn_cast_or_null(); + auto newTy = typeConverter->template convertType(resultType); if (!newTy) return rewriter.notifyMatchFailure( loc, llvm::formatv("unsupported type: {0}", resultType)); @@ -640,9 +633,7 @@ matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); - auto newTy = getTypeConverter() - ->convertType(op.getType()) - .dyn_cast_or_null(); + auto newTy = getTypeConverter()->convertType(op.getType()); if (!newTy) return rewriter.notifyMatchFailure( loc, llvm::formatv("unsupported type: {0}", op.getType())); @@ -677,8 +668,7 @@ Location loc = op->getLoc(); Type oldTy = op.getType(); - auto newTy = - getTypeConverter()->convertType(oldTy).dyn_cast_or_null(); + auto newTy = getTypeConverter()->convertType(oldTy); if (!newTy) return rewriter.notifyMatchFailure( loc, llvm::formatv("unsupported type: {0}", op.getType())); @@ -767,8 +757,7 @@ Location loc = op->getLoc(); Type oldTy = op.getType(); - auto newTy = - getTypeConverter()->convertType(oldTy).dyn_cast_or_null(); + auto newTy = getTypeConverter()->convertType(oldTy); if (!newTy) return rewriter.notifyMatchFailure( loc, llvm::formatv("unsupported type: {0}", op.getType())); @@ -857,8 +846,7 @@ Location loc = op->getLoc(); Type oldTy = op.getType(); - auto newTy = - getTypeConverter()->convertType(oldTy).dyn_cast_or_null(); + auto newTy = getTypeConverter()->convertType(oldTy); if (!newTy) return rewriter.notifyMatchFailure( loc, llvm::formatv("unsupported type: {0}", op.getType())); @@ -922,8 +910,7 @@ Value in = op.getIn(); Type oldTy = in.getType(); - auto newTy = - dyn_cast_or_null(getTypeConverter()->convertType(oldTy)); + auto newTy = getTypeConverter()->convertType(oldTy); if (!newTy) return rewriter.notifyMatchFailure( loc, llvm::formatv("unsupported type: {0}", oldTy)); @@ -967,8 +954,7 @@ Location loc = op.getLoc(); Type oldTy = op.getIn().getType(); - auto newTy = - dyn_cast_or_null(getTypeConverter()->convertType(oldTy)); + auto newTy = getTypeConverter()->convertType(oldTy); if (!newTy) return rewriter.notifyMatchFailure( loc, llvm::formatv("unsupported type: {0}", oldTy));