diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -57,8 +57,7 @@ static Value extractOne(ConversionPatternRewriter &rewriter, LLVMTypeConverter &typeConverter, Location loc, Value val, Type llvmType, int64_t rank, int64_t pos) { - assert(rank > 0 && "0-D vector corner case should have been handled already"); - if (rank == 1) { + if (rank == 0 || rank == 1) { auto idxType = rewriter.getIndexType(); auto constant = rewriter.create( loc, typeConverter.convertType(idxType), @@ -1009,7 +1008,9 @@ Value value, VectorType vectorType, Operation *printer, int64_t rank, PrintConversion conversion) const { Location loc = op->getLoc(); - if (rank == 0) { + // Note that we can have 0-D vectors, in which case `vectorType != nullptr`, + // but `rank == 0`. + if (!vectorType && rank == 0) { switch (conversion) { case PrintConversion::ZeroExt64: value = rewriter.create( @@ -1030,7 +1031,9 @@ LLVM::lookupOrCreatePrintOpenFn(op->getParentOfType())); Operation *printComma = LLVM::lookupOrCreatePrintCommaFn(op->getParentOfType()); - int64_t dim = vectorType.getDimSize(0); + + // `rank` can be 0 for 0-D vectors. + int64_t dim = rank == 0 ? 1 : vectorType.getDimSize(0); for (int64_t d = 0; d < dim; ++d) { auto reducedType = rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr; @@ -1038,7 +1041,8 @@ rank > 1 ? reducedType : vectorType.getElementType()); Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value, llvmType, rank, d); - emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1, + int64_t lowerRank = rank == 0 ? 0 : rank - 1; + emitRanks(rewriter, op, nestedVal, reducedType, printer, lowerRank, conversion); if (d != dim - 1) emitCall(rewriter, loc, printComma); diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -832,6 +832,23 @@ // ----- +func @vector_print_vector_0d(%arg0: vector) { + vector.print %arg0 : vector + return +} +// CHECK-LABEL: @vector_print_vector_0d( +// CHECK-SAME: %[[A:.*]]: vector) +// CHECK: %[[T0:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector to vector<1xf32> +// CHECK: llvm.call @printOpen() : () -> () +// CHECK: %[[T1:.*]] = llvm.mlir.constant(0 : index) : i64 +// CHECK: %[[T2:.*]] = llvm.extractelement %[[T0]][%[[T1]] : i64] : vector<1xf32> +// CHECK: llvm.call @printF32(%[[T2]]) : (f32) -> () +// CHECK: llvm.call @printClose() : () -> () +// CHECK: llvm.call @printNewline() : () -> () +// CHECK: return + +// ----- + func @vector_print_vector(%arg0: vector<2x2xf32>) { vector.print %arg0 : vector<2x2xf32> return diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir @@ -15,10 +15,20 @@ return %1: vector } +func @print_vector_0d(%a: vector) { + // CHECK: ( 42 ) + vector.print %a: vector + return +} + func @entry() { %0 = arith.constant 42.0 : f32 %1 = arith.constant dense<0.0> : vector %2 = call @insert_element_0d(%0, %1) : (f32, vector) -> (vector) call @extract_element_0d(%2) : (vector) -> () + + %3 = arith.constant dense<42.0> : vector + call @print_vector_0d(%3) : (vector) -> () + return }