diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -57,8 +57,7 @@ static Value extractOne(ConversionPatternRewriter &rewriter, LLVMTypeConverter &typeConverter, Location loc, Value val, Type llvmType, int64_t rank, int64_t pos) { - assert(rank > 0 && "0-D vector corner case should have been handled already"); - if (rank == 1) { + if (rank <= 1) { auto idxType = rewriter.getIndexType(); auto constant = rewriter.create( loc, typeConverter.convertType(idxType), @@ -987,7 +986,8 @@ // Unroll vector into elementary print calls. int64_t rank = vectorType ? vectorType.getRank() : 0; - emitRanks(rewriter, printOp, adaptor.source(), vectorType, printer, rank, + Type type = vectorType ? vectorType : eltType; + emitRanks(rewriter, printOp, adaptor.source(), type, printer, rank, conversion); emitCall(rewriter, printOp->getLoc(), LLVM::lookupOrCreatePrintNewlineFn( @@ -1006,10 +1006,12 @@ }; void emitRanks(ConversionPatternRewriter &rewriter, Operation *op, - Value value, VectorType vectorType, Operation *printer, - int64_t rank, PrintConversion conversion) const { + Value value, Type type, Operation *printer, int64_t rank, + PrintConversion conversion) const { + VectorType vectorType = type.dyn_cast(); Location loc = op->getLoc(); - if (rank == 0) { + if (!vectorType) { + assert(rank == 0 && "The scalar case expects rank == 0"); switch (conversion) { case PrintConversion::ZeroExt64: value = rewriter.create( @@ -1030,12 +1032,29 @@ LLVM::lookupOrCreatePrintOpenFn(op->getParentOfType())); Operation *printComma = LLVM::lookupOrCreatePrintCommaFn(op->getParentOfType()); + + if (rank <= 1) { + auto reducedType = vectorType.getElementType(); + auto llvmType = typeConverter->convertType(reducedType); + int64_t dim = rank == 0 ? 1 : vectorType.getDimSize(0); + for (int64_t d = 0; d < dim; ++d) { + Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value, + llvmType, /*rank=*/0, /*pos=*/d); + emitRanks(rewriter, op, nestedVal, reducedType, printer, /*rank=*/0, + conversion); + if (d != dim - 1) + emitCall(rewriter, loc, printComma); + } + emitCall( + rewriter, loc, + LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType())); + return; + } + int64_t dim = vectorType.getDimSize(0); for (int64_t d = 0; d < dim; ++d) { - auto reducedType = - rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr; - auto llvmType = typeConverter->convertType( - rank > 1 ? reducedType : vectorType.getElementType()); + auto reducedType = reducedVectorTypeFront(vectorType); + auto llvmType = typeConverter->convertType(reducedType); Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value, llvmType, rank, d); emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1, diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -832,6 +832,23 @@ // ----- +func @vector_print_vector_0d(%arg0: vector) { + vector.print %arg0 : vector + return +} +// CHECK-LABEL: @vector_print_vector_0d( +// CHECK-SAME: %[[A:.*]]: vector) +// CHECK: %[[T0:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector to vector<1xf32> +// CHECK: llvm.call @printOpen() : () -> () +// CHECK: %[[T1:.*]] = llvm.mlir.constant(0 : index) : i64 +// CHECK: %[[T2:.*]] = llvm.extractelement %[[T0]][%[[T1]] : i64] : vector<1xf32> +// CHECK: llvm.call @printF32(%[[T2]]) : (f32) -> () +// CHECK: llvm.call @printClose() : () -> () +// CHECK: llvm.call @printNewline() : () -> () +// CHECK: return + +// ----- + func @vector_print_vector(%arg0: vector<2x2xf32>) { vector.print %arg0 : vector<2x2xf32> return diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir @@ -15,10 +15,20 @@ return %1: vector } +func @print_vector_0d(%a: vector) { + // CHECK: ( 42 ) + vector.print %a: vector + return +} + func @entry() { %0 = arith.constant 42.0 : f32 %1 = arith.constant dense<0.0> : vector %2 = call @insert_element_0d(%0, %1) : (f32, vector) -> (vector) call @extract_element_0d(%2) : (vector) -> () + + %3 = arith.constant dense<42.0> : vector + call @print_vector_0d(%3) : (vector) -> () + return }