diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -820,6 +820,11 @@ llvm.mlir.cast %v : llvm.float to f32 llvm.mlir.cast %v : !llvm<"<2 x float>"> to vector<2xf32> }]; + let extraClassDeclaration = [{ + /// Return true if `a` and `b` are valid operand and result types for the + /// operation. + static bool areCastCompatible(Type a, Type b); + }]; let assemblyFormat = "$in attr-dict `:` type($in) `to` type($res)"; let verifier = "return ::verify(*this);"; } diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -170,8 +170,9 @@ Location loc) -> Optional { if (inputs.size() != 1) return llvm::None; - // FIXME: These should check LLVM::DialectCastOp can actually be constructed - // from the input and result. + if (!LLVM::DialectCastOp::areCastCompatible(inputs[0].getType(), + resultType)) + return llvm::None; return builder.create(loc, resultType, inputs[0]) .getResult(); }); @@ -180,8 +181,9 @@ Location loc) -> Optional { if (inputs.size() != 1) return llvm::None; - // FIXME: These should check LLVM::DialectCastOp can actually be constructed - // from the input and result. + if (!LLVM::DialectCastOp::areCastCompatible(inputs[0].getType(), + resultType)) + return llvm::None; return builder.create(loc, resultType, inputs[0]) .getResult(); }); diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -981,26 +981,25 @@ // Verifier for LLVM::DialectCastOp. //===----------------------------------------------------------------------===// -static LogicalResult verify(DialectCastOp op) { - auto verifyMLIRCastType = [&op](Type type) -> LogicalResult { +bool mlir::LLVM::DialectCastOp::areCastCompatible(Type a, Type b) { + auto verifyMLIRCastType = [&](Type type) -> bool { if (auto llvmType = type.dyn_cast()) { if (llvmType.isVectorTy()) llvmType = llvmType.getVectorElementType(); if (llvmType.isIntegerTy() || llvmType.isBFloatTy() || llvmType.isHalfTy() || llvmType.isFloatTy() || llvmType.isDoubleTy()) { - return success(); + return true; } - return op.emitOpError("type must be non-index integer types, float " - "types, or vector of mentioned types."); + return false; } if (auto vectorType = type.dyn_cast()) { if (vectorType.getShape().size() > 1) - return op.emitOpError("only 1-d vector is allowed"); + return false; type = vectorType.getElementType(); } if (type.isSignlessIntOrFloat()) - return success(); + return true; // Note that memrefs are not supported. We currently don't have a use case // for it, but even if we do, there are challenges: // * if we allow memrefs to cast from/to memref descriptors, then the @@ -1010,11 +1009,16 @@ // alternatively want metadata that only present in the descriptor. // // TODO: re-evaluate the memref cast design when it's needed. - return op.emitOpError("type must be non-index integer types, float types, " - "or vector of mentioned types."); + return false; }; - return failure(failed(verifyMLIRCastType(op.in().getType())) || - failed(verifyMLIRCastType(op.getType()))); + return verifyMLIRCastType(a) && verifyMLIRCastType(b); +} + +static LogicalResult verify(DialectCastOp op) { + if (!DialectCastOp::areCastCompatible(op.in().getType(), op.getType())) + return op.emitOpError("types must be non-index integer types, float types, " + "or 1-d vector of mentioned types."); + return success(); } // Parses one of the keywords provided in the list `keywords` and returns the diff --git a/mlir/test/Conversion/StandardToLLVM/invalid.mlir b/mlir/test/Conversion/StandardToLLVM/invalid.mlir --- a/mlir/test/Conversion/StandardToLLVM/invalid.mlir +++ b/mlir/test/Conversion/StandardToLLVM/invalid.mlir @@ -1,7 +1,7 @@ // RUN: mlir-opt %s -convert-std-to-llvm -verify-diagnostics -split-input-file func @mlir_cast_to_llvm(%0 : index) -> !llvm.i64 { - // expected-error@+1 {{'llvm.mlir.cast' op type must be non-index integer types, float types, or vector of mentioned types}} + // expected-error@+1 {{'llvm.mlir.cast' op types must be non-index integer types, float types, or 1-d vector of mentioned types}} %1 = llvm.mlir.cast %0 : index to !llvm.i64 return %1 : !llvm.i64 } @@ -9,7 +9,7 @@ // ----- func @mlir_cast_from_llvm(%0 : !llvm.i64) -> index { - // expected-error@+1 {{'llvm.mlir.cast' op type must be non-index integer types, float types, or vector of mentioned types}} + // expected-error@+1 {{'llvm.mlir.cast' op types must be non-index integer types, float types, or 1-d vector of mentioned types}} %1 = llvm.mlir.cast %0 : !llvm.i64 to index return %1 : index } @@ -25,7 +25,7 @@ // ----- func @mlir_cast_to_llvm_vec(%0 : vector<1x1xf32>) -> !llvm.vec<1 x float> { - // expected-error@+1 {{'llvm.mlir.cast' op only 1-d vector is allowed}} + // expected-error@+1 {{'llvm.mlir.cast' op types must be non-index integer types, float types, or 1-d vector of mentioned types}} %1 = llvm.mlir.cast %0 : vector<1x1xf32> to !llvm.vec<1 x float> return %1 : !llvm.vec<1 x float> }