diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -809,7 +809,11 @@ Type eltType = vectorType ? vectorType.getElementType() : printType; int64_t rank = vectorType ? vectorType.getRank() : 0; Operation *printer; - if (eltType.isF32()) + if (eltType.isInteger(32)) + printer = getPrintI32(op); + else if (eltType.isInteger(64)) + printer = getPrintI64(op); + else if (eltType.isF32()) printer = getPrintFloat(op); else if (eltType.isF64()) printer = getPrintDouble(op); @@ -872,6 +876,16 @@ } // Helpers for method names. + Operation *getPrintI32(Operation *op) const { + LLVM::LLVMDialect *dialect = lowering.getDialect(); + return getPrint(op, dialect, "print_i32", + LLVM::LLVMType::getInt32Ty(dialect)); + } + Operation *getPrintI64(Operation *op) const { + LLVM::LLVMDialect *dialect = lowering.getDialect(); + return getPrint(op, dialect, "print_i64", + LLVM::LLVMType::getInt64Ty(dialect)); + } Operation *getPrintFloat(Operation *op) const { LLVM::LLVMDialect *dialect = lowering.getDialect(); return getPrint(op, dialect, "print_f32", diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -395,15 +395,42 @@ // CHECK: llvm.mlir.constant(0 : index // CHECK: llvm.insertvalue {{.*}}[2] : !llvm<"{ [8 x [8 x <8 x float>]]*, [8 x [8 x <8 x float>]]*, i64 }"> -func @vector_print_scalar(%arg0: f32) { +func @vector_print_scalar_i32(%arg0: i32) { + vector.print %arg0 : i32 + return +} +// CHECK-LABEL: llvm.func @vector_print_scalar_i32 +// CHECK-SAME: %[[A:arg[0-9]+]]: !llvm.i32 +// CHECK: llvm.call @print_i32(%[[A]]) : (!llvm.i32) -> () +// CHECK: llvm.call @print_newline() : () -> () + +func @vector_print_scalar_i64(%arg0: i64) { + vector.print %arg0 : i64 + return +} +// CHECK-LABEL: llvm.func @vector_print_scalar_i64 +// CHECK-SAME: %[[A:arg[0-9]+]]: !llvm.i64 +// CHECK: llvm.call @print_i64(%[[A]]) : (!llvm.i64) -> () +// CHECK: llvm.call @print_newline() : () -> () + +func @vector_print_scalar_f32(%arg0: f32) { vector.print %arg0 : f32 return } -// CHECK-LABEL: llvm.func @vector_print_scalar +// CHECK-LABEL: llvm.func @vector_print_scalar_f32 // CHECK-SAME: %[[A:arg[0-9]+]]: !llvm.float // CHECK: llvm.call @print_f32(%[[A]]) : (!llvm.float) -> () // CHECK: llvm.call @print_newline() : () -> () +func @vector_print_scalar_f64(%arg0: f64) { + vector.print %arg0 : f64 + return +} +// CHECK-LABEL: llvm.func @vector_print_scalar_f64 +// CHECK-SAME: %[[A:arg[0-9]+]]: !llvm.double +// CHECK: llvm.call @print_f64(%[[A]]) : (!llvm.double) -> () +// CHECK: llvm.call @print_newline() : () -> () + func @vector_print_vector(%arg0: vector<2x2xf32>) { vector.print %arg0 : vector<2x2xf32> return diff --git a/mlir/test/mlir-cpu-runner/mlir_runner_utils.cpp b/mlir/test/mlir-cpu-runner/mlir_runner_utils.cpp --- a/mlir/test/mlir-cpu-runner/mlir_runner_utils.cpp +++ b/mlir/test/mlir-cpu-runner/mlir_runner_utils.cpp @@ -13,6 +13,7 @@ #include "include/mlir_runner_utils.h" +#include #include extern "C" void @@ -76,7 +77,9 @@ // Small runtime support "lib" for vector.print lowering. // By providing elementary printing methods only, this // library can remain fully unaware of low-level implementation -// details of our vectors. +// details of our vectors. Also useful for direct LLVM IR output. +extern "C" void print_i32(int32_t i) { fprintf(stdout, "%" PRId32, i); } +extern "C" void print_i64(int64_t l) { fprintf(stdout, "%" PRId64, l); } extern "C" void print_f32(float f) { fprintf(stdout, "%g", f); } extern "C" void print_f64(double d) { fprintf(stdout, "%lg", d); } extern "C" void print_open() { fputs("( ", stdout); }