diff --git a/mlir/integration_test/Dialect/Vector/CPU/test-print-int.mlir b/mlir/integration_test/Dialect/Vector/CPU/test-print-int.mlir new file mode 100644 --- /dev/null +++ b/mlir/integration_test/Dialect/Vector/CPU/test-print-int.mlir @@ -0,0 +1,76 @@ +// RUN: mlir-opt %s -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \ +// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +// +// Test various signless, signed, unsigned integer types. +// +func @entry() { + %0 = std.constant dense<[true, false, -1, 0, 1]> : vector<5xi1> + vector.print %0 : vector<5xi1> + // CHECK: ( 1, 0, 1, 0, 1 ) + + %1 = std.constant dense<[true, false, -1, 0]> : vector<4xsi1> + vector.print %1 : vector<4xsi1> + // CHECK: ( 1, 0, 1, 0 ) + + %2 = std.constant dense<[true, false, 0, 1]> : vector<4xui1> + vector.print %2 : vector<4xui1> + // CHECK: ( 1, 0, 0, 1 ) + + %3 = std.constant dense<[-128, -127, -1, 0, 1, 127, 128, 254, 255]> : vector<9xi8> + vector.print %3 : vector<9xi8> + // CHECK: ( -128, -127, -1, 0, 1, 127, -128, -2, -1 ) + + %4 = std.constant dense<[-128, -127, -1, 0, 1, 127]> : vector<6xsi8> + vector.print %4 : vector<6xsi8> + // CHECK: ( -128, -127, -1, 0, 1, 127 ) + + %5 = std.constant dense<[0, 1, 127, 128, 254, 255]> : vector<6xui8> + vector.print %5 : vector<6xui8> + // CHECK: ( 0, 1, 127, 128, 254, 255 ) + + %6 = std.constant dense<[-32768, -32767, -1, 0, 1, 32767, 32768, 65534, 65535]> : vector<9xi16> + vector.print %6 : vector<9xi16> + // CHECK: ( -32768, -32767, -1, 0, 1, 32767, -32768, -2, -1 ) + + %7 = std.constant dense<[-32768, -32767, -1, 0, 1, 32767]> : vector<6xsi16> + vector.print %7 : vector<6xsi16> + // CHECK: ( -32768, -32767, -1, 0, 1, 32767 ) + + %8 = std.constant dense<[0, 1, 32767, 32768, 65534, 65535]> : vector<6xui16> + vector.print %8 : vector<6xui16> + // CHECK: ( 0, 1, 32767, 32768, 65534, 65535 ) + + %9 = std.constant dense<[-2147483648, -2147483647, -1, 0, 1, + 2147483647, 2147483648, 4294967294, 4294967295]> : vector<9xi32> + vector.print %9 : vector<9xi32> + // CHECK: ( -2147483648, -2147483647, -1, 0, 1, 2147483647, -2147483648, -2, -1 ) + + %10 = std.constant dense<[-2147483648, -2147483647, -1, 0, 1, 2147483647]> : vector<6xsi32> + vector.print %10 : vector<6xsi32> + // CHECK: ( -2147483648, -2147483647, -1, 0, 1, 2147483647 ) + + %11 = std.constant dense<[0, 1, 2147483647, 2147483648, 4294967294, 4294967295]> : vector<6xui32> + vector.print %11 : vector<6xui32> + // CHECK: ( 0, 1, 2147483647, 2147483648, 4294967294, 4294967295 ) + + %12 = std.constant dense<[-9223372036854775808, -9223372036854775807, -1, 0, 1, + 9223372036854775807, 9223372036854775808, + 18446744073709551614, 18446744073709551615]> : vector<9xi64> + vector.print %12 : vector<9xi64> + // CHECK: ( -9223372036854775808, -9223372036854775807, -1, 0, 1, 9223372036854775807, -9223372036854775808, -2, -1 ) + + %13 = std.constant dense<[-9223372036854775808, -9223372036854775807, -1, 0, 1, + 9223372036854775807]> : vector<6xsi64> + vector.print %13 : vector<6xsi64> + // CHECK: ( -9223372036854775808, -9223372036854775807, -1, 0, 1, 9223372036854775807 ) + + %14 = std.constant dense<[0, 1, 9223372036854775807, 9223372036854775808, + 18446744073709551614, 18446744073709551615]> : vector<6xui64> + vector.print %14 : vector<6xui64> + // CHECK: ( 0, 1, 9223372036854775807, 9223372036854775808, 18446744073709551614, 18446744073709551615 ) + + return +} diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1319,44 +1319,96 @@ if (typeConverter.convertType(printType) == nullptr) return failure(); - // Make sure element type has runtime support (currently just Float/Double). + // Make sure element type has runtime support. + PrintConversion conversion = PrintConversion::None; VectorType vectorType = printType.dyn_cast(); Type eltType = vectorType ? vectorType.getElementType() : printType; - int64_t rank = vectorType ? vectorType.getRank() : 0; Operation *printer; - if (eltType.isSignlessInteger(1) || eltType.isSignlessInteger(32)) - printer = getPrintI32(op); - else if (eltType.isSignlessInteger(64)) - printer = getPrintI64(op); - else if (eltType.isF32()) + if (eltType.isF32()) { printer = getPrintFloat(op); - else if (eltType.isF64()) + } else if (eltType.isF64()) { printer = getPrintDouble(op); - else + } else if (auto intTy = eltType.dyn_cast()) { + // Integers need a zero or sign extension on the operand + // (depending on the source type) as well as a signed or + // unsigned print method. Up to 64-bit is supported. + unsigned width = intTy.getWidth(); + if (intTy.isUnsigned()) { + if (width <= 32) { + if (width < 32) + conversion = PrintConversion::ZeroExt32; + printer = getPrintU32(op); + } else if (width <= 64) { + if (width < 64) + conversion = PrintConversion::ZeroExt64; + printer = getPrintU64(op); + } else { + return failure(); + } + } else { + assert(intTy.isSignless() || intTy.isSigned()); + if (width <= 32) { + // Note that we *always* zero extend booleans (1-bit integers), + // so that true/false is printed as 1/0 rather than -1/0. + if (width == 1) + conversion = PrintConversion::ZeroExt32; + else if (width < 32) + conversion = PrintConversion::SignExt32; + printer = getPrintI32(op); + } else if (width <= 64) { + if (width < 64) + conversion = PrintConversion::SignExt64; + printer = getPrintI64(op); + } else { + return failure(); + } + } + } else { return failure(); + } // Unroll vector into elementary print calls. - emitRanks(rewriter, op, adaptor.source(), vectorType, printer, rank); + int64_t rank = vectorType ? vectorType.getRank() : 0; + emitRanks(rewriter, op, adaptor.source(), vectorType, printer, rank, + conversion); emitCall(rewriter, op->getLoc(), getPrintNewline(op)); rewriter.eraseOp(op); return success(); } private: + enum class PrintConversion { + None, + ZeroExt32, + SignExt32, + ZeroExt64, + SignExt64 + }; + void emitRanks(ConversionPatternRewriter &rewriter, Operation *op, Value value, VectorType vectorType, Operation *printer, - int64_t rank) const { + int64_t rank, PrintConversion conversion) const { Location loc = op->getLoc(); if (rank == 0) { - if (value.getType() == LLVM::LLVMType::getInt1Ty(rewriter.getContext())) { - // Convert i1 (bool) to i32 so we can use the print_i32 method. - // This avoids the need for a print_i1 method with an unclear ABI. - auto i32Type = LLVM::LLVMType::getInt32Ty(rewriter.getContext()); - auto trueVal = rewriter.create( - loc, i32Type, rewriter.getI32IntegerAttr(1)); - auto falseVal = rewriter.create( - loc, i32Type, rewriter.getI32IntegerAttr(0)); - value = rewriter.create(loc, value, trueVal, falseVal); + switch (conversion) { + case PrintConversion::ZeroExt32: + value = rewriter.create( + loc, value, LLVM::LLVMType::getInt32Ty(rewriter.getContext())); + break; + case PrintConversion::SignExt32: + value = rewriter.create( + loc, value, LLVM::LLVMType::getInt32Ty(rewriter.getContext())); + break; + case PrintConversion::ZeroExt64: + value = rewriter.create( + loc, value, LLVM::LLVMType::getInt64Ty(rewriter.getContext())); + break; + case PrintConversion::SignExt64: + value = rewriter.create( + loc, value, LLVM::LLVMType::getInt64Ty(rewriter.getContext())); + break; + case PrintConversion::None: + break; } emitCall(rewriter, loc, printer, value); return; @@ -1372,7 +1424,8 @@ rank > 1 ? reducedType : vectorType.getElementType()); Value nestedVal = extractOne(rewriter, typeConverter, loc, value, llvmType, rank, d); - emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1); + emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1, + conversion); if (d != dim - 1) emitCall(rewriter, loc, printComma); } @@ -1410,6 +1463,14 @@ return getPrint(op, "print_i64", LLVM::LLVMType::getInt64Ty(op->getContext())); } + Operation *getPrintU32(Operation *op) const { + return getPrint(op, "print_u32", + LLVM::LLVMType::getInt32Ty(op->getContext())); + } + Operation *getPrintU64(Operation *op) const { + return getPrint(op, "print_u64", + LLVM::LLVMType::getInt64Ty(op->getContext())); + } Operation *getPrintFloat(Operation *op) const { return getPrint(op, "print_f32", LLVM::LLVMType::getFloatTy(op->getContext())); diff --git a/mlir/lib/ExecutionEngine/CRunnerUtils.cpp b/mlir/lib/ExecutionEngine/CRunnerUtils.cpp --- a/mlir/lib/ExecutionEngine/CRunnerUtils.cpp +++ b/mlir/lib/ExecutionEngine/CRunnerUtils.cpp @@ -25,6 +25,8 @@ // details of our vectors. Also useful for direct LLVM IR output. extern "C" void print_i32(int32_t i) { fprintf(stdout, "%" PRId32, i); } extern "C" void print_i64(int64_t l) { fprintf(stdout, "%" PRId64, l); } +extern "C" void print_u32(uint32_t i) { fprintf(stdout, "%" PRIu32, i); } +extern "C" void print_u64(uint64_t l) { fprintf(stdout, "%" PRIu64, l); } extern "C" void print_f32(float f) { fprintf(stdout, "%g", f); } extern "C" void print_f64(double d) { fprintf(stdout, "%lg", d); } extern "C" void print_open() { fputs("( ", stdout); } diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -433,14 +433,45 @@ vector.print %arg0 : i1 return } +// +// Type "boolean" always uses zero extension. +// // CHECK-LABEL: llvm.func @vector_print_scalar_i1( // CHECK-SAME: %[[A:.*]]: !llvm.i1) -// CHECK: %[[T:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32 -// CHECK: %[[F:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32 -// CHECK: %[[S:.*]] = llvm.select %[[A]], %[[T]], %[[F]] : !llvm.i1, !llvm.i32 +// CHECK: %[[S:.*]] = llvm.zext %[[A]] : !llvm.i1 to !llvm.i32 +// CHECK: llvm.call @print_i32(%[[S]]) : (!llvm.i32) -> () +// CHECK: llvm.call @print_newline() : () -> () + +func @vector_print_scalar_i4(%arg0: i4) { + vector.print %arg0 : i4 + return +} +// CHECK-LABEL: llvm.func @vector_print_scalar_i4( +// CHECK-SAME: %[[A:.*]]: !llvm.i4) +// CHECK: %[[S:.*]] = llvm.sext %[[A]] : !llvm.i4 to !llvm.i32 // CHECK: llvm.call @print_i32(%[[S]]) : (!llvm.i32) -> () // CHECK: llvm.call @print_newline() : () -> () +func @vector_print_scalar_si4(%arg0: si4) { + vector.print %arg0 : si4 + return +} +// CHECK-LABEL: llvm.func @vector_print_scalar_si4( +// CHECK-SAME: %[[A:.*]]: !llvm.i4) +// CHECK: %[[S:.*]] = llvm.sext %[[A]] : !llvm.i4 to !llvm.i32 +// CHECK: llvm.call @print_i32(%[[S]]) : (!llvm.i32) -> () +// CHECK: llvm.call @print_newline() : () -> () + +func @vector_print_scalar_ui4(%arg0: ui4) { + vector.print %arg0 : ui4 + return +} +// CHECK-LABEL: llvm.func @vector_print_scalar_ui4( +// CHECK-SAME: %[[A:.*]]: !llvm.i4) +// CHECK: %[[S:.*]] = llvm.zext %[[A]] : !llvm.i4 to !llvm.i32 +// CHECK: llvm.call @print_u32(%[[S]]) : (!llvm.i32) -> () +// CHECK: llvm.call @print_newline() : () -> () + func @vector_print_scalar_i32(%arg0: i32) { vector.print %arg0 : i32 return @@ -450,6 +481,45 @@ // CHECK: llvm.call @print_i32(%[[A]]) : (!llvm.i32) -> () // CHECK: llvm.call @print_newline() : () -> () +func @vector_print_scalar_ui32(%arg0: ui32) { + vector.print %arg0 : ui32 + return +} +// CHECK-LABEL: llvm.func @vector_print_scalar_ui32( +// CHECK-SAME: %[[A:.*]]: !llvm.i32) +// CHECK: llvm.call @print_u32(%[[A]]) : (!llvm.i32) -> () +// CHECK: llvm.call @print_newline() : () -> () + +func @vector_print_scalar_i40(%arg0: i40) { + vector.print %arg0 : i40 + return +} +// CHECK-LABEL: llvm.func @vector_print_scalar_i40( +// CHECK-SAME: %[[A:.*]]: !llvm.i40) +// CHECK: %[[S:.*]] = llvm.sext %[[A]] : !llvm.i40 to !llvm.i64 +// CHECK: llvm.call @print_i64(%[[S]]) : (!llvm.i64) -> () +// CHECK: llvm.call @print_newline() : () -> () + +func @vector_print_scalar_si40(%arg0: si40) { + vector.print %arg0 : si40 + return +} +// CHECK-LABEL: llvm.func @vector_print_scalar_si40( +// CHECK-SAME: %[[A:.*]]: !llvm.i40) +// CHECK: %[[S:.*]] = llvm.sext %[[A]] : !llvm.i40 to !llvm.i64 +// CHECK: llvm.call @print_i64(%[[S]]) : (!llvm.i64) -> () +// CHECK: llvm.call @print_newline() : () -> () + +func @vector_print_scalar_ui40(%arg0: ui40) { + vector.print %arg0 : ui40 + return +} +// CHECK-LABEL: llvm.func @vector_print_scalar_ui40( +// CHECK-SAME: %[[A:.*]]: !llvm.i40) +// CHECK: %[[S:.*]] = llvm.zext %[[A]] : !llvm.i40 to !llvm.i64 +// CHECK: llvm.call @print_u64(%[[S]]) : (!llvm.i64) -> () +// CHECK: llvm.call @print_newline() : () -> () + func @vector_print_scalar_i64(%arg0: i64) { vector.print %arg0 : i64 return @@ -459,6 +529,15 @@ // CHECK: llvm.call @print_i64(%[[A]]) : (!llvm.i64) -> () // CHECK: llvm.call @print_newline() : () -> () +func @vector_print_scalar_ui64(%arg0: ui64) { + vector.print %arg0 : ui64 + return +} +// CHECK-LABEL: llvm.func @vector_print_scalar_ui64( +// CHECK-SAME: %[[A:.*]]: !llvm.i64) +// CHECK: llvm.call @print_u64(%[[A]]) : (!llvm.i64) -> () +// CHECK: llvm.call @print_newline() : () -> () + func @vector_print_scalar_f32(%arg0: f32) { vector.print %arg0 : f32 return