diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -701,17 +701,26 @@ // ExtractElementOp //===----------------------------------------------------------------------===// -def LLVM_ExtractElementOp : LLVM_Op<"extractelement", [NoSideEffect]> { +def LLVM_ExtractElementOp : LLVM_Op<"extractelement", [NoSideEffect, + TypesMatchWith<"result type matches vector element type", "vector", "res", + "LLVM::getVectorElementType($_self)">]> { + let summary = "Extract an element from an LLVM vector."; + let arguments = (ins LLVM_AnyVector:$vector, AnyInteger:$position); let results = (outs LLVM_Type:$res); + + let builders = [ + OpBuilder<(ins "Value":$vector, "Value":$position, + CArg<"ArrayRef", "{}">:$attrs)> + ]; + + let assemblyFormat = [{ + $vector `[` $position `:` type($position) `]` attr-dict `:` type($vector) + }]; + string llvmBuilder = [{ $res = builder.CreateExtractElement($vector, $position); }]; - let builders = [ - OpBuilder<(ins "Value":$vector, "Value":$position, - CArg<"ArrayRef", "{}">:$attrs)>]; - let hasCustomAssemblyFormat = 1; - let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -746,16 +755,26 @@ // InsertElementOp //===----------------------------------------------------------------------===// -def LLVM_InsertElementOp : LLVM_Op<"insertelement", [NoSideEffect]> { +def LLVM_InsertElementOp : LLVM_Op<"insertelement", [NoSideEffect, + TypesMatchWith<"argument type matches vector element type", "vector", + "value", "LLVM::getVectorElementType($_self)">, + AllTypesMatch<["res", "vector"]>]> { + let summary = "Insert an element into an LLVM vector."; + let arguments = (ins LLVM_AnyVector:$vector, LLVM_PrimitiveType:$value, AnyInteger:$position); let results = (outs LLVM_AnyVector:$res); + + let builders = [LLVM_OneResultOpBuilder]; + + let assemblyFormat = [{ + $value `,` $vector `[` $position `:` type($position) `]` attr-dict `:` + type($vector) + }]; + string llvmBuilder = [{ $res = builder.CreateInsertElement($vector, $value, $position); }]; - let builders = [LLVM_OneResultOpBuilder]; - let hasCustomAssemblyFormat = 1; - let hasVerifier = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -1365,10 +1365,10 @@ } //===----------------------------------------------------------------------===// -// Printing/parsing for LLVM::ExtractElementOp. +// ExtractElementOp //===----------------------------------------------------------------------===// -// Expects vector to be of wrapped LLVM vector type and position to be of -// wrapped LLVM i32 type. + +/// Expects vector to be an LLVM vector type and position to be an integer type. void LLVM::ExtractElementOp::build(OpBuilder &b, OperationState &result, Value vector, Value position, ArrayRef attrs) { @@ -1378,49 +1378,6 @@ result.addAttributes(attrs); } -void ExtractElementOp::print(OpAsmPrinter &p) { - p << ' ' << getVector() << "[" << getPosition() << " : " - << getPosition().getType() << "]"; - p.printOptionalAttrDict((*this)->getAttrs()); - p << " : " << getVector().getType(); -} - -// ::= `llvm.extractelement` ssa-use `, ` ssa-use -// attribute-dict? `:` type -ParseResult ExtractElementOp::parse(OpAsmParser &parser, - OperationState &result) { - SMLoc loc; - OpAsmParser::UnresolvedOperand vector, position; - Type type, positionType; - if (parser.getCurrentLocation(&loc) || parser.parseOperand(vector) || - parser.parseLSquare() || parser.parseOperand(position) || - parser.parseColonType(positionType) || parser.parseRSquare() || - parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(type) || - parser.resolveOperand(vector, type, result.operands) || - parser.resolveOperand(position, positionType, result.operands)) - return failure(); - if (!LLVM::isCompatibleVectorType(type)) - return parser.emitError( - loc, "expected LLVM dialect-compatible vector type for operand #1"); - result.addTypes(LLVM::getVectorElementType(type)); - return success(); -} - -LogicalResult ExtractElementOp::verify() { - Type vectorType = getVector().getType(); - if (!LLVM::isCompatibleVectorType(vectorType)) - return emitOpError("expected LLVM dialect-compatible vector type for " - "operand #1, got") - << vectorType; - Type valueType = LLVM::getVectorElementType(vectorType); - if (valueType != getRes().getType()) - return emitOpError() << "Type mismatch: extracting from " << vectorType - << " should produce " << valueType - << " but this op returns " << getRes().getType(); - return success(); -} - //===----------------------------------------------------------------------===// // ExtractValueOp //===----------------------------------------------------------------------===// @@ -1530,57 +1487,6 @@ container, builder.getAttr(position)); } -//===----------------------------------------------------------------------===// -// Printing/parsing for LLVM::InsertElementOp. -//===----------------------------------------------------------------------===// - -void InsertElementOp::print(OpAsmPrinter &p) { - p << ' ' << getValue() << ", " << getVector() << "[" << getPosition() << " : " - << getPosition().getType() << "]"; - p.printOptionalAttrDict((*this)->getAttrs()); - p << " : " << getVector().getType(); -} - -// ::= `llvm.insertelement` ssa-use `,` ssa-use `,` ssa-use -// attribute-dict? `:` type -ParseResult InsertElementOp::parse(OpAsmParser &parser, - OperationState &result) { - SMLoc loc; - OpAsmParser::UnresolvedOperand vector, value, position; - Type vectorType, positionType; - if (parser.getCurrentLocation(&loc) || parser.parseOperand(value) || - parser.parseComma() || parser.parseOperand(vector) || - parser.parseLSquare() || parser.parseOperand(position) || - parser.parseColonType(positionType) || parser.parseRSquare() || - parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(vectorType)) - return failure(); - - if (!LLVM::isCompatibleVectorType(vectorType)) - return parser.emitError( - loc, "expected LLVM dialect-compatible vector type for operand #1"); - Type valueType = LLVM::getVectorElementType(vectorType); - if (!valueType) - return failure(); - - if (parser.resolveOperand(vector, vectorType, result.operands) || - parser.resolveOperand(value, valueType, result.operands) || - parser.resolveOperand(position, positionType, result.operands)) - return failure(); - - result.addTypes(vectorType); - return success(); -} - -LogicalResult InsertElementOp::verify() { - Type valueType = LLVM::getVectorElementType(getVector().getType()); - if (valueType != getValue().getType()) - return emitOpError() << "Type mismatch: cannot insert " - << getValue().getType() << " into " - << getVector().getType(); - return success(); -} - //===----------------------------------------------------------------------===// // InsertValueOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -458,14 +458,14 @@ // ----- func.func @invalid_vector_type_1(%arg0: vector<4xf32>, %arg1: i32, %arg2: f32) { - // expected-error@+1 {{expected LLVM dialect-compatible vector type for operand #1}} + // expected-error@+1 {{'vector' must be LLVM dialect-compatible vector}} %0 = llvm.extractelement %arg2[%arg1 : i32] : f32 } // ----- func.func @invalid_vector_type_2(%arg0: vector<4xf32>, %arg1: i32, %arg2: f32) { - // expected-error@+1 {{expected LLVM dialect-compatible vector type for operand #1}} + // expected-error@+1 {{'vector' must be LLVM dialect-compatible vector}} %0 = llvm.insertelement %arg2, %arg2[%arg1 : i32] : f32 } @@ -479,7 +479,7 @@ // ----- func.func @invalid_vector_type_4(%a : vector<4xf32>, %idx : i32) -> vector<4xf32> { - // expected-error@+1 {{'llvm.extractelement' op Type mismatch: extracting from 'vector<4xf32>' should produce 'f32' but this op returns 'vector<4xf32>'}} + // expected-error@+1 {{failed to verify that result type matches vector element type}} %b = "llvm.extractelement"(%a, %idx) : (vector<4xf32>, i32) -> vector<4xf32> return %b : vector<4xf32> } @@ -487,7 +487,7 @@ // ----- func.func @invalid_vector_type_5(%a : vector<4xf32>, %idx : i32) -> vector<4xf32> { - // expected-error@+1 {{'llvm.insertelement' op Type mismatch: cannot insert 'vector<4xf32>' into 'vector<4xf32>'}} + // expected-error@+1 {{failed to verify that argument type matches vector element type}} %b = "llvm.insertelement"(%a, %a, %idx) : (vector<4xf32>, vector<4xf32>, i32) -> vector<4xf32> return %b : vector<4xf32> }