diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -1611,8 +1611,14 @@ // TensorFromElementsOp //===----------------------------------------------------------------------===// -def TensorFromElementsOp : Std_Op<"tensor_from_elements", - [NoSideEffect, SameOperandsAndResultElementType]> { +def TensorFromElementsOp : Std_Op<"tensor_from_elements", [ + NoSideEffect, + SameOperandsAndResultElementType, + TypesMatchWith<"operand types match result element type", + "result", "elements", "SmallVector(" + "$_self.cast().getDimSize(0), " + "$_self.cast().getElementType())"> + ]> { string summary = "tensor from elements operation."; string description = [{ Create a 1D tensor from a range of same-type arguments. @@ -1625,9 +1631,13 @@ }]; let arguments = (ins Variadic:$elements); - let results = (outs AnyTensor:$result); + let results = (outs 1DTensorOf<[AnyType]>:$result); + + let assemblyFormat = "$elements attr-dict `:` type($result)"; + + // This op is fully verified by its traits. + let verifier = ?; - let skipDefaultBuilders = 1; let builders = [ OpBuilder<"OpBuilder &b, OperationState &result, ValueRange elements"> ]; diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -1756,50 +1756,12 @@ // TensorFromElementsOp //===----------------------------------------------------------------------===// -static ParseResult parseTensorFromElementsOp(OpAsmParser &parser, - OperationState &result) { - SmallVector elementsOperands; - Type resultType; - if (parser.parseOperandList(elementsOperands) || - parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(resultType)) - return failure(); - - if (parser.resolveOperands(elementsOperands, - resultType.cast().getElementType(), - result.operands)) - return failure(); - - result.addTypes(resultType); - return success(); -} - -static void print(OpAsmPrinter &p, TensorFromElementsOp op) { - p << "tensor_from_elements " << op.elements(); - p.printOptionalAttrDict(op.getAttrs()); - p << " : " << op.getType(); -} - -static LogicalResult verify(TensorFromElementsOp op) { - auto resultTensorType = op.result().getType().dyn_cast(); - if (!resultTensorType) - return op.emitOpError("expected result type to be a ranked tensor"); - - int64_t elementsCount = static_cast(op.elements().size()); - if (resultTensorType.getRank() != 1 || - resultTensorType.getShape().front() != elementsCount) - return op.emitOpError() - << "expected result type to be a 1D tensor with " << elementsCount - << (elementsCount == 1 ? " element" : " elements"); - return success(); -} - void TensorFromElementsOp::build(OpBuilder &builder, OperationState &result, ValueRange elements) { assert(!elements.empty() && "expected at least one element"); - result.addOperands(elements); - result.addTypes(RankedTensorType::get({static_cast(elements.size())}, - *elements.getTypes().begin())); + Type resultTy = RankedTensorType::get({static_cast(elements.size())}, + elements.front().getType()); + build(builder, result, resultTy, elements); } namespace { diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -595,7 +595,7 @@ // ----- func @tensor_from_elements_wrong_result_type() { - // expected-error@+2 {{expected result type to be a ranked tensor}} + // expected-error@+2 {{'result' must be 1D tensor of any type values, but got 'tensor<*xi32>'}} %c0 = constant 0 : i32 %0 = tensor_from_elements %c0 : tensor<*xi32> return @@ -604,7 +604,7 @@ // ----- func @tensor_from_elements_wrong_elements_count() { - // expected-error@+2 {{expected result type to be a 1D tensor with 1 element}} + // expected-error@+2 {{1 operands present, but expected 2}} %c0 = constant 0 : index %0 = tensor_from_elements %c0 : tensor<2xindex> return