diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp --- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp @@ -320,10 +320,13 @@ /// Returns true if the given `type` is a boolean scalar or vector type. static bool isBoolScalarOrVector(Type type) { + assert(type && "Not a valid type"); if (type.isInteger(1)) return true; + if (auto vecType = type.dyn_cast()) return vecType.getElementType().isInteger(1); + return false; } @@ -343,6 +346,22 @@ return aBW != 0 && bBW != 0 && aBW == bBW; } +/// Returns a source type conversion failure for `srcType` and operation `op`. +static LogicalResult +getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op, + Type srcType) { + return rewriter.notifyMatchFailure( + op->getLoc(), + llvm::formatv("failed to convert source type '{0}'", srcType)); +} + +/// Returns a source type conversion failure for the result type of `op`. +static LogicalResult +getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op) { + assert(op->getNumResults() == 1); + return getTypeConversionFailure(rewriter, op, op->getResultTypes().front()); +} + //===----------------------------------------------------------------------===// // ConstantOp with composite type //===----------------------------------------------------------------------===// @@ -562,10 +581,10 @@ Op op, typename Op::Adaptor adaptor, ConversionPatternRewriter &rewriter) const { assert(adaptor.getOperands().size() == 2); - auto dstType = - this->getTypeConverter()->convertType(op.getResult().getType()); + Type dstType = this->getTypeConverter()->convertType(op.getType()); if (!dstType) - return failure(); + return getTypeConversionFailure(rewriter, op); + if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) { rewriter.template replaceOpWithNewOp(op, dstType, adaptor.getOperands()); @@ -590,7 +609,8 @@ Type dstType = getTypeConverter()->convertType(op.getType()); if (!dstType) - return failure(); + return getTypeConversionFailure(rewriter, op); + rewriter.replaceOpWithNewOp(op, dstType, adaptor.getOperands()); @@ -611,7 +631,8 @@ Type dstType = getTypeConverter()->convertType(op.getType()); if (!dstType) - return failure(); + return getTypeConversionFailure(rewriter, op); + rewriter.replaceOpWithNewOp(op, dstType, adaptor.getOperands()); return success(); @@ -628,7 +649,10 @@ if (!isBoolScalarOrVector(srcType)) return failure(); - Type dstType = getTypeConverter()->convertType(op.getResult().getType()); + Type dstType = getTypeConverter()->convertType(op.getType()); + if (!dstType) + return getTypeConversionFailure(rewriter, op); + Location loc = op.getLoc(); Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); @@ -649,7 +673,9 @@ return failure(); Location loc = op.getLoc(); - Type dstType = getTypeConverter()->convertType(op.getResult().getType()); + Type dstType = getTypeConverter()->convertType(op.getType()); + if (!dstType) + return getTypeConversionFailure(rewriter, op); Value allOnes; if (auto intTy = dstType.dyn_cast()) { @@ -684,7 +710,10 @@ if (!isBoolScalarOrVector(srcType)) return failure(); - Type dstType = getTypeConverter()->convertType(op.getResult().getType()); + Type dstType = getTypeConverter()->convertType(op.getType()); + if (!dstType) + return getTypeConversionFailure(rewriter, op); + Location loc = op.getLoc(); Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); @@ -700,7 +729,10 @@ LogicalResult TruncII1Pattern::matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - Type dstType = getTypeConverter()->convertType(op.getResult().getType()); + Type dstType = getTypeConverter()->convertType(op.getType()); + if (!dstType) + return getTypeConversionFailure(rewriter, op); + if (!isBoolScalarOrVector(dstType)) return failure(); @@ -728,10 +760,13 @@ ConversionPatternRewriter &rewriter) const { assert(adaptor.getOperands().size() == 1); Type srcType = adaptor.getOperands().front().getType(); - Type dstType = - this->getTypeConverter()->convertType(op.getResult().getType()); + Type dstType = this->getTypeConverter()->convertType(op.getType()); + if (!dstType) + return getTypeConversionFailure(rewriter, op); + if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType)) return failure(); + if (dstType == srcType) { // Due to type conversion, we are seeing the same source and target type. // Then we can just erase this operation by forwarding its operand. @@ -755,7 +790,7 @@ return failure(); Type dstType = getTypeConverter()->convertType(srcType); if (!dstType) - return failure(); + return getTypeConversionFailure(rewriter, op, srcType); switch (op.getPredicate()) { case arith::CmpIPredicate::eq: { @@ -804,7 +839,7 @@ return failure(); Type dstType = getTypeConverter()->convertType(srcType); if (!dstType) - return failure(); + return getTypeConversionFailure(rewriter, op, srcType); switch (op.getPredicate()) { #define DISPATCH(cmpPredicate, spirvOp) \ @@ -999,7 +1034,7 @@ auto *converter = this->template getTypeConverter(); Type dstType = converter->convertType(op.getType()); if (!dstType) - return failure(); + return getTypeConversionFailure(rewriter, op); // arith.maxf/minf: // "if one of the arguments is NaN, then the result is also NaN." diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir --- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir +++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir @@ -130,3 +130,19 @@ } } // end module + +// ----- + +module attributes { + spirv.target_env = #spirv.target_env< + #spirv.vce, #spirv.resource_limits<>> +} { + +// i64 is not a valid result type in this target env. +func.func @type_conversion_failure(%arg0: i32) { + // expected-error@+1 {{failed to legalize operation 'arith.extsi'}} + %2 = arith.extsi %arg0 : i32 to i64 + return +} + +} // end module