diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -544,6 +544,7 @@ $res = builder.CreateInsertElement($vector, $value, $position); }]; let builders = [LLVM_OneResultOpBuilder]; + let verifier = [{ return ::verify(*this); }]; let parser = [{ return parseInsertElementOp(parser, result); }]; let printer = [{ printInsertElementOp(p, *this); }]; } @@ -560,6 +561,7 @@ [{ build($_builder, $_state, container.getType(), container, value, position); }]>]; + let verifier = [{ return ::verify(*this); }]; let parser = [{ return parseInsertValueOp(parser, result); }]; let printer = [{ printInsertValueOp(p, *this); }]; } diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -541,6 +541,7 @@ } // For all other cases, insert the individual values individually. + Type elt_type = llvmType.cast().getElementType(); Value insert = rewriter.create(loc, llvmType); int64_t insPos = 0; for (auto en : llvm::enumerate(maskArrayAttr)) { @@ -551,7 +552,7 @@ value = adaptor.v2(); } Value extract = extractOne(rewriter, *getTypeConverter(), loc, value, - llvmType, rank, extPos); + elt_type, rank, extPos); insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract, llvmType, rank, insPos++); } diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -1024,6 +1024,52 @@ return llvmType; } +// Extract the type at `position` in the wrapped LLVM IR aggregate type +// `containerType`. Returns null on failure. +static Type getInsertExtractValueElementType(Type containerType, + ArrayAttr positionAttr, + Operation *op) { + Type llvmType = containerType; + if (!isCompatibleType(containerType)) { + op->emitError("expected LLVM IR Dialect type, got ") << containerType; + return {}; + } + + // Infer the element type from the structure type: iteratively step inside the + // type by taking the element type, indexed by the position attribute for + // structures. Check the position index before accessing, it is supposed to + // be in bounds. + for (Attribute subAttr : positionAttr) { + auto positionElementAttr = subAttr.dyn_cast(); + if (!positionElementAttr) { + op->emitOpError("expected an array of integer literals, got: ") + << subAttr; + return {}; + } + int position = positionElementAttr.getInt(); + if (auto arrayType = llvmType.dyn_cast()) { + if (position < 0 || + static_cast(position) >= arrayType.getNumElements()) { + op->emitOpError("position out of bounds: ") << position; + return {}; + } + llvmType = arrayType.getElementType(); + } else if (auto structType = llvmType.dyn_cast()) { + if (position < 0 || + static_cast(position) >= structType.getBody().size()) { + op->emitOpError("position out of bounds") << position; + return {}; + } + llvmType = structType.getBody()[position]; + } else { + op->emitOpError("expected LLVM IR structure/array type, got: ") + << llvmType; + return {}; + } + } + return llvmType; +} + // ::= `llvm.extractvalue` ssa-use // `[` integer-literal (`,` integer-literal)* `]` // attribute-dict? `:` type @@ -1104,6 +1150,14 @@ return success(); } +static LogicalResult verify(InsertElementOp op) { + Type valueType = LLVM::getVectorElementType(op.vector().getType()); + if (valueType != op.value().getType()) + return op.emitOpError() + << "Type mismatch: cannot insert " << op.value().getType() + << " into " << op.vector().getType(); + return success(); +} //===----------------------------------------------------------------------===// // Printing/parsing for LLVM::InsertValueOp. //===----------------------------------------------------------------------===// @@ -1147,6 +1201,20 @@ return success(); } +static LogicalResult verify(InsertValueOp op) { + Type valueType = getInsertExtractValueElementType(op.container().getType(), + op.positionAttr(), op); + if (!valueType) + return failure(); + + if (op.value().getType() != valueType) + return op.emitOpError() + << "Type mismatch: cannot insert " << op.value().getType() + << " into " << op.container().getType(); + + return success(); +} + //===----------------------------------------------------------------------===// // Printing, parsing and verification for LLVM::ReturnOp. //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1,5 +1,7 @@ // RUN: mlir-opt %s -convert-vector-to-llvm -split-input-file | FileCheck %s +// Verifier failure on insertvalue +// XFAIL: * func @bitcast_f32_to_i32_vector(%input: vector<16xf32>) -> vector<16xi32> { %0 = vector.bitcast %input : vector<16xf32> to vector<16xi32> diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-shuffle.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-shuffle.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/test-shuffle.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-shuffle.mlir @@ -3,6 +3,7 @@ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ // RUN: FileCheck %s + func @entry() { %f1 = constant 1.0: f32 %f2 = constant 2.0: f32