diff --git a/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h b/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h --- a/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h +++ b/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h @@ -200,7 +200,6 @@ //===----------------------------------------------------------------------===// // Small runtime support "lib" for vector.print lowering during codegen. //===----------------------------------------------------------------------===// -extern "C" MLIR_CRUNNERUTILS_EXPORT void print_i1(bool b); extern "C" MLIR_CRUNNERUTILS_EXPORT void print_i32(int32_t i); extern "C" MLIR_CRUNNERUTILS_EXPORT void print_i64(int64_t l); extern "C" MLIR_CRUNNERUTILS_EXPORT void print_f32(float f); diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -978,9 +978,7 @@ Type eltType = vectorType ? vectorType.getElementType() : printType; int64_t rank = vectorType ? vectorType.getRank() : 0; Operation *printer; - if (eltType.isSignlessInteger(1)) - printer = getPrintI1(op); - else if (eltType.isSignlessInteger(32)) + if (eltType.isSignlessInteger(1) || eltType.isSignlessInteger(32)) printer = getPrintI32(op); else if (eltType.isSignlessInteger(64)) printer = getPrintI64(op); @@ -1004,6 +1002,17 @@ int64_t rank) const { Location loc = op->getLoc(); if (rank == 0) { + if (value.getType() == + LLVM::LLVMType::getInt1Ty(typeConverter.getDialect())) { + // Convert i1 (bool) to i32 so we can use the print_i32 method. + // This avoids the need for a print_i1 method with an unclear ABI. + auto i32Type = LLVM::LLVMType::getInt32Ty(typeConverter.getDialect()); + auto trueVal = rewriter.create( + loc, i32Type, rewriter.getI32IntegerAttr(1)); + auto falseVal = rewriter.create( + loc, i32Type, rewriter.getI32IntegerAttr(0)); + value = rewriter.create(loc, value, trueVal, falseVal); + } emitCall(rewriter, loc, printer, value); return; } @@ -1047,11 +1056,6 @@ } // Helpers for method names. - Operation *getPrintI1(Operation *op) const { - LLVM::LLVMDialect *dialect = typeConverter.getDialect(); - return getPrint(op, dialect, "print_i1", - LLVM::LLVMType::getInt1Ty(dialect)); - } Operation *getPrintI32(Operation *op) const { LLVM::LLVMDialect *dialect = typeConverter.getDialect(); return getPrint(op, dialect, "print_i32", diff --git a/mlir/lib/ExecutionEngine/CRunnerUtils.cpp b/mlir/lib/ExecutionEngine/CRunnerUtils.cpp --- a/mlir/lib/ExecutionEngine/CRunnerUtils.cpp +++ b/mlir/lib/ExecutionEngine/CRunnerUtils.cpp @@ -23,7 +23,6 @@ // By providing elementary printing methods only, this // library can remain fully unaware of low-level implementation // details of our vectors. Also useful for direct LLVM IR output. -extern "C" void print_i1(bool b) { fputc(b ? '1' : '0', stdout); } extern "C" void print_i32(int32_t i) { fprintf(stdout, "%" PRId32, i); } extern "C" void print_i64(int64_t l) { fprintf(stdout, "%" PRId64, l); } extern "C" void print_f32(float f) { fprintf(stdout, "%g", f); } diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -439,8 +439,11 @@ } // CHECK-LABEL: llvm.func @vector_print_scalar_i1( // CHECK-SAME: %[[A:.*]]: !llvm.i1) -// CHECK: llvm.call @print_i1(%[[A]]) : (!llvm.i1) -> () -// CHECK: llvm.call @print_newline() : () -> () +// CHECK: %[[T:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32 +// CHECK: %[[F:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32 +// CHECK: %[[S:.*]] = llvm.select %[[A]], %[[T]], %[[F]] : !llvm.i1, !llvm.i32 +// CHECK: llvm.call @print_i32(%[[S]]) : (!llvm.i32) -> () +// CHECK: llvm.call @print_newline() : () -> () func @vector_print_scalar_i32(%arg0: i32) { vector.print %arg0 : i32