diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -1524,6 +1524,39 @@ }]; } +//===----------------------------------------------------------------------===// +// TensorFromElementsOp +//===----------------------------------------------------------------------===// + +def TensorFromElementsOp : Std_Op<"tensor_from_elements", + [NoSideEffect, SameOperandsAndResultElementType]> { + string summary = "tensor from elements operation."; + string description = [{ + Create a 1D tensor from a range of same-type arguments. + + Example: + + ```mlir + tensor_from_elements(i_1, ..., i_N) : tensor + ``` + }]; + + let arguments = (ins Variadic:$elements); + let results = (outs AnyTensor:$result); + + let skipDefaultBuilders = 1; + let builders = [OpBuilder< + "OpBuilder &builder, OperationState &result, ValueRange elements", [{ + assert(!elements.empty() && "expected at least one element"); + result.addOperands(elements); + result.addTypes( + RankedTensorType::get({static_cast(elements.size())}, + *elements.getTypes().begin())); + }]>]; + + let hasCanonicalizer = 1; +} + //===----------------------------------------------------------------------===// // FPExtOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -1640,6 +1640,86 @@ return {}; } +//===----------------------------------------------------------------------===// +// TensorFromElementsOp +//===----------------------------------------------------------------------===// + +static ParseResult parseTensorFromElementsOp(OpAsmParser &parser, + OperationState &result) { + SmallVector elementsOperands; + Type resultType; + if (parser.parseLParen() || parser.parseOperandList(elementsOperands) || + parser.parseRParen() || parser.parseOptionalAttrDict(result.attributes) || + parser.parseColon() || parser.parseType(resultType)) + return failure(); + + if (parser.resolveOperands(elementsOperands, + resultType.cast().getElementType(), + result.operands)) + return failure(); + + result.addTypes(resultType); + return success(); +} + +static void print(OpAsmPrinter &p, TensorFromElementsOp op) { + p << "tensor_from_elements(" << op.elements() << ')'; + p.printOptionalAttrDict(op.getAttrs()); + p << " : " << op.result().getType(); +} + +static LogicalResult verify(TensorFromElementsOp op) { + auto resultTensorType = op.result().getType().dyn_cast(); + if (!resultTensorType) + return op.emitOpError("expected result type to be a ranked tensor"); + + int64_t elementsCount = static_cast(op.elements().size()); + if (resultTensorType.getRank() != 1 || + resultTensorType.getShape().front() != elementsCount) + return op.emitOpError() + << "expected result type to be a 1D tensor with " << elementsCount + << (elementsCount == 1 ? " element" : " elements"); + return success(); +} + +namespace { + +// Canonicalizes the pattern of the form +// +// %tensor = "tensor_from_elements(%element) : (i32) -> tensor<1xi32> +// %extracted_element = extract_element %tensor[%c0] : tensor<1xi32> +// +// to just %element. +struct ExtractElementFromTensorFromElements + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ExtractElementOp extract, + PatternRewriter &rewriter) const final { + if (extract.indices().size() != 1) + return failure(); + + auto tensor_from_elements = + dyn_cast(extract.aggregate().getDefiningOp()); + if (tensor_from_elements == nullptr) + return failure(); + + APInt index; + if (!matchPattern(*extract.indices().begin(), m_ConstantInt(&index))) + return failure(); + rewriter.replaceOp(extract, + tensor_from_elements.getOperand(index.getZExtValue())); + return success(); + } +}; + +} // namespace + +void TensorFromElementsOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // FPExtOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir --- a/mlir/test/IR/core-ops.mlir +++ b/mlir/test/IR/core-ops.mlir @@ -644,6 +644,24 @@ return %0 : i32 } +// CHECK-LABEL: func @tensor_from_elements() { +func @tensor_from_elements() { + %c0 = "std.constant"() {value = 0: index} : () -> index + // CHECK: %0 = tensor_from_elements(%c0) : tensor<1xindex> + %0 = tensor_from_elements(%c0) : tensor<1xindex> + + %c1 = "std.constant"() {value = 1: index} : () -> index + // CHECK: %1 = tensor_from_elements(%c0, %c1) : tensor<2xindex> + %1 = tensor_from_elements(%c0, %c1) : tensor<2xindex> + + %c0_f32 = "std.constant"() {value = 0.0: f32} : () -> f32 + // CHECK: [[C0_F32:%.*]] = constant + // CHECK: %2 = tensor_from_elements([[C0_F32]]) : tensor<1xf32> + %2 = tensor_from_elements(%c0_f32) : tensor<1xf32> + + return +} + // CHECK-LABEL: func @tensor_cast(%arg0 func @tensor_cast(%arg0: tensor<*xf32>, %arg1 : tensor<4x4xf32>, %arg2: tensor) { // CHECK: %0 = tensor_cast %arg0 : tensor<*xf32> to tensor diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -605,7 +605,24 @@ func @extract_element_tensor_too_few_indices(%t : tensor<2x3xf32>, %i : index) { // expected-error@+1 {{incorrect number of indices for extract_element}} - %0 = "std.extract_element"(%t, %i) : (tensor<2x3xf32>, index) -> f32 + %0 = "std.extract_element"(%t, %i) : (tensor<2x3xf32>, index) -> f32 return +} + +// ----- + +func @tensor_from_elements_wrong_result_type() { + // expected-error@+2 {{expected result type to be a ranked tensor}} + %c0 = constant 0 : i32 + %0 = tensor_from_elements(%c0) : tensor<*xi32> + return +} + +// ----- + +func @tensor_from_elements_wrong_elements_count() { + // expected-error@+2 {{expected result type to be a 1D tensor with 1 element}} + %c0 = constant 0 : index + %0 = tensor_from_elements(%c0) : tensor<2xindex> return } diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir --- a/mlir/test/Transforms/canonicalize.mlir +++ b/mlir/test/Transforms/canonicalize.mlir @@ -971,3 +971,15 @@ // CHECK: memref_cast{{.*}}: memref<3x4xf32, #[[map0]]> to memref<3x4xf32, #[[map1]]> return %1: memref<3x4xf32, offset:?, strides:[?, 1]> } + +// ----- + +// CHECK-LABEL: func @extract_element_from_tensor_from_elements +func @extract_element_from_tensor_from_elements(%element : index) -> index { + // CHECK-SAME: ([[ARG:%.*]]: index) + %c0 = constant 0 : index + %tensor = tensor_from_elements(%element) : tensor<1xindex> + %extracted_element = extract_element %tensor[%c0] : tensor<1xindex> + // CHECK: [[ARG]] : index + return %extracted_element : index +}