diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -1067,40 +1067,190 @@ // Verifier for LLVM::DialectCastOp. //===----------------------------------------------------------------------===// -static LogicalResult verify(DialectCastOp op) { - auto verifyMLIRCastType = [&op](Type type) -> LogicalResult { - if (auto llvmType = type.dyn_cast()) { - if (llvmType.isVectorTy()) - llvmType = llvmType.getVectorElementType(); - if (llvmType.isIntegerTy() || llvmType.isBFloatTy() || - llvmType.isHalfTy() || llvmType.isFloatTy() || - llvmType.isDoubleTy()) { - return success(); - } - return op.emitOpError("type must be non-index integer types, float " - "types, or vector of mentioned types."); +/// Checks if `llvmType` is dialect cast-compatible with `index` type. Does not +/// report the error, the user is expected to produce an appropriate message. +// TODO: make the size depend on data layout rather than on the conversion +// pass option, and pull that information here. +static LogicalResult verifyCastWithIndex(LLVMType llvmType) { + return success(llvmType.isa()); +} + +/// Checks if `llvmType` is dialect cast-compatible with built-in `type` and +/// reports errors to the location of `op`. +static LogicalResult verifyCast(DialectCastOp op, LLVMType llvmType, + Type type) { + // Index is compatible with any integer. + if (type.isIndex()) { + if (succeeded(verifyCastWithIndex(llvmType))) + return success(); + + return op.emitOpError("invalid cast between index and non-integer type"); + } + + // Simple one-to-one mappings for floating point types. + if (type.isF16()) { + if (llvmType.isa()) + return success(); + return op.emitOpError( + "invalid cast between f16 and a type other than !llvm.half"); + } + if (type.isBF16()) { + if (llvmType.isa()) + return success(); + return op->emitOpError( + "invalid cast between bf16 and a type other than !llvm.bfloat"); + } + if (type.isF32()) { + if (llvmType.isa()) + return success(); + return op->emitOpError( + "invalid cast between f32 and a type other than !llvm.float"); + } + if (type.isF64()) { + if (llvmType.isa()) + return success(); + return op->emitOpError( + "invalid cast between f64 and a type other than !llvm.double"); + } + + // Singless integers are compatible with LLVM integer of the same bitwidth. + if (type.isSignlessInteger()) { + auto llvmInt = llvmType.dyn_cast(); + if (!llvmInt) + return op->emitOpError( + "invalid cast between integer and non-integer type"); + if (llvmInt.getBitWidth() == type.getIntOrFloatBitWidth()) + return success(); + + return op->emitOpError( + "invalid cast between integers with mismatching bitwidth"); + } + + // Vectors are compatible if they are 1D non-scalable, and their element types + // are compatible. + if (auto vectorType = type.dyn_cast()) { + if (vectorType.getRank() != 1) + return op->emitOpError("only 1-d vector is allowed"); + + auto llvmVector = llvmType.dyn_cast(); + if (llvmVector.isa()) + return op->emitOpError("only fixed-sized vector is allowed"); + + if (vectorType.getDimSize(0) != llvmVector.getVectorNumElements()) + return op->emitOpError( + "invalid cast between vectors with mismatching sizes"); + + return verifyCast(op, llvmVector.getElementType(), + vectorType.getElementType()); + } + + if (auto memrefType = type.dyn_cast()) { + // Bare pointer convention: statically-shaped memref is compatible with an + // LLVM pointer to the element type. + if (auto ptrType = llvmType.dyn_cast()) { + if (!memrefType.hasStaticShape()) + return op->emitOpError( + "unexpected bare pointer for dynamically shaped memref"); + if (memrefType.getMemorySpace() != ptrType.getAddressSpace()) + return op->emitError("invalid conversion between memref and pointer in " + "different memory spaces"); + + return verifyCast(op, ptrType.getElementType(), + memrefType.getElementType()); } - if (auto vectorType = type.dyn_cast()) { - if (vectorType.getShape().size() > 1) - return op.emitOpError("only 1-d vector is allowed"); - type = vectorType.getElementType(); + + // Otherwise, memrefs are convertible to a descriptor, which is a structure + // type. + auto structType = llvmType.dyn_cast(); + if (!structType) + return op->emitOpError("invalid cast between a memref and a type other " + "than pointer or memref descriptor"); + + unsigned expectedNumElements = memrefType.getRank() == 0 ? 3 : 5; + if (structType.getBody().size() != expectedNumElements) { + return op->emitOpError() << "expected memref descriptor with " + << expectedNumElements << " elements"; } - if (type.isSignlessIntOrFloat()) + + // The first two elements are pointers to the element type. + auto allocatedPtr = structType.getBody()[0].dyn_cast(); + if (!allocatedPtr || + allocatedPtr.getAddressSpace() != memrefType.getMemorySpace()) + return op->emitOpError("expected first element of a memref descriptor to " + "be a pointer in the address space of the memref"); + if (failed(verifyCast(op, allocatedPtr.getElementType(), + memrefType.getElementType()))) + return failure(); + + auto alignedPtr = structType.getBody()[1].dyn_cast(); + if (!alignedPtr || + alignedPtr.getAddressSpace() != memrefType.getMemorySpace()) + return op->emitOpError( + "expected second element of a memref descriptor to " + "be a pointer in the address space of the memref"); + if (failed(verifyCast(op, alignedPtr.getElementType(), + memrefType.getElementType()))) + return failure(); + + // The second element (offset) is an equivalent of index. + if (failed(verifyCastWithIndex(structType.getBody()[2]))) + return op->emitOpError("expected third element of a memref descriptor to " + "be index-compatible integers"); + + // 0D memrefs don't have sizes/strides. + if (memrefType.getRank() == 0) return success(); - // Note that memrefs are not supported. We currently don't have a use case - // for it, but even if we do, there are challenges: - // * if we allow memrefs to cast from/to memref descriptors, then the - // semantics of the cast op depends on the implementation detail of the - // descriptor. - // * if we allow memrefs to cast from/to bare pointers, some users might - // alternatively want metadata that only present in the descriptor. - // - // TODO: re-evaluate the memref cast design when it's needed. - return op.emitOpError("type must be non-index integer types, float types, " - "or vector of mentioned types."); - }; - return failure(failed(verifyMLIRCastType(op.in().getType())) || - failed(verifyMLIRCastType(op.getType()))); + + // Sizes and strides are rank-sized arrays of `index` equivalents. + auto sizes = structType.getBody()[3].dyn_cast(); + if (!sizes || failed(verifyCastWithIndex(sizes.getElementType())) || + sizes.getNumElements() != memrefType.getRank()) + return op->emitOpError( + "expected fourth element of a memref descriptor " + "to be an array of index-compatible integers"); + + auto strides = structType.getBody()[4].dyn_cast(); + if (!strides || failed(verifyCastWithIndex(strides.getElementType())) || + strides.getNumElements() != memrefType.getRank()) + return op->emitOpError( + "expected fifth element of a memref descriptor " + "to be an array of index-compatible integers"); + + return success(); + } + + // Unranked memrefs are compatible with their descriptors. + if (auto unrankedMemrefType = type.dyn_cast()) { + auto structType = llvmType.dyn_cast(); + if (!structType || structType.getBody().size() != 2) + return op->emitOpError( + "expected descriptor to be a struct with two elements"); + + if (failed(verifyCastWithIndex(structType.getBody()[0]))) + return op->emitOpError("expected first element of a memref descriptor to " + "be an index-compatible integer"); + + auto ptrType = structType.getBody()[1].dyn_cast(); + if (!ptrType || !ptrType.getPointerElementTy().isIntegerTy(8)) + return op->emitOpError("expected second element of a memref descriptor " + "to be an !llvm.ptr"); + + return success(); + } + + // Everything else is not supported. + return op->emitError("unsupported cast"); +} + +static LogicalResult verify(DialectCastOp op) { + if (auto llvmType = op.getType().dyn_cast()) + return verifyCast(op, llvmType, op.in().getType()); + + auto llvmType = op.in().getType().dyn_cast(); + if (!llvmType) + return op->emitOpError("expected one LLVM type and one built-in type"); + + return verifyCast(op, llvmType, op.getType()); } // Parses one of the keywords provided in the list `keywords` and returns the diff --git a/mlir/test/Conversion/StandardToLLVM/invalid.mlir b/mlir/test/Conversion/StandardToLLVM/invalid.mlir --- a/mlir/test/Conversion/StandardToLLVM/invalid.mlir +++ b/mlir/test/Conversion/StandardToLLVM/invalid.mlir @@ -1,37 +1,5 @@ // RUN: mlir-opt %s -convert-std-to-llvm -verify-diagnostics -split-input-file -func @mlir_cast_to_llvm(%0 : index) -> !llvm.i64 { - // expected-error@+1 {{'llvm.mlir.cast' op type must be non-index integer types, float types, or vector of mentioned types}} - %1 = llvm.mlir.cast %0 : index to !llvm.i64 - return %1 : !llvm.i64 -} - -// ----- - -func @mlir_cast_from_llvm(%0 : !llvm.i64) -> index { - // expected-error@+1 {{'llvm.mlir.cast' op type must be non-index integer types, float types, or vector of mentioned types}} - %1 = llvm.mlir.cast %0 : !llvm.i64 to index - return %1 : index -} - -// ----- - -func @mlir_cast_to_llvm_int(%0 : i32) -> !llvm.i64 { - // expected-error@+1 {{failed to legalize operation 'llvm.mlir.cast' that was explicitly marked illegal}} - %1 = llvm.mlir.cast %0 : i32 to !llvm.i64 - return %1 : !llvm.i64 -} - -// ----- - -func @mlir_cast_to_llvm_vec(%0 : vector<1x1xf32>) -> !llvm.vec<1 x float> { - // expected-error@+1 {{'llvm.mlir.cast' op only 1-d vector is allowed}} - %1 = llvm.mlir.cast %0 : vector<1x1xf32> to !llvm.vec<1 x float> - return %1 : !llvm.vec<1 x float> -} - -// ----- - // Should not crash on unsupported types in function signatures. func private @unsupported_signature() -> tensor<10 x i32> diff --git a/mlir/test/Dialect/LLVMIR/dialect-cast.mlir b/mlir/test/Dialect/LLVMIR/dialect-cast.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/LLVMIR/dialect-cast.mlir @@ -0,0 +1,258 @@ +// RUN: mlir-opt -split-input-file -verify-diagnostics %s + +// These are the supported cases, just make sure they don't trigger errors, op +// syntax is tested elsewhere. + +func @mlir_dialect_cast(%0: index, %1: i32, %2: bf16, %3: f16, %4: f32, %5: f64, + %6: vector<42xf32>, %7: memref<42xf32>, + %8: memref, %9: memref, + %10: memref<*xf32>) { + llvm.mlir.cast %0 : index to !llvm.i64 + llvm.mlir.cast %0 : index to !llvm.i32 + llvm.mlir.cast %1 : i32 to !llvm.i32 + llvm.mlir.cast %2 : bf16 to !llvm.bfloat + llvm.mlir.cast %3 : f16 to !llvm.half + llvm.mlir.cast %4 : f32 to !llvm.float + llvm.mlir.cast %5 : f64 to !llvm.double + llvm.mlir.cast %6 : vector<42xf32> to !llvm.vec<42xfloat> + llvm.mlir.cast %7 : memref<42xf32> to !llvm.ptr + llvm.mlir.cast %7 : memref<42xf32> to !llvm.struct<(ptr, ptr, i64, array<1xi64>, array<1xi64>)> + llvm.mlir.cast %8 : memref to !llvm.struct<(ptr, ptr, i64, array<1xi64>, array<1xi64>)> + llvm.mlir.cast %9 : memref to !llvm.struct<(ptr, ptr, i64)> + llvm.mlir.cast %10 : memref<*xf32> to !llvm.struct<(i64, ptr)> + return +} + +// ----- + +func @mlir_dialect_cast_index_non_integer(%0 : index) { + // expected-error@+1 {{invalid cast between index and non-integer type}} + %1 = llvm.mlir.cast %0 : index to !llvm.float +} + +// ----- + +// Cast verifier is symmetric, so we only check the symmetry once by having an +// std->llvm and llvm->std test. Everything else is std->llvm. + +func @mlir_dialect_cast_index_non_integer_symmetry(%0: !llvm.float) { + // expected-error@+1 {{invalid cast between index and non-integer type}} + llvm.mlir.cast %0 : !llvm.float to index +} + +// ----- + +func @mlir_dialect_cast_f16(%0 : f16) { + // expected-error@+1 {{invalid cast between f16 and a type other than !llvm.half}} + llvm.mlir.cast %0 : f16 to !llvm.float +} + +// ----- + +func @mlir_dialect_cast_bf16(%0 : bf16) { + // expected-error@+1 {{invalid cast between bf16 and a type other than !llvm.bfloat}} + llvm.mlir.cast %0 : bf16 to !llvm.half +} + +// ----- + +func @mlir_dialect_cast_f32(%0 : f32) { + // expected-error@+1 {{invalid cast between f32 and a type other than !llvm.float}} + llvm.mlir.cast %0 : f32 to !llvm.bfloat +} + +// ----- + +func @mlir_dialect_cast_f64(%0 : f64) { + // expected-error@+1 {{invalid cast between f64 and a type other than !llvm.double}} + llvm.mlir.cast %0 : f64 to !llvm.float +} + +// ----- + +func @mlir_dialect_cast_integer_non_integer(%0 : i16) { + // expected-error@+1 {{invalid cast between integer and non-integer type}} + llvm.mlir.cast %0 : i16 to !llvm.half +} + +// ----- + +func @mlir_dialect_cast_integer_bitwidth_mismatch(%0 : i16) { + // expected-error@+1 {{invalid cast between integers with mismatching bitwidth}} + llvm.mlir.cast %0 : i16 to !llvm.i32 +} + +// ----- + +func @mlir_dialect_cast_nd_vector(%0 : vector<2x2xf32>) { + // expected-error@+1 {{only 1-d vector is allowed}} + llvm.mlir.cast %0 : vector<2x2xf32> to !llvm.vec<4xfloat> +} + +// ----- + +func @mlir_dialect_cast_scalable_vector(%0 : vector<2xf32>) { + // expected-error@+1 {{only fixed-sized vector is allowed}} + llvm.mlir.cast %0 : vector<2xf32> to !llvm.vec +} + +// ----- + +func @mlir_dialect_cast_vector_size_mismatch(%0 : vector<2xf32>) { + // expected-error@+1 {{invalid cast between vectors with mismatching sizes}} + llvm.mlir.cast %0 : vector<2xf32> to !llvm.vec<4xfloat> +} + +// ----- + +func @mlir_dialect_cast_dynamic_memref_bare_ptr(%0 : memref) { + // expected-error@+1 {{unexpected bare pointer for dynamically shaped memref}} + llvm.mlir.cast %0 : memref to !llvm.ptr +} + +// ----- + +func @mlir_dialect_cast_memref_bare_ptr_space(%0 : memref<4xf32, 4>) { + // expected-error@+1 {{invalid conversion between memref and pointer in different memory spaces}} + llvm.mlir.cast %0 : memref<4xf32, 4> to !llvm.ptr +} + +// ----- + +func @mlir_dialect_cast_memref_no_descriptor(%0 : memref) { + // expected-error@+1 {{invalid cast between a memref and a type other than pointer or memref descriptor}} + llvm.mlir.cast %0 : memref to !llvm.float +} + +// ----- + +func @mlir_dialect_cast_memref_descriptor_wrong_num_elements(%0 : memref) { + // expected-error@+1 {{expected memref descriptor with 5 elements}} + llvm.mlir.cast %0 : memref to !llvm.struct<()> +} + +// ----- + +func @mlir_dialect_cast_0d_memref_descriptor_wrong_num_elements(%0 : memref) { + // expected-error@+1 {{expected memref descriptor with 3 elements}} + llvm.mlir.cast %0 : memref to !llvm.struct<()> +} + +// ----- + +func @mlir_dialect_cast_memref_descriptor_allocated(%0 : memref) { + // expected-error@+1 {{expected first element of a memref descriptor to be a pointer in the address space of the memref}} + llvm.mlir.cast %0 : memref to !llvm.struct<(float, float, float, float, float)> +} + +// ----- + +func @mlir_dialect_cast_memref_descriptor_allocated_wrong_space(%0 : memref) { + // expected-error@+1 {{expected first element of a memref descriptor to be a pointer in the address space of the memref}} + llvm.mlir.cast %0 : memref to !llvm.struct<(ptr, float, float, float, float)> +} + +// ----- + +func @mlir_dialect_cast_memref_descriptor_aligned(%0 : memref) { + // expected-error@+1 {{expected second element of a memref descriptor to be a pointer in the address space of the memref}} + llvm.mlir.cast %0 : memref to !llvm.struct<(ptr, float, float, float, float)> +} + +// ----- + +func @mlir_dialect_cast_memref_descriptor_aligned_wrong_space(%0 : memref) { + // expected-error@+1 {{expected second element of a memref descriptor to be a pointer in the address space of the memref}} + llvm.mlir.cast %0 : memref to !llvm.struct<(ptr, ptr, float, float, float)> +} + +// ----- + +func @mlir_dialect_cast_memref_descriptor_offset(%0 : memref) { + // expected-error@+1 {{expected third element of a memref descriptor to be index-compatible integers}} + llvm.mlir.cast %0 : memref to !llvm.struct<(ptr, ptr, float, float, float)> +} + +// ----- + +func @mlir_dialect_cast_memref_descriptor_sizes(%0 : memref) { + // expected-error@+1 {{expected fourth element of a memref descriptor to be an array of index-compatible integers}} + llvm.mlir.cast %0 : memref to !llvm.struct<(ptr, ptr, i64, float, float)> +} + +// ----- + +func @mlir_dialect_cast_memref_descriptor_sizes_wrong_type(%0 : memref) { + // expected-error@+1 {{expected fourth element of a memref descriptor to be an array of index-compatible integers}} + llvm.mlir.cast %0 : memref to !llvm.struct<(ptr, ptr, i64, array<10xfloat>, float)> +} + +// ----- + +func @mlir_dialect_cast_memref_descriptor_sizes_wrong_rank(%0 : memref) { + // expected-error@+1 {{expected fourth element of a memref descriptor to be an array of index-compatible integers}} + llvm.mlir.cast %0 : memref to !llvm.struct<(ptr, ptr, i64, array<10xi64>, float)> +} + +// ----- + +func @mlir_dialect_cast_memref_descriptor_strides(%0 : memref) { + // expected-error@+1 {{expected fifth element of a memref descriptor to be an array of index-compatible integers}} + llvm.mlir.cast %0 : memref to !llvm.struct<(ptr, ptr, i64, array<1xi64>, float)> +} + +// ----- + +func @mlir_dialect_cast_memref_descriptor_strides_wrong_type(%0 : memref) { + // expected-error@+1 {{expected fifth element of a memref descriptor to be an array of index-compatible integers}} + llvm.mlir.cast %0 : memref to !llvm.struct<(ptr, ptr, i64, array<1xi64>, array<10xfloat>)> +} + +// ----- + +func @mlir_dialect_cast_memref_descriptor_strides_wrong_rank(%0 : memref) { + // expected-error@+1 {{expected fifth element of a memref descriptor to be an array of index-compatible integers}} + llvm.mlir.cast %0 : memref to !llvm.struct<(ptr, ptr, i64, array<1xi64>, array<10xi64>)> +} + +// ----- + +func @mlir_dialect_cast_tensor(%0 : tensor) { + // expected-error@+1 {{unsupported cast}} + llvm.mlir.cast %0 : tensor to !llvm.float +} + +// ----- + +func @mlir_dialect_cast_two_std_types(%0 : f32) { + // expected-error@+1 {{expected one LLVM type and one built-in type}} + llvm.mlir.cast %0 : f32 to f64 +} + +// ----- + +func @mlir_dialect_cast_unranked_memref(%0: memref<*xf32>) { + // expected-error@+1 {{expected descriptor to be a struct with two elements}} + llvm.mlir.cast %0 : memref<*xf32> to !llvm.ptr +} + +// ----- + +func @mlir_dialect_cast_unranked_memref(%0: memref<*xf32>) { + // expected-error@+1 {{expected descriptor to be a struct with two elements}} + llvm.mlir.cast %0 : memref<*xf32> to !llvm.struct<()> +} + +// ----- + +func @mlir_dialect_cast_unranked_rank(%0: memref<*xf32>) { + // expected-error@+1 {{expected first element of a memref descriptor to be an index-compatible integer}} + llvm.mlir.cast %0 : memref<*xf32> to !llvm.struct<(float, float)> +} + +// ----- + +func @mlir_dialect_cast_unranked_rank(%0: memref<*xf32>) { + // expected-error@+1 {{expected second element of a memref descriptor to be an !llvm.ptr}} + llvm.mlir.cast %0 : memref<*xf32> to !llvm.struct<(i64, float)> +}