Index: mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp =================================================================== --- mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -988,39 +988,99 @@ //===----------------------------------------------------------------------===// static LogicalResult verify(DialectCastOp op) { - auto verifyMLIRCastType = [&op](Type type) -> LogicalResult { - if (auto llvmType = type.dyn_cast()) { - if (llvmType.isVectorTy()) - llvmType = llvmType.getVectorElementType(); - if (llvmType.isIntegerTy() || llvmType.isBFloatTy() || - llvmType.isHalfTy() || llvmType.isFloatTy() || - llvmType.isDoubleTy()) { - return success(); - } - return op.emitOpError("type must be non-index integer types, float " - "types, or vector of mentioned types."); - } - if (auto vectorType = type.dyn_cast()) { - if (vectorType.getShape().size() > 1) - return op.emitOpError("only 1-d vector is allowed"); - type = vectorType.getElementType(); - } - if (type.isSignlessIntOrFloat()) - return success(); - // Note that memrefs are not supported. We currently don't have a use case - // for it, but even if we do, there are challenges: - // * if we allow memrefs to cast from/to memref descriptors, then the - // semantics of the cast op depends on the implementation detail of the - // descriptor. - // * if we allow memrefs to cast from/to bare pointers, some users might - // alternatively want metadata that only present in the descriptor. - // - // TODO: re-evaluate the memref cast design when it's needed. - return op.emitOpError("type must be non-index integer types, float types, " - "or vector of mentioned types."); + Type inType = op.in().getType(); + Type outType = op.getType(); + + // Casting between same types is valid. + if (inType == outType) + return success(); + + auto inLLVMType = inType.dyn_cast(); + auto outLLVMType = outType.dyn_cast(); + + // If types are different then exactly one of them must be LLVM type. + if (!inLLVMType && !outLLVMType) + return op.emitOpError("incorrect cast between two non-LLVM types"); + if (inLLVMType && outLLVMType) + return op.emitOpError("incorrect cast between two different LLVM types"); + + LLVMType llvmType = inLLVMType ? inLLVMType : outLLVMType; + Type otherType = inLLVMType ? outType : inType; + + // TODO: Reduce duplication between below code and LLVMTypeConverter. + auto verifyScalarCastTypes = [](LLVMType llvmType, + Type stdType) -> LogicalResult { + if (stdType.isa()) + return success((llvmType.isBFloatTy() && stdType.isBF16()) || + (llvmType.isDoubleTy() && stdType.isF64()) || + (llvmType.isFloatTy() && stdType.isF32()) || + (llvmType.isHalfTy() && stdType.isF16())); + + if (auto integerType = stdType.dyn_cast()) + return success(llvmType.isIntegerTy() && + llvmType.getIntegerBitWidth() == integerType.getWidth()); + + return failure(); }; - return failure(failed(verifyMLIRCastType(op.in().getType())) || - failed(verifyMLIRCastType(op.getType()))); + + // Check valid casts in form of allow-list. + if (succeeded(verifyScalarCastTypes(llvmType, otherType))) + return success(); + + if (auto vectorType = otherType.dyn_cast()) { + if (vectorType.getShape().size() > 1) + return op.emitOpError("only 1-d vector is allowed"); + if (!llvmType.isVectorTy()) + return op.emitOpError("incorrect cast between vector and " + "non-vector types"); + + LLVMType llvmElementType = llvmType.getVectorElementType(); + Type vectorElementType = vectorType.getElementType(); + + if (failed(verifyScalarCastTypes(llvmElementType, vectorElementType))) + return op.emitOpError("vector elements must have matching " + "non-index integer types or float types"); + + return success(); + } + + if (auto memRefType = otherType.dyn_cast()) { + Type memRefElementType = memRefType.getElementType(); + + if (!llvmType.isStructTy()) + return op.emitOpError( + "memref can be cast only from/to memref decriptor struct type"); + + if ((memRefType.getRank() > 0 && llvmType.getStructNumElements() != 5) || + (memRefType.getRank() == 0 && llvmType.getStructNumElements() != 3)) + return op.emitOpError( + "incorrect number of elements in memref descriptor struct"); + + LLVMType allocatedPtrType = llvmType.getStructElementType(0); + LLVMType alignedPtrType = llvmType.getStructElementType(1); + + if (!allocatedPtrType.isPointerTy() || !alignedPtrType.isPointerTy()) + return op.emitOpError("first and second element of memref descriptor " + "must have pointer type"); + + LLVMType allocatedElementType = allocatedPtrType.getPointerElementTy(); + LLVMType alignedElementType = alignedPtrType.getPointerElementTy(); + + if (allocatedElementType != alignedElementType || + failed(verifyScalarCastTypes(allocatedElementType, memRefElementType))) + return op.emitOpError( + "first and second element of memref descriptor must have same types, " + "that is pointer to type convertible to memref's element type"); + + // For now skip checking other fields of memref descriptor as they rely on + // index lowering. + + return success(); + } + + return op.emitOpError("incorrect cast").attachNote() + << "casting is supported for non-index integers and floats, " + "vectors and memrefs of those"; } // Parses one of the keywords provided in the list `keywords` and returns the Index: mlir/test/Conversion/StandardToLLVM/invalid.mlir =================================================================== --- mlir/test/Conversion/StandardToLLVM/invalid.mlir +++ mlir/test/Conversion/StandardToLLVM/invalid.mlir @@ -1,7 +1,8 @@ // RUN: mlir-opt %s -convert-std-to-llvm -verify-diagnostics -split-input-file func @mlir_cast_to_llvm(%0 : index) -> !llvm.i64 { - // expected-error@+1 {{'llvm.mlir.cast' op type must be non-index integer types, float types, or vector of mentioned types}} + // expected-error@+2 {{'llvm.mlir.cast' op incorrect cast}} + // expected-note@+1 {{casting is supported for}} %1 = llvm.mlir.cast %0 : index to !llvm.i64 return %1 : !llvm.i64 } @@ -9,7 +10,8 @@ // ----- func @mlir_cast_from_llvm(%0 : !llvm.i64) -> index { - // expected-error@+1 {{'llvm.mlir.cast' op type must be non-index integer types, float types, or vector of mentioned types}} + // expected-error@+2 {{'llvm.mlir.cast' op incorrect cast}} + // expected-note@+1 {{casting is supported for}} %1 = llvm.mlir.cast %0 : !llvm.i64 to index return %1 : index } @@ -17,7 +19,8 @@ // ----- func @mlir_cast_to_llvm_int(%0 : i32) -> !llvm.i64 { - // expected-error@+1 {{failed to legalize operation 'llvm.mlir.cast' that was explicitly marked illegal}} + // expected-error@+2 {{'llvm.mlir.cast' op incorrect cast}} + // expected-note@+1 {{casting is supported for}} %1 = llvm.mlir.cast %0 : i32 to !llvm.i64 return %1 : !llvm.i64 } Index: mlir/test/Dialect/LLVMIR/invalid.mlir =================================================================== --- mlir/test/Dialect/LLVMIR/invalid.mlir +++ mlir/test/Dialect/LLVMIR/invalid.mlir @@ -613,3 +613,95 @@ // expected-error @+1 {{can be given only acquire, release, acq_rel, and seq_cst orderings}} llvm.fence syncscope("agent") monotonic } + +// ----- + +// Check dialect cast operation verifier. + +func @invaild_dialect_cast() { + %0 = "op"() : () -> (index) + // expected-error @+2 {{incorrect cast}} + // expected-note @+1 {{casting is supported for}} + %1 = llvm.mlir.cast %0 : index to !llvm.i64 +} + +// ----- + +func @invaild_dialect_cast() { + %0 = "op"() : () -> (f32) + // expected-error @+1 {{incorrect cast between two non-LLVM types}} + %1 = llvm.mlir.cast %0 : f32 to f16 +} + +// ----- + +func @invaild_dialect_cast() { + %0 = "op"() : () -> (!llvm.i32) + // expected-error @+1 {{incorrect cast between two different LLVM types}} + %1 = llvm.mlir.cast %0 : !llvm.i32 to !llvm.float +} + +// ----- + +func @invaild_dialect_cast() { + %0 = "op"() : () -> (f32) + // expected-error @+2 {{incorrect cast}} + // expected-note @+1 {{casting is supported for}} + %1 = llvm.mlir.cast %0 : f32 to !llvm.half +} + +// ----- + +func @invaild_dialect_cast() { + %0 = "op"() : () -> (vector<2x3xf32>) + // expected-error @+1 {{only 1-d vector is allowed}} + %1 = llvm.mlir.cast %0 : vector<2x3xf32> to !llvm<"[2 x <3 x float>]"> +} + +// ----- + +func @invaild_dialect_cast() { + %0 = "op"() : () -> (vector<2xf32>) + // expected-error @+1 {{incorrect cast between vector and non-vector types}} + %1 = llvm.mlir.cast %0 : vector<2xf32> to !llvm.float +} + +// ----- + +func @invaild_dialect_cast() { + %0 = "op"() : () -> (vector<2xf32>) + // expected-error @+1 {{vector elements must have matching non-index integer types or float types}} + %1 = llvm.mlir.cast %0 : vector<2xf32> to !llvm<"<2 x i32>"> +} + +// ----- + +func @invaild_dialect_cast() { + %0 = "op"() : () -> (memref<2xf32>) + // expected-error @+1 {{memref can be cast only from/to memref decriptor struct type}} + %1 = llvm.mlir.cast %0 : memref<2xf32> to !llvm.float +} + +// ----- + +func @invaild_dialect_cast() { + %0 = "op"() : () -> (memref<2xf32>) + // expected-error @+1 {{incorrect number of elements in memref descriptor struct}} + %1 = llvm.mlir.cast %0 : memref<2xf32> to !llvm<"{float*, float*, i64}"> +} + +// ----- + +func @invaild_dialect_cast() { + %0 = "op"() : () -> (memref<2xf32>) + // expected-error @+1 {{first and second element of memref descriptor must have pointer type}} + %1 = llvm.mlir.cast %0 : memref<2xf32> to !llvm<"{float, float*, i64, [2 x i64], [2 x i64]}"> +} + +// ----- + +func @invaild_dialect_cast() { + %0 = "op"() : () -> (memref<2xf32>) + // expected-error @+1 {{first and second element of memref descriptor must have same types, that is pointer to type convertible to memref's element type}} + %1 = llvm.mlir.cast %0 : memref<2xf32> to !llvm<"{half*, half*, i64, [2 x i64], [2 x i64]}"> +} Index: mlir/test/Dialect/LLVMIR/roundtrip.mlir =================================================================== --- mlir/test/Dialect/LLVMIR/roundtrip.mlir +++ mlir/test/Dialect/LLVMIR/roundtrip.mlir @@ -333,3 +333,55 @@ llvm.fence release return } + +// Dialect cast operation. + +// CHECK-LABEL: @castI32 +// CHECK-SAME: %[[ARG0:.*]]: !llvm.i32 +// CHECK: %[[CAST:.*]] = llvm.mlir.cast %[[ARG0]] : !llvm.i32 to i32 +// CHECK: llvm.mlir.cast %[[CAST]] : i32 to !llvm.i32 +func @castI32(%arg0: !llvm.i32) { + %0 = llvm.mlir.cast %arg0 : !llvm.i32 to i32 + %1 = llvm.mlir.cast %0 : i32 to !llvm.i32 + return +} + +// CHECK-LABEL: @castFloat +// CHECK-SAME: %[[ARG0:.*]]: !llvm.float +// CHECK: %[[CAST:.*]] = llvm.mlir.cast %[[ARG0]] : !llvm.float to f32 +// CHECK: llvm.mlir.cast %[[CAST]] : f32 to !llvm.float +func @castFloat(%arg0: !llvm.float) { + %0 = llvm.mlir.cast %arg0 : !llvm.float to f32 + %1 = llvm.mlir.cast %0 : f32 to !llvm.float + return +} + +// CHECK-LABEL: @castVector +// CHECK-SAME: %[[ARG0:.*]]: !llvm<"<3 x float>"> +// CHECK: %[[CAST:.*]] = llvm.mlir.cast %[[ARG0]] : !llvm<"<3 x float>"> to vector<3xf32> +// CHECK: llvm.mlir.cast %[[CAST]] : vector<3xf32> to !llvm<"<3 x float>"> +func @castVector(%arg0: !llvm<"<3 x float>">) { + %0 = llvm.mlir.cast %arg0 : !llvm<"<3 x float>"> to vector<3xf32> + %1 = llvm.mlir.cast %0 : vector<3xf32> to !llvm<"<3 x float>"> + return +} + +// CHECK-LABEL: @castMemref +// CHECK-SAME: %[[ARG0:.*]]: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// CHECK: %[[CAST:.*]] = llvm.mlir.cast %[[ARG0]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> to memref<2x3xf32> +// CHECK: llvm.mlir.cast %[[CAST]] : memref<2x3xf32> to !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +func @castMemref(%arg0: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">) { + %0 = llvm.mlir.cast %arg0 : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> to memref<2x3xf32> + %1 = llvm.mlir.cast %0 : memref<2x3xf32> to !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + return +} + +// CHECK-LABEL: @castMemref0D +// CHECK-SAME: %[[ARG0:.*]]: !llvm<"{ float*, float*, i64 }"> +// CHECK: %[[CAST:.*]] = llvm.mlir.cast %[[ARG0]] : !llvm<"{ float*, float*, i64 }"> to memref +// CHECK: llvm.mlir.cast %[[CAST]] : memref to !llvm<"{ float*, float*, i64 }"> +func @castMemref0D(%arg0: !llvm<"{ float*, float*, i64 }">) { + %0 = llvm.mlir.cast %arg0 : !llvm<"{ float*, float*, i64 }"> to memref + %1 = llvm.mlir.cast %0 : memref to !llvm<"{ float*, float*, i64 }"> + return +}