diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -182,6 +182,67 @@ let hasCanonicalizer = 1; } +//===----------------------------------------------------------------------===// +// ReshapeOp +//===----------------------------------------------------------------------===// + +def Tensor_ReshapeOp: Tensor_Op<"reshape", [NoSideEffect]> { + let summary = "tensor reshape operation"; + let description = [{ + The `reshape` operation converts a tensor from one type to an equivalent + type with a provided shape. The source and destination types are compatible + if both have the same element type, same number of elements. The following + combinations are possible: + + a. Source type is ranked or unranked. Shape argument has static size. + Result type is ranked. + + ```mlir + // Reshape statically-shaped tensor. + %dst = tensor.reshape %src(%shape) + : (tensor<4x1xf32>, tensor<1xi32>) -> tensor<4xf32> + %dst0 = tensor.reshape %src(%shape0) + : (tensor<4x1xf32>, tensor<2xi32>) -> tensor<2x2xf32> + // Flatten unranked tensor. + %dst = tensor.reshape %src(%shape) + : (tensor<*xf32>, tensor<1xi32>) -> tensor + ``` + + b. Source type is ranked or unranked. Shape argument has dynamic size. + Result type is unranked. + + ```mlir + // Reshape dynamically-shaped 1D tensor. + %dst = tensor.reshape %src(%shape) + : (tensor, tensor) -> tensor<*xf32> + // Reshape unranked tensor. + %dst = tensor.reshape %src(%shape) + : (tensor<*xf32>, tensor) -> tensor<*xf32> + ``` + }]; + + let arguments = (ins + AnyTensor:$source, + TensorRankOf<[AnySignlessInteger, Index], [1]>:$shape + ); + let results = (outs AnyTensor:$result); + + let builders = [OpBuilder< + (ins "TensorType":$resultType, "Value":$operand, "Value":$shape), [{ + $_state.addOperands(operand); + $_state.addOperands(shape); + $_state.addTypes(resultType); + }]>]; + + let extraClassDeclaration = [{ + TensorType getResultType() { return getResult().getType().cast(); } + }]; + + let assemblyFormat = [{ + $source `(` $shape `)` attr-dict `:` functional-type(operands, results) + }]; +} + //===----------------------------------------------------------------------===// // YieldOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -442,6 +442,47 @@ StaticTensorGenerate>(context); } +//===----------------------------------------------------------------------===// +// ReshapeOp +//===----------------------------------------------------------------------===// + +static int64_t GetNumElements(ShapedType type) { + int64_t numElements = 1; + for (auto dim : type.getShape()) + numElements *= dim; + return numElements; +} + +static LogicalResult verify(ReshapeOp op) { + TensorType operandType = op.source().getType().cast(); + TensorType resultType = op.result().getType().cast(); + + if (operandType.getElementType() != resultType.getElementType()) + return op.emitOpError("element types of source and destination tensor " + "types should be the same"); + + int64_t shapeSize = + op.shape().getType().cast().getDimSize(0); + auto resultRankedType = resultType.dyn_cast(); + auto operandRankedType = operandType.dyn_cast(); + + if (resultRankedType) { + if (operandRankedType && resultRankedType.hasStaticShape() && + operandRankedType.hasStaticShape()) { + if (GetNumElements(operandRankedType) != GetNumElements(resultRankedType)) + return op.emitOpError("source and destination tensor should have the " + "same number of elements"); + } + if (shapeSize == TensorType::kDynamicSize) + return op.emitOpError("cannot use shape operand with dynamic length to " + "reshape to statically-ranked tensor type"); + if (shapeSize != resultRankedType.getRank()) + return op.emitOpError( + "length of shape operand differs from the result's tensor rank"); + } + return success(); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir --- a/mlir/test/Dialect/Tensor/invalid.mlir +++ b/mlir/test/Dialect/Tensor/invalid.mlir @@ -97,3 +97,36 @@ } : tensor return %tnsr : tensor } +// ----- + +func @tensor.reshape_element_type_mismatch( + %buf: tensor<*xf32>, %shape: tensor<1xi32>) { + // expected-error @+1 {{element types of source and destination tensor types should be the same}} + tensor.reshape %buf(%shape) : (tensor<*xf32>, tensor<1xi32>) -> tensor +} + +// ----- + +func @tensor.reshape_dst_ranked_shape_unranked( + %buf: tensor<*xf32>, %shape: tensor) { + // expected-error @+1 {{cannot use shape operand with dynamic length to reshape to statically-ranked tensor type}} + tensor.reshape %buf(%shape) : (tensor<*xf32>, tensor) -> tensor +} + +// ----- + +func @tensor.reshape_dst_shape_rank_mismatch( + %buf: tensor<*xf32>, %shape: tensor<1xi32>) { + // expected-error @+1 {{length of shape operand differs from the result's tensor rank}} + tensor.reshape %buf(%shape) + : (tensor<*xf32>, tensor<1xi32>) -> tensor +} + +// ----- + +func @tensor.reshape_num_elements_mismatch( + %buf: tensor<1xf32>, %shape: tensor<1xi32>) { + // expected-error @+1 {{source and destination tensor should have the same number of elements}} + tensor.reshape %buf(%shape) + : (tensor<1xf32>, tensor<1xi32>) -> tensor<10xf32> +} diff --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir --- a/mlir/test/Dialect/Tensor/ops.mlir +++ b/mlir/test/Dialect/Tensor/ops.mlir @@ -53,3 +53,15 @@ } : tensor return %tnsr : tensor } + +// CHECK-LABEL: func @tensor_reshape +func @tensor_reshape(%unranked: tensor<*xf32>, %shape1: tensor<1xi32>, + %shape2: tensor<2xi32>, %shape3: tensor) -> tensor<*xf32> { + %dyn_vec = tensor.reshape %unranked(%shape1) + : (tensor<*xf32>, tensor<1xi32>) -> tensor + %dyn_mat = tensor.reshape %dyn_vec(%shape2) + : (tensor, tensor<2xi32>) -> tensor + %new_unranked = tensor.reshape %dyn_mat(%shape3) + : (tensor, tensor) -> tensor<*xf32> + return %new_unranked : tensor<*xf32> +}