diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -482,14 +482,20 @@ TypesMatchWith<"result type matches element type of vector operand", "vector", "result", "$_self.cast().getElementType()">]>, - Arguments<(ins AnyVector:$vector, AnySignlessIntegerOrIndex:$position)>, + Arguments<(ins AnyVectorOfAnyRank:$vector, + Optional:$position)>, Results<(outs AnyType:$result)> { let summary = "extractelement operation"; let description = [{ - Takes an 1-D vector and a dynamic index position and extracts the - scalar at that position. Note that this instruction resembles - vector.extract, but is restricted to 1-D vectors and relaxed - to dynamic indices. It is meant to be closer to LLVM's version: + Takes a 0-D or 1-D vector and a optional dynamic index position and + extracts the scalar at that position. + + Note that this instruction resembles vector.extract, but is restricted to + 0-D and 1-D vectors and relaxed to dynamic indices. + If the vector is 0-D, the position must be llvm::None. + + + It is meant to be closer to LLVM's version: https://llvm.org/docs/LangRef.html#extractelement-instruction Example: @@ -497,14 +503,18 @@ ```mlir %c = arith.constant 15 : i32 %1 = vector.extractelement %0[%c : i32]: vector<16xf32> + %2 = vector.extractelement %z[]: vector ``` }]; let assemblyFormat = [{ - $vector `[` $position `:` type($position) `]` attr-dict `:` type($vector) + $vector `[` ($position^ `:` type($position))? `]` attr-dict `:` type($vector) }]; let builders = [ - OpBuilder<(ins "Value":$source, "Value":$position)> + // 0-D builder. + OpBuilder<(ins "Value":$source)>, + // 1-D + position builder. + OpBuilder<(ins "Value":$source, "Value":$position)>, ]; let extraClassDeclaration = [{ VectorType getVectorType() { diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -208,7 +208,12 @@ //===----------------------------------------------------------------------===// // Whether a type is a VectorType. -def IsVectorTypePred : CPred<"$_self.isa<::mlir::VectorType>()">; +// Explicitly disallow 0-D vectors for now until we have good enough coverage. +def IsVectorTypePred : And<[CPred<"$_self.isa<::mlir::VectorType>()">, + CPred<"$_self.cast<::mlir::VectorType>().getRank() > 0">]>; + +// Temporary vector type clone that allows gradual transition to 0-D vectors. +def IsVectorOfAnyRankTypePred : CPred<"$_self.isa<::mlir::VectorType>()">; // Whether a type is a TensorType. def IsTensorTypePred : CPred<"$_self.isa<::mlir::TensorType>()">; @@ -598,6 +603,10 @@ class VectorOf allowedTypes> : ShapedContainerType; +// Temporary vector type clone that allows gradual transition to 0-D vectors. +class VectorOfAnyRankOf allowedTypes> : + ShapedContainerType; // Whether the number of elements of a vector is from the given // `allowedRanks` list @@ -649,6 +658,8 @@ "::mlir::VectorType">; def AnyVector : VectorOf<[AnyType]>; +// Temporary vector type clone that allows gradual transition to 0-D vectors. +def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>; // Shaped types. diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp --- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp +++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp @@ -369,13 +369,17 @@ return LLVM::LLVMPointerType::get(elementType, type.getMemorySpaceAsInt()); } -/// Convert an n-D vector type to an LLVM vector type via (n-1)-D array type -/// when n > 1. For example, `vector<4 x f32>` remains as is while, -/// `vector<4x8x16xf32>` converts to `!llvm.array<4xarray<8 x vector<16xf32>>>`. +/// Convert an n-D vector type to an LLVM vector type: +/// * 0-D `vector` are converted to vector<1xT> +/// * 1-D `vector` remains as is while, +/// * n>1 `vector` convert via an (n-1)-D array type to +/// `!llvm.array>>`. Type LLVMTypeConverter::convertVectorType(VectorType type) { auto elementType = convertType(type.getElementType()); if (!elementType) return {}; + if (type.getShape().empty()) + return VectorType::get({1}, elementType); Type vectorType = VectorType::get(type.getShape().back(), elementType); assert(LLVM::isCompatibleVectorType(vectorType) && "expected vector type compatible with the LLVM dialect"); diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -40,6 +40,7 @@ LLVMTypeConverter &typeConverter, Location loc, Value val1, Value val2, Type llvmType, int64_t rank, int64_t pos) { + assert(rank > 0 && "0-D vector corner case should have been handled already"); if (rank == 1) { auto idxType = rewriter.getIndexType(); auto constant = rewriter.create( @@ -56,6 +57,7 @@ static Value extractOne(ConversionPatternRewriter &rewriter, LLVMTypeConverter &typeConverter, Location loc, Value val, Type llvmType, int64_t rank, int64_t pos) { + assert(rank > 0 && "0-D vector corner case should have been handled already"); if (rank == 1) { auto idxType = rewriter.getIndexType(); auto constant = rewriter.create( @@ -542,6 +544,17 @@ if (!llvmType) return failure(); + if (vectorType.getRank() == 0) { + Location loc = extractEltOp.getLoc(); + auto idxType = rewriter.getIndexType(); + auto zero = rewriter.create( + loc, typeConverter->convertType(idxType), + rewriter.getIntegerAttr(idxType, 0)); + rewriter.replaceOpWithNewOp( + extractEltOp, llvmType, adaptor.vector(), zero); + return success(); + } + rewriter.replaceOpWithNewOp( extractEltOp, llvmType, adaptor.vector(), adaptor.position()); return success(); diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -832,6 +832,12 @@ // ExtractElementOp //===----------------------------------------------------------------------===// +void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result, + Value source) { + result.addOperands({source}); + result.addTypes(source.getType().cast().getElementType()); +} + void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result, Value source, Value position) { result.addOperands({source, position}); @@ -840,8 +846,15 @@ static LogicalResult verify(vector::ExtractElementOp op) { VectorType vectorType = op.getVectorType(); + if (vectorType.getRank() == 0) { + if (op.position()) + return op.emitOpError("expected position to be empty with 0-D vector"); + return success(); + } if (vectorType.getRank() != 1) - return op.emitOpError("expected 1-D vector"); + return op.emitOpError("unexpected >1 vector rank"); + if (!op.position()) + return op.emitOpError("expected position for 1-D vector"); return success(); } diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -418,6 +418,16 @@ // ----- +// CHECK-LABEL: @extract_element_0d +func @extract_element_0d(%a: vector) -> f32 { + // CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64 + // CHECK: llvm.extractelement %{{.*}}[%[[C0]] : {{.*}}] : vector<1xf32> + %1 = vector.extractelement %a[] : vector + return %1 : f32 +} + +// ----- + func @extract_element(%arg0: vector<16xf32>) -> f32 { %0 = arith.constant 15 : i32 %1 = vector.extractelement %arg0[%0 : i32]: vector<16xf32> diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -72,9 +72,25 @@ // ----- +func @extract_element(%arg0: vector) { + %c = arith.constant 3 : i32 + // expected-error@+1 {{expected position to be empty with 0-D vector}} + %1 = vector.extractelement %arg0[%c : i32] : vector +} + +// ----- + +func @extract_element(%arg0: vector<4xf32>) { + %c = arith.constant 3 : i32 + // expected-error@+1 {{expected position for 1-D vector}} + %1 = vector.extractelement %arg0[] : vector<4xf32> +} + +// ----- + func @extract_element(%arg0: vector<4x4xf32>) { %c = arith.constant 3 : i32 - // expected-error@+1 {{'vector.extractelement' op expected 1-D vector}} + // expected-error@+1 {{unexpected >1 vector rank}} %1 = vector.extractelement %arg0[%c : i32] : vector<4x4xf32> } diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -163,6 +163,13 @@ return %1 : vector<3x4xf32> } +// CHECK-LABEL: @extract_element_0d +func @extract_element_0d(%a: vector) -> f32 { + // CHECK-NEXT: vector.extractelement %{{.*}}[] : vector + %1 = vector.extractelement %a[] : vector + return %1 : f32 +} + // CHECK-LABEL: @extract_element func @extract_element(%a: vector<16xf32>) -> f32 { // CHECK: %[[C15:.*]] = arith.constant 15 : i32 diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir @@ -0,0 +1,17 @@ +// RUN: mlir-opt %s -convert-scf-to-std -convert-vector-to-llvm -convert-memref-to-llvm -convert-std-to-llvm -reconcile-unrealized-casts | \ +// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +func @extract_element_0d(%a: vector) { + %1 = vector.extractelement %a[] : vector + // CHECK: 42 + vector.print %1: f32 + return +} + +func @entry() { + %1 = arith.constant dense<42.0> : vector + call @extract_element_0d(%1) : (vector) -> () + return +}