diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -706,20 +706,32 @@ "LLVM::getVectorElementType($_self)">]> { let summary = "Extract an element from an LLVM vector."; - let arguments = (ins LLVM_AnyVector:$vector, AnyInteger:$position); + let arguments = (ins LLVM_AnyVector:$vector, + Optional:$dynamicPosition, + OptionalAttr:$constPosition); let results = (outs LLVM_Type:$res); let builders = [ OpBuilder<(ins "Value":$vector, "Value":$position, - CArg<"ArrayRef", "{}">:$attrs)> + CArg<"ArrayRef", "{}">:$attrs)>, + OpBuilder<(ins "Type":$res, "Value":$vector, "Value":$position)>, + OpBuilder<(ins "Value":$vector, "int64_t":$position)>, ]; let assemblyFormat = [{ - $vector `[` $position `:` type($position) `]` attr-dict `:` type($vector) + $vector `[` + ($dynamicPosition^ `:` type($dynamicPosition)) : (`` $constPosition) ? + `]` attr-dict `:` type($vector) }]; + let hasVerifier = 1; + string llvmBuilder = [{ - $res = builder.CreateExtractElement($vector, $position); + if (op.getDynamicPosition()) + $res = builder.CreateExtractElement($vector, $dynamicPosition.front()); + else + $res = builder.CreateExtractElement( + $vector, builder.getInt($constPosition.value())); }]; } @@ -762,18 +774,33 @@ let summary = "Insert an element into an LLVM vector."; let arguments = (ins LLVM_AnyVector:$vector, LLVM_PrimitiveType:$value, - AnyInteger:$position); + Optional:$dynamicPosition, + OptionalAttr:$constPosition); let results = (outs LLVM_AnyVector:$res); - let builders = [LLVM_OneResultOpBuilder]; + let builders = [ + LLVM_OneResultOpBuilder, + OpBuilder<(ins "Type":$res, "Value":$vector, "Value":$value, + "Value":$position)>, + OpBuilder<(ins "Value":$vector, "Value":$value, "Value":$position)>, + OpBuilder<(ins "Value":$vector, "Value":$value, "int64_t":$position)>, + ]; let assemblyFormat = [{ - $value `,` $vector `[` $position `:` type($position) `]` attr-dict `:` - type($vector) + $value `,` $vector `[` + ($dynamicPosition^ `:` type($dynamicPosition)) : (`` $constPosition) ? + `]` attr-dict `:` type($vector) }]; + let hasVerifier = 1; + string llvmBuilder = [{ - $res = builder.CreateInsertElement($vector, $value, $position); + if (op.getDynamicPosition()) + $res = builder.CreateInsertElement($vector, $value, + $dynamicPosition.front()); + else + $res = builder.CreateInsertElement( + $vector, $value, builder.getInt(op.getConstPosition().value())); }]; } diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -1368,14 +1368,50 @@ // ExtractElementOp //===----------------------------------------------------------------------===// -/// Expects vector to be an LLVM vector type and position to be an integer type. -void LLVM::ExtractElementOp::build(OpBuilder &b, OperationState &result, - Value vector, Value position, - ArrayRef attrs) { - auto vectorType = vector.getType(); - auto llvmType = LLVM::getVectorElementType(vectorType); - build(b, result, llvmType, vector, position); - result.addAttributes(attrs); +/// Verify the dynamic and constant indices of an insert or extract element +/// operation. +static LogicalResult verifyElementIndex(Operation *op, Value dynamicPosition, + IntegerAttr constPosition, + Value vector) { + if (dynamicPosition && constPosition) + return op->emitOpError( + "expected either a dynamic position or constant position, not both"); + else if (!dynamicPosition && !constPosition) + return op->emitOpError( + "expected a dynamic position or a constant position"); + + if (constPosition) { + int64_t pos = constPosition.getInt(); + if (pos < 0 || pos >= vector.getType().cast().getNumElements()) + return op->emitOpError("constant index out-of-range"); + } + return success(); +} + +LogicalResult ExtractElementOp::verify() { + return verifyElementIndex(*this, getDynamicPosition(), getConstPositionAttr(), + getVector()); +} + +void ExtractElementOp::build(OpBuilder &builder, OperationState &state, + Value vector, Value position, + ArrayRef attrs) { + build(builder, state, LLVM::getVectorElementType(vector.getType()), vector, + position); + state.addAttributes(attrs); +} + +void ExtractElementOp::build(OpBuilder &builder, OperationState &state, + Type res, Value vector, Value position) { + build(builder, state, res, vector, /*dynamicPosition=*/position, + /*constPosition=*/nullptr); +} + +void ExtractElementOp::build(OpBuilder &builder, OperationState &state, + Value vector, int64_t position) { + build(builder, state, LLVM::getVectorElementType(vector.getType()), vector, + /*dynamicPosition=*/nullptr, + /*constPosition=*/builder.getIndexAttr(position)); } //===----------------------------------------------------------------------===// @@ -1487,6 +1523,33 @@ container, builder.getAttr(position)); } +//===----------------------------------------------------------------------===// +// InsertElementOp +//===----------------------------------------------------------------------===// + +LogicalResult InsertElementOp::verify() { + return verifyElementIndex(*this, getDynamicPosition(), getConstPositionAttr(), + getVector()); +} + +void InsertElementOp::build(OpBuilder &builder, OperationState &state, Type res, + Value vector, Value value, Value position) { + build(builder, state, res, vector, value, /*dynamicPosition=*/position, + /*constPosition=*/nullptr); +} + +void InsertElementOp::build(OpBuilder &builder, OperationState &state, + Value vector, Value value, Value position) { + build(builder, state, vector.getType(), vector, value, position); +} + +void InsertElementOp::build(OpBuilder &builder, OperationState &state, + Value vector, Value value, int64_t position) { + build(builder, state, vector.getType(), vector, value, + /*dynamicPosition=*/nullptr, + /*constPosition=*/builder.getIndexAttr(position)); +} + //===----------------------------------------------------------------------===// // InsertValueOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -494,6 +494,30 @@ // ----- +func.func @invalid_element_index(%a : vector<4xf32>, %idx : i32) -> f32 { + // expected-error@+1 {{expected either a dynamic position or constant position, not both}} + %b = "llvm.extractelement"(%a, %idx) {constPosition = 1 : index} : (vector<4xf32>, i32) -> f32 + return %b : f32 +} + +// ----- + +func.func @invalid_element_index(%a : vector<4xf32>, %idx : i32) -> f32 { + // expected-error@+1 {{expected a dynamic position or a constant position}} + %b = "llvm.extractelement"(%a) : (vector<4xf32>) -> f32 + return %b : f32 +} + +// ----- + +func.func @invalid_element_index(%a : vector<4xf32>, %idx : i32) -> f32 { + // expected-error@+1 {{constant index out-of-range}} + %b = llvm.extractelement %a[5] : vector<4xf32> + return %b : f32 +} + +// ----- + func.func @null_non_llvm_type() { // expected-error@+1 {{custom op 'llvm.mlir.null' invalid kind of type specified}} llvm.mlir.null : i32 diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir --- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir +++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir @@ -522,3 +522,17 @@ // CHECK: llvm.return llvm.return } + +// CHECK-LABEL: @insert_extract_element +// CHECK-SAME: %[[V0:.*]]: vector<4xf32>, %[[V1:.*]]: vector<4xf32>, %[[I:.*]]: i32 +llvm.func @insert_extract_element(%v0: vector<4xf32>, %v1: vector<4xf32>, %i: i32) { + // CHECK: %[[E0:.*]] = llvm.extractelement %[[V0]][%[[I]] : i32] + // CHECK: %[[E1:.*]] = llvm.extractelement %[[V0]][2] + // CHECK: %[[E2:.*]] = llvm.insertelement %[[E0]], %[[V1]][%[[I]] : i32] + // CHECK: %[[E3:.*]] = llvm.insertelement %[[E1]], %[[E2]][3] + %0 = llvm.extractelement %v0[%i : i32] : vector<4xf32> + %1 = llvm.extractelement %v0[2] : vector<4xf32> + %2 = llvm.insertelement %0, %v1[%i : i32] : vector<4xf32> + %3 = llvm.insertelement %1, %2[3] : vector<4xf32> + llvm.return +}