diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -1589,14 +1589,9 @@ let results = (outs AnyTensor:$result); let skipDefaultBuilders = 1; - let builders = [OpBuilder< - "OpBuilder &builder, OperationState &result, ValueRange elements", [{ - assert(!elements.empty() && "expected at least one element"); - result.addOperands(elements); - result.addTypes( - RankedTensorType::get({static_cast(elements.size())}, - *elements.getTypes().begin())); - }]>]; + let builders = [ + OpBuilder<"OpBuilder &b, OperationState &result, ValueRange elements"> + ]; let hasCanonicalizer = 1; } diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -1684,9 +1684,9 @@ OperationState &result) { SmallVector elementsOperands; Type resultType; - if (parser.parseLParen() || parser.parseOperandList(elementsOperands) || - parser.parseRParen() || parser.parseOptionalAttrDict(result.attributes) || - parser.parseColon() || parser.parseType(resultType)) + if (parser.parseOperandList(elementsOperands) || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(resultType)) return failure(); if (parser.resolveOperands(elementsOperands, @@ -1699,9 +1699,9 @@ } static void print(OpAsmPrinter &p, TensorFromElementsOp op) { - p << "tensor_from_elements(" << op.elements() << ')'; + p << "tensor_from_elements " << op.elements(); p.printOptionalAttrDict(op.getAttrs()); - p << " : " << op.result().getType(); + p << " : " << op.getType(); } static LogicalResult verify(TensorFromElementsOp op) { @@ -1718,6 +1718,14 @@ return success(); } +void TensorFromElementsOp::build(OpBuilder &builder, OperationState &result, + ValueRange elements) { + assert(!elements.empty() && "expected at least one element"); + result.addOperands(elements); + result.addTypes(RankedTensorType::get({static_cast(elements.size())}, + *elements.getTypes().begin())); +} + namespace { // Canonicalizes the pattern of the form diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir --- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir +++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir @@ -44,7 +44,7 @@ // CHECK-DAG: %[[C1:.*]] = constant 1 : index // CHECK-DAG: %[[C2:.*]] = constant 2 : index // CHECK-DAG: %[[C3:.*]] = constant 3 : index - // CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor_from_elements(%[[C1]], %[[C2]], %[[C3]]) : tensor<3xindex> + // CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor_from_elements %[[C1]], %[[C2]], %[[C3]] : tensor<3xindex> %shape = shape.shape_of %arg : tensor<1x2x3xf32> -> tensor return } @@ -59,7 +59,7 @@ // CHECK-DAG: %[[C5:.*]] = constant 5 : index // CHECK-DAG: %[[C2:.*]] = constant 2 : index // CHECK-DAG: %[[DYN_DIM:.*]] = dim %[[ARG]], %[[C2]] : tensor<1x5x?xf32> - // CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor_from_elements(%[[C1]], %[[C5]], %[[DYN_DIM]]) : tensor<3xindex> + // CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor_from_elements %[[C1]], %[[C5]], %[[DYN_DIM]] : tensor<3xindex> %shape = shape.shape_of %arg : tensor<1x5x?xf32> -> tensor return } @@ -134,7 +134,7 @@ // CHECK: %[[C1:.*]] = constant 1 : index // CHECK: %[[C2:.*]] = constant 2 : index // CHECK: %[[C3:.*]] = constant 3 : index - // CHECK: %[[TENSOR3:.*]] = tensor_from_elements(%[[C1]], %[[C2]], %[[C3]]) + // CHECK: %[[TENSOR3:.*]] = tensor_from_elements %[[C1]], %[[C2]], %[[C3]] // CHECK: %[[RESULT:.*]] = tensor_cast %[[TENSOR3]] : tensor<3xindex> to tensor // CHECK: return %[[RESULT]] : tensor %shape = shape.const_shape [1, 2, 3] : tensor diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir --- a/mlir/test/IR/core-ops.mlir +++ b/mlir/test/IR/core-ops.mlir @@ -661,17 +661,17 @@ // CHECK-LABEL: func @tensor_from_elements() { func @tensor_from_elements() { %c0 = "std.constant"() {value = 0: index} : () -> index - // CHECK: %0 = tensor_from_elements(%c0) : tensor<1xindex> - %0 = tensor_from_elements(%c0) : tensor<1xindex> + // CHECK: %0 = tensor_from_elements %c0 : tensor<1xindex> + %0 = tensor_from_elements %c0 : tensor<1xindex> %c1 = "std.constant"() {value = 1: index} : () -> index - // CHECK: %1 = tensor_from_elements(%c0, %c1) : tensor<2xindex> - %1 = tensor_from_elements(%c0, %c1) : tensor<2xindex> + // CHECK: %1 = tensor_from_elements %c0, %c1 : tensor<2xindex> + %1 = tensor_from_elements %c0, %c1 : tensor<2xindex> %c0_f32 = "std.constant"() {value = 0.0: f32} : () -> f32 // CHECK: [[C0_F32:%.*]] = constant - // CHECK: %2 = tensor_from_elements([[C0_F32]]) : tensor<1xf32> - %2 = tensor_from_elements(%c0_f32) : tensor<1xf32> + // CHECK: %2 = tensor_from_elements [[C0_F32]] : tensor<1xf32> + %2 = tensor_from_elements %c0_f32 : tensor<1xf32> return } diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -597,7 +597,7 @@ func @tensor_from_elements_wrong_result_type() { // expected-error@+2 {{expected result type to be a ranked tensor}} %c0 = constant 0 : i32 - %0 = tensor_from_elements(%c0) : tensor<*xi32> + %0 = tensor_from_elements %c0 : tensor<*xi32> return } @@ -606,7 +606,7 @@ func @tensor_from_elements_wrong_elements_count() { // expected-error@+2 {{expected result type to be a 1D tensor with 1 element}} %c0 = constant 0 : index - %0 = tensor_from_elements(%c0) : tensor<2xindex> + %0 = tensor_from_elements %c0 : tensor<2xindex> return } diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir --- a/mlir/test/Transforms/canonicalize.mlir +++ b/mlir/test/Transforms/canonicalize.mlir @@ -981,7 +981,7 @@ func @extract_element_from_tensor_from_elements(%element : index) -> index { // CHECK-SAME: ([[ARG:%.*]]: index) %c0 = constant 0 : index - %tensor = tensor_from_elements(%element) : tensor<1xindex> + %tensor = tensor_from_elements %element : tensor<1xindex> %extracted_element = extract_element %tensor[%c0] : tensor<1xindex> // CHECK: [[ARG]] : index return %extracted_element : index