diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -1613,7 +1613,6 @@ def TensorFromElementsOp : Std_Op<"tensor_from_elements", [ NoSideEffect, - SameOperandsAndResultElementType, TypesMatchWith<"operand types match result element type", "result", "elements", "SmallVector(" "$_self.cast().getDimSize(0), " @@ -1638,7 +1637,11 @@ // This op is fully verified by its traits. let verifier = ?; + let skipDefaultBuilders = 1; let builders = [ + OpBuilder<"OpBuilder &b, OperationState &result, Type elementType," + "ValueRange elements">, + // Special case builder for when `elements` has size >=1. OpBuilder<"OpBuilder &b, OperationState &result, ValueRange elements"> ]; diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -182,8 +182,9 @@ extentOperands.push_back( rewriter.create(loc, extent.getLimitedValue())); } - Value tensor = rewriter.create(loc, extentOperands); Type indexTy = rewriter.getIndexType(); + Value tensor = + rewriter.create(loc, indexTy, extentOperands); Type resultTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy); rewriter.replaceOpWithNewOp(op, tensor, resultTy); return success(); @@ -444,8 +445,8 @@ } // Materialize extent tensor. - Value staticExtentTensor = - rewriter.create(loc, extentValues); + Value staticExtentTensor = rewriter.create( + loc, rewriter.getIndexType(), extentValues); rewriter.replaceOpWithNewOp(op, staticExtentTensor, op.getType()); return success(); diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -1756,12 +1756,18 @@ // TensorFromElementsOp //===----------------------------------------------------------------------===// +void TensorFromElementsOp::build(OpBuilder &builder, OperationState &result, + Type elementType, ValueRange elements) { + Type resultTy = RankedTensorType::get({static_cast(elements.size())}, + elementType); + result.addOperands(elements); + result.addTypes(resultTy); +} + void TensorFromElementsOp::build(OpBuilder &builder, OperationState &result, ValueRange elements) { assert(!elements.empty() && "expected at least one element"); - Type resultTy = RankedTensorType::get({static_cast(elements.size())}, - elements.front().getType()); - build(builder, result, resultTy, elements); + build(builder, result, elements.front().getType(), elements); } namespace { diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir --- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir +++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir @@ -103,6 +103,19 @@ // ----- +// Lower `const_shape` in the case of rank 0. +// CHECK-LABEL: func @const_shape_zero_elements +// CHECK-SAME: () -> tensor +func @const_shape_zero_elements() -> tensor { + // CHECK: %[[TENSOR:.*]] = tensor_from_elements : tensor<0xindex> + // CHECK: %[[RESULT:.*]] = tensor_cast %[[TENSOR]] : tensor<0xindex> to tensor + // CHECK: return %[[RESULT]] : tensor + %shape = shape.const_shape [] : tensor + return %shape : tensor +} + +// ----- + // Lower `any` to its first operand. // CHECK-LABEL: @any_of_three // CHECK-SAME: (%[[A:.*]]: tensor, %[[B:.*]]: tensor, %[[C:.*]]: tensor) -> tensor @@ -227,6 +240,17 @@ // ----- +// Lower `shape_of` for 0-D tensor. +// CHECK-LABEL: @shape_of_zero_d +// CHECK-SAME: (%[[ARG:.*]]: tensor) +func @shape_of_zero_d(%arg : tensor) { + // CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor_from_elements : tensor<0xindex> + %shape = shape.shape_of %arg : tensor -> tensor + return +} + +// ----- + // Lower `shape_of` for dynamically shaped tensor. // CHECK-LABEL: @shape_of_dyn // CHECK-SAME: (%[[ARG:.*]]: tensor<1x5x?xf32>) diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir --- a/mlir/test/IR/core-ops.mlir +++ b/mlir/test/IR/core-ops.mlir @@ -673,6 +673,9 @@ // CHECK: %2 = tensor_from_elements [[C0_F32]] : tensor<1xf32> %2 = tensor_from_elements %c0_f32 : tensor<1xf32> + // CHECK: tensor_from_elements : tensor<0xindex> + %3 = tensor_from_elements : tensor<0xindex> + return }