diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -522,6 +522,7 @@ let builders = [ OpBuilder<(ins "Value":$vector, "Value":$position, CArg<"ArrayRef", "{}">:$attrs)>]; + let verifier = [{ return ::verify(*this); }]; let parser = [{ return parseExtractElementOp(parser, result); }]; let printer = [{ printExtractElementOp(p, *this); }]; } @@ -532,6 +533,7 @@ $res = builder.CreateExtractValue($container, extractPosition($position)); }]; let builders = [LLVM_OneResultOpBuilder]; + let verifier = [{ return ::verify(*this); }]; let parser = [{ return parseExtractValueOp(parser, result); }]; let printer = [{ printExtractValueOp(p, *this); }]; let hasFolder = 1; @@ -544,6 +546,7 @@ $res = builder.CreateInsertElement($vector, $value, $position); }]; let builders = [LLVM_OneResultOpBuilder]; + let verifier = [{ return ::verify(*this); }]; let parser = [{ return parseInsertElementOp(parser, result); }]; let printer = [{ printInsertElementOp(p, *this); }]; } @@ -560,6 +563,7 @@ [{ build($_builder, $_state, container.getType(), container, value, position); }]>]; + let verifier = [{ return ::verify(*this); }]; let parser = [{ return parseInsertValueOp(parser, result); }]; let printer = [{ printInsertValueOp(p, *this); }]; } diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -541,6 +541,12 @@ } // For all other cases, insert the individual values individually. + Type eltType; + llvm::errs() << llvmType << "\n"; + if (auto arrayType = llvmType.dyn_cast()) + eltType = arrayType.getElementType(); + else + eltType = llvmType.cast().getElementType(); Value insert = rewriter.create(loc, llvmType); int64_t insPos = 0; for (auto en : llvm::enumerate(maskArrayAttr)) { @@ -551,7 +557,7 @@ value = adaptor.v2(); } Value extract = extractOne(rewriter, *getTypeConverter(), loc, value, - llvmType, rank, extPos); + eltType, rank, extPos); insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract, llvmType, rank, insPos++); } diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -970,6 +970,20 @@ return success(); } +static LogicalResult verify(ExtractElementOp op) { + Type vectorType = op.vector().getType(); + if (!LLVM::isCompatibleVectorType(vectorType)) + return op->emitOpError("expected LLVM dialect-compatible vector type for " + "operand #1, got") + << vectorType; + Type valueType = LLVM::getVectorElementType(vectorType); + if (valueType != op.res().getType()) + return op.emitOpError() << "Type mismatch: extracting from " << vectorType + << " should produce " << valueType + << " but this op returns " << op.res().getType(); + return success(); +} + //===----------------------------------------------------------------------===// // Printing/parsing for LLVM::ExtractValueOp. //===----------------------------------------------------------------------===// @@ -1024,6 +1038,52 @@ return llvmType; } +// Extract the type at `position` in the wrapped LLVM IR aggregate type +// `containerType`. Returns null on failure. +static Type getInsertExtractValueElementType(Type containerType, + ArrayAttr positionAttr, + Operation *op) { + Type llvmType = containerType; + if (!isCompatibleType(containerType)) { + op->emitError("expected LLVM IR Dialect type, got ") << containerType; + return {}; + } + + // Infer the element type from the structure type: iteratively step inside the + // type by taking the element type, indexed by the position attribute for + // structures. Check the position index before accessing, it is supposed to + // be in bounds. + for (Attribute subAttr : positionAttr) { + auto positionElementAttr = subAttr.dyn_cast(); + if (!positionElementAttr) { + op->emitOpError("expected an array of integer literals, got: ") + << subAttr; + return {}; + } + int position = positionElementAttr.getInt(); + if (auto arrayType = llvmType.dyn_cast()) { + if (position < 0 || + static_cast(position) >= arrayType.getNumElements()) { + op->emitOpError("position out of bounds: ") << position; + return {}; + } + llvmType = arrayType.getElementType(); + } else if (auto structType = llvmType.dyn_cast()) { + if (position < 0 || + static_cast(position) >= structType.getBody().size()) { + op->emitOpError("position out of bounds") << position; + return {}; + } + llvmType = structType.getBody()[position]; + } else { + op->emitOpError("expected LLVM IR structure/array type, got: ") + << llvmType; + return {}; + } + } + return llvmType; +} + // ::= `llvm.extractvalue` ssa-use // `[` integer-literal (`,` integer-literal)* `]` // attribute-dict? `:` type @@ -1062,6 +1122,20 @@ return {}; } +static LogicalResult verify(ExtractValueOp op) { + Type valueType = getInsertExtractValueElementType(op.container().getType(), + op.positionAttr(), op); + if (!valueType) + return failure(); + + if (op.res().getType() != valueType) + return op.emitOpError() + << "Type mismatch: extracting from " << op.container().getType() + << " should produce " << valueType << " but this op returns " + << op.res().getType(); + return success(); +} + //===----------------------------------------------------------------------===// // Printing/parsing for LLVM::InsertElementOp. //===----------------------------------------------------------------------===// @@ -1104,6 +1178,14 @@ return success(); } +static LogicalResult verify(InsertElementOp op) { + Type valueType = LLVM::getVectorElementType(op.vector().getType()); + if (valueType != op.value().getType()) + return op.emitOpError() + << "Type mismatch: cannot insert " << op.value().getType() + << " into " << op.vector().getType(); + return success(); +} //===----------------------------------------------------------------------===// // Printing/parsing for LLVM::InsertValueOp. //===----------------------------------------------------------------------===// @@ -1147,6 +1229,20 @@ return success(); } +static LogicalResult verify(InsertValueOp op) { + Type valueType = getInsertExtractValueElementType(op.container().getType(), + op.positionAttr(), op); + if (!valueType) + return failure(); + + if (op.value().getType() != valueType) + return op.emitOpError() + << "Type mismatch: cannot insert " << op.value().getType() + << " into " << op.container().getType(); + + return success(); +} + //===----------------------------------------------------------------------===// // Printing, parsing and verification for LLVM::ReturnOp. //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -355,6 +355,24 @@ llvm.insertvalue %a, %b[0,0] : !llvm.struct<(i32)> } +// ----- + +func @insertvalue_invalid_type(%a : !llvm.array<1 x i32>) -> !llvm.array<1 x i32> { + // expected-error@+1 {{'llvm.insertvalue' op Type mismatch: cannot insert '!llvm.array<1 x i32>' into '!llvm.array<1 x i32>'}} + %b = "llvm.insertvalue"(%a, %a) {position = [0]} : (!llvm.array<1 x i32>, !llvm.array<1 x i32>) -> !llvm.array<1 x i32> + return %b : !llvm.array<1 x i32> +} + +// ----- + +func @extractvalue_invalid_type(%a : !llvm.array<4 x vector<8xf32>>) -> !llvm.array<4 x vector<8xf32>> { + // expected-error@+1 {{'llvm.extractvalue' op Type mismatch: extracting from '!llvm.array<4 x vector<8xf32>>' should produce 'vector<8xf32>' but this op returns '!llvm.array<4 x vector<8xf32>>'}} + %b = "llvm.extractvalue"(%a) {position = [1]} + : (!llvm.array<4 x vector<8xf32>>) -> !llvm.array<4 x vector<8xf32>> + return %b : !llvm.array<4 x vector<8xf32>> +} + + // ----- func @extractvalue_non_llvm_type(%a : i32, %b : tensor<*xi32>) { @@ -422,6 +440,22 @@ // ----- +func @invalid_vector_type_4(%a : vector<4xf32>, %idx : i32) -> vector<4xf32> { + // expected-error@+1 {{'llvm.extractelement' op Type mismatch: extracting from 'vector<4xf32>' should produce 'f32' but this op returns 'vector<4xf32>'}} + %b = "llvm.extractelement"(%a, %idx) : (vector<4xf32>, i32) -> vector<4xf32> + return %b : vector<4xf32> +} + +// ----- + +func @invalid_vector_type_5(%a : vector<4xf32>, %idx : i32) -> vector<4xf32> { + // expected-error@+1 {{'llvm.insertelement' op Type mismatch: cannot insert 'vector<4xf32>' into 'vector<4xf32>'}} + %b = "llvm.insertelement"(%a, %a, %idx) : (vector<4xf32>, vector<4xf32>, i32) -> vector<4xf32> + return %b : vector<4xf32> +} + +// ----- + func @null_non_llvm_type() { // expected-error@+1 {{must be LLVM pointer type, but got 'i32'}} llvm.mlir.null : i32