Index: mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp =================================================================== --- mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -988,39 +988,97 @@ //===----------------------------------------------------------------------===// static LogicalResult verify(DialectCastOp op) { - auto verifyMLIRCastType = [&op](Type type) -> LogicalResult { - if (auto llvmType = type.dyn_cast()) { - if (llvmType.isVectorTy()) - llvmType = llvmType.getVectorElementType(); - if (llvmType.isIntegerTy() || llvmType.isBFloatTy() || - llvmType.isHalfTy() || llvmType.isFloatTy() || - llvmType.isDoubleTy()) { - return success(); - } - return op.emitOpError("type must be non-index integer types, float " - "types, or vector of mentioned types."); - } - if (auto vectorType = type.dyn_cast()) { - if (vectorType.getShape().size() > 1) - return op.emitOpError("only 1-d vector is allowed"); - type = vectorType.getElementType(); + auto inType = op.in().getType(); + auto outType = op.getType(); + + // Casting between same types is valid + if (inType == outType) + return success(); + + auto inLLVMType = inType.dyn_cast(); + auto outLLVMType = outType.dyn_cast(); + + // If types are different then exactly one of them must be LLVM type + if (!inLLVMType && !outLLVMType) + return op.emitOpError("incorrect cast between two non-LLVM types"); + if (inLLVMType && outLLVMType) + return op.emitOpError("incorrect cast between two different LLVM types"); + + auto llvmType = inLLVMType ? inLLVMType : outLLVMType; + auto otherType = inLLVMType ? outType : inType; + + auto verifyScalarCastTypes = [](LLVM::LLVMType llvmType, + Type stdType) -> LogicalResult { + if (stdType.isa()) { + return success(llvmType.isBFloatTy() || llvmType.isFloatTy() || + llvmType.isHalfTy() || llvmType.isDoubleTy()); + } else if (stdType.isa()) { + return success(llvmType.isIntegerTy()); } - if (type.isSignlessIntOrFloat()) - return success(); - // Note that memrefs are not supported. We currently don't have a use case - // for it, but even if we do, there are challenges: - // * if we allow memrefs to cast from/to memref descriptors, then the - // semantics of the cast op depends on the implementation detail of the - // descriptor. - // * if we allow memrefs to cast from/to bare pointers, some users might - // alternatively want metadata that only present in the descriptor. - // - // TODO: re-evaluate the memref cast design when it's needed. - return op.emitOpError("type must be non-index integer types, float types, " - "or vector of mentioned types."); + return failure(); }; - return failure(failed(verifyMLIRCastType(op.in().getType())) || - failed(verifyMLIRCastType(op.getType()))); + + // Check valid casts in form of allow-list + if (succeeded(verifyScalarCastTypes(llvmType, otherType))) { + return success(); + } else if (auto vectorType = otherType.dyn_cast()) { + if (vectorType.getShape().size() > 1) + return op.emitOpError("only 1-d vector is allowed"); + if (!llvmType.isVectorTy()) + return op.emitOpError("incorrect cast between vector and " + "non-vector types"); + + auto llvmElementType = llvmType.getVectorElementType(); + auto vectorElementType = vectorType.getElementType(); + + if (failed(verifyScalarCastTypes(llvmElementType, vectorElementType))) + return op.emitOpError("vector elements must have matching " + "non-index integer types or float types"); + + return success(); + } else if (auto memRefType = otherType.dyn_cast()) { + auto memRefElementType = memRefType.getElementType(); + + if (llvmType.isStructTy()) { + if (llvmType.getStructNumElements() != 5) + return op.emitOpError("memref convertible struct must have 5 elements"); + + auto allocatedPtrType = llvmType.getStructElementType(0); + auto alignedPtrType = llvmType.getStructElementType(1); + + if (!allocatedPtrType.isPointerTy()) + return op.emitOpError( + "first element of memref descriptor must have pointer type"); + if (!alignedPtrType.isPointerTy()) + return op.emitOpError( + "second element of memref descriptor must have pointer type"); + + auto allocatedElementType = allocatedPtrType.getPointerElementTy(); + auto alignedElementType = alignedPtrType.getPointerElementTy(); + + if (failed( + verifyScalarCastTypes(allocatedElementType, memRefElementType))) + return op.emitOpError( + "first element of memref descriptor must be a pointer to type " + "convertible to memref's element type"); + if (failed(verifyScalarCastTypes(alignedElementType, memRefElementType))) + return op.emitOpError( + "second element of memref descriptor must be a pointer to type " + "convertible to memref's element type"); + + // For now skip checking other fields of memref descriptor as they rely on + // index lowering + } else { + return op.emitOpError( + "memref can be cast only from/to memref decriptor struct type"); + } + + return success(); + } + + return op.emitOpError( + "casted types must be non-index integer types, float types, " + "vector of mentioned types, or memref of mentioned types"); } // Parses one of the keywords provided in the list `keywords` and returns the Index: mlir/test/Conversion/StandardToLLVM/convert-to-llvmir_32bit.mlir =================================================================== --- /dev/null +++ mlir/test/Conversion/StandardToLLVM/convert-to-llvmir_32bit.mlir @@ -0,0 +1,31 @@ +// RUN: mlir-opt -convert-std-to-llvm='index-bitwidth=32' %s -split-input-file | FileCheck %s + +// CHECK-LABEL: func @mlir_cast_to_llvm +// CHECK-SAME: %[[ARG1:[^,]*]]: +// CHECK-SAME: %[[ARG2:[^,]*]]: +// CHECK-SAME: %[[ARG3:[^,]*]]: +// CHECK-SAME: %[[ARG4:[^,]*]]: +// CHECK-SAME: %[[ARG5:[^,]*]]: +// CHECK-SAME: %[[ARG6:[^,]*]]: +// CHECK-SAME: %[[ARG7:[^,]*]]: +func @mlir_cast_to_llvm(%0 : memref<3x4xf32>) -> !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> { + %1 = llvm.mlir.cast %0 : memref<3x4xf32> to !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> + // CHECK-NEXT: llvm.mlir.undef + // CHECK-NEXT: llvm.insertvalue %[[ARG1]] + // CHECK-NEXT: llvm.insertvalue %[[ARG2]] + // CHECK-NEXT: llvm.insertvalue %[[ARG3]] + // CHECK-NEXT: llvm.insertvalue %[[ARG4]] + // CHECK-NEXT: llvm.insertvalue %[[ARG6]] + // CHECK-NEXT: llvm.insertvalue %[[ARG5]] + // CHECK-NEXT: %[[RES:.*]] = llvm.insertvalue %[[ARG7]] + // CHECK-NEXT: llvm.return %[[RES]] + return %1 : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> +} + +// CHECK-LABEL: func @mlir_cast_from_llvm +// CHECK-SAME: %[[ARG:.*]]: +func @mlir_cast_from_llvm(%0 : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }">) -> memref<3x4xf32> { + %1 = llvm.mlir.cast %0 : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> to memref<3x4xf32> + // CHECK-NEXT: llvm.return %[[ARG]] + return %1 : memref<3x4xf32> +} Index: mlir/test/Conversion/StandardToLLVM/convert-to-llvmir_64bit.mlir =================================================================== --- /dev/null +++ mlir/test/Conversion/StandardToLLVM/convert-to-llvmir_64bit.mlir @@ -0,0 +1,31 @@ +// RUN: mlir-opt -convert-std-to-llvm %s -split-input-file | FileCheck %s + +// CHECK-LABEL: func @mlir_cast_to_llvm +// CHECK-SAME: %[[ARG1:[^,]*]]: +// CHECK-SAME: %[[ARG2:[^,]*]]: +// CHECK-SAME: %[[ARG3:[^,]*]]: +// CHECK-SAME: %[[ARG4:[^,]*]]: +// CHECK-SAME: %[[ARG5:[^,]*]]: +// CHECK-SAME: %[[ARG6:[^,]*]]: +// CHECK-SAME: %[[ARG7:[^,]*]]: +func @mlir_cast_to_llvm(%0 : memref<3x4xf32>) -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> { + %1 = llvm.mlir.cast %0 : memref<3x4xf32> to !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK-NEXT: llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK-NEXT: llvm.insertvalue %[[ARG1]] + // CHECK-NEXT: llvm.insertvalue %[[ARG2]] + // CHECK-NEXT: llvm.insertvalue %[[ARG3]] + // CHECK-NEXT: llvm.insertvalue %[[ARG4]] + // CHECK-NEXT: llvm.insertvalue %[[ARG6]] + // CHECK-NEXT: llvm.insertvalue %[[ARG5]] + // CHECK-NEXT: %[[RES:.*]] = llvm.insertvalue %[[ARG7]] + // CHECK-NEXT: llvm.return %[[RES]] + return %1 : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +} + +// CHECK-LABEL: func @mlir_cast_from_llvm +// CHECK-SAME: %[[ARG:.*]]: +func @mlir_cast_from_llvm(%0 : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">) -> memref<3x4xf32> { + %1 = llvm.mlir.cast %0 : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> to memref<3x4xf32> + // CHECK-NEXT: llvm.return %[[ARG]] + return %1 : memref<3x4xf32> +} Index: mlir/test/Conversion/StandardToLLVM/invalid.mlir =================================================================== --- mlir/test/Conversion/StandardToLLVM/invalid.mlir +++ mlir/test/Conversion/StandardToLLVM/invalid.mlir @@ -1,7 +1,7 @@ // RUN: mlir-opt %s -convert-std-to-llvm -verify-diagnostics -split-input-file func @mlir_cast_to_llvm(%0 : index) -> !llvm.i64 { - // expected-error@+1 {{'llvm.mlir.cast' op type must be non-index integer types, float types, or vector of mentioned types}} + // expected-error@+1 {{'llvm.mlir.cast' op casted types must be non-index integer types, float types, vector of mentioned types, or memref of mentioned types}} %1 = llvm.mlir.cast %0 : index to !llvm.i64 return %1 : !llvm.i64 } @@ -9,7 +9,7 @@ // ----- func @mlir_cast_from_llvm(%0 : !llvm.i64) -> index { - // expected-error@+1 {{'llvm.mlir.cast' op type must be non-index integer types, float types, or vector of mentioned types}} + // expected-error@+1 {{'llvm.mlir.cast' op casted types must be non-index integer types, float types, vector of mentioned types, or memref of mentioned types}} %1 = llvm.mlir.cast %0 : !llvm.i64 to index return %1 : index }