diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -2020,6 +2020,72 @@ }]; } +//===----------------------------------------------------------------------===// +// MemRefReshapeOp +//===----------------------------------------------------------------------===// + +def MemRefReshapeOp: Std_Op<"memref_reshape", [ + DeclareOpInterfaceMethods, + NoSideEffect]> { + let summary = "memref reshape operation"; + let description = [{ + The `memref_reshape` operation converts a memref from one type to an + equivalent type with a provided shape. The data is never copied or + modified. The source and destination types are compatible if both have the + same element type, address space and identity layout map. The following + combinations are possible: + + a. Source type is ranked or ranked. Shape argument has static length. + Result type is ranked. + + ```mlir + // Reshape statically-shaped memref. + %dst = memref_reshape %src(%shape) + : (memref<4x1xf32>, memref<1xi32>) to memref<4xf32> + %dst0 = memref_reshape %src(%shape0) + : (memref<4x1xf32>, memref<2xi32>) to memref<2x2xf32> + // Flatten unranked memref. + %dst = memref_reshape %src(%shape) + : (memref<*xf32>, memref<1xi32>) to memref + ``` + + a. Source type is ranked or ranked. Shape argument has dynamic length. + Result type is unranked. + + ```mlir + // Reshape dynamically-shaped 1D memref. + %dst = memref_reshape %src(%shape) + : (memref, memref) to memref<*xf32> + // Reshape unranked memref. + %dst = memref_reshape %src(%shape) + : (memref<*xf32>, memref) to memref<*xf32> + ``` + }]; + + let arguments = (ins + AnyRankedOrUnrankedMemRef:$source, + MemRefRankOf<[AnySignlessInteger], [1]>:$shape + ); + let results = (outs AnyRankedOrUnrankedMemRef:$result); + + let builders = [OpBuilder< + "OpBuilder &builder, OperationState &result, MemRefType resultType, " # + "Value operand, Value shape", [{ + result.addOperands(operand); + result.addOperands(shape); + result.types.push_back(resultType); + }]>]; + + let extraClassDeclaration = [{ + MemRefType getType() { return getResult().getType().cast(); } + }]; + + let assemblyFormat = [{ + $source `(` $shape `)` attr-dict `:` `(` type($source) `,` type($shape) + `)` `->` type($result) + }]; +} + //===----------------------------------------------------------------------===// // MulFOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -2145,6 +2145,43 @@ return impl::foldCastOp(*this); } +//===----------------------------------------------------------------------===// +// MemRefReshapeOp +//===----------------------------------------------------------------------===// + +Value MemRefReshapeOp::getViewSource() { return source(); } + +static LogicalResult verify(MemRefReshapeOp op) { + Type operandType = op.source().getType(); + Type resultType = op.result().getType(); + + Type operandElementType = operandType.cast().getElementType(); + Type resultElementType = resultType.cast().getElementType(); + if (operandElementType != resultElementType) + return op.emitOpError("element types of source and destination memref " + "types should be the same"); + + if (auto operandMemRefType = operandType.dyn_cast()) + if (!operandMemRefType.getAffineMaps().empty()) + return op.emitOpError( + "source memref type should have identity affine map"); + + int64_t shapeSize = op.shape().getType().cast().getDimSize(0); + auto resultMemRefType = resultType.dyn_cast(); + if (resultMemRefType) { + if (!resultMemRefType.getAffineMaps().empty()) + return op.emitOpError( + "result memref type should have identity affine map"); + if (shapeSize == ShapedType::kDynamicSize) + return op.emitOpError("cannot use shape operand with dynamic length to " + "reshape to statically-ranked memref type"); + if (shapeSize != resultMemRefType.getRank()) + return op.emitOpError( + "length of shape operand differs from the result's memref rank"); + } + return success(); +} + //===----------------------------------------------------------------------===// // MulFOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Standard/invalid.mlir b/mlir/test/Dialect/Standard/invalid.mlir --- a/mlir/test/Dialect/Standard/invalid.mlir +++ b/mlir/test/Dialect/Standard/invalid.mlir @@ -102,3 +102,43 @@ // expected-error @+1 {{output type 'memref (d0 * s1 + s0 + d1)>>' does not match transposed input type 'memref (d0 * s1 + s0 + d1)>>'}} transpose %v (i, j) -> (j, i) : memref(off + M * i + j)>> to memref(off + M * i + j)>> } + +// ----- + +// CHECK-LABEL: memref_reshape_element_type_mismatch +func @memref_reshape_element_type_mismatch( + %buf: memref<*xf32>, %shape: memref<1xi32>) { + // expected-error @+1 {{element types of source and destination memref types should be the same}} + memref_reshape %buf(%shape) : (memref<*xf32>, memref<1xi32>) -> memref +} + +// ----- + +// CHECK-LABEL: memref_reshape_dst_ranked_shape_unranked +func @memref_reshape_dst_ranked_shape_unranked( + %buf: memref<*xf32>, %shape: memref) { + // expected-error @+1 {{cannot use shape operand with dynamic length to reshape to statically-ranked memref type}} + memref_reshape %buf(%shape) : (memref<*xf32>, memref) -> memref +} + +// ----- + +// CHECK-LABEL: memref_reshape_dst_shape_rank_mismatch +func @memref_reshape_dst_shape_rank_mismatch( + %buf: memref<*xf32>, %shape: memref<1xi32>) { + // expected-error @+1 {{length of shape operand differs from the result's memref rank}} + memref_reshape %buf(%shape) + : (memref<*xf32>, memref<1xi32>) -> memref +} + +// ----- + +// CHECK-LABEL: memref_reshape_affine_map_is_not_identity +func @memref_reshape_affine_map_is_not_identity( + %buf: memref<4x4xf32, offset: 0, strides: [3, 2]>, + %shape: memref<1xi32>) { + // expected-error @+1 {{source memref type should have identity affine map}} + memref_reshape %buf(%shape) + : (memref<4x4xf32, offset: 0, strides: [3, 2]>, memref<1xi32>) + -> memref<8xf32> +} diff --git a/mlir/test/Dialect/Standard/ops.mlir b/mlir/test/Dialect/Standard/ops.mlir --- a/mlir/test/Dialect/Standard/ops.mlir +++ b/mlir/test/Dialect/Standard/ops.mlir @@ -54,3 +54,15 @@ %result = atan2 %arg0, %arg1 : f32 return %result : f32 } + +// CHECK-LABEL: func @memref_reshape( +func @memref_reshape(%unranked: memref<*xf32>, %shape1: memref<1xi32>, + %shape2: memref<2xi32>, %shape3: memref) -> memref<*xf32> { + %dyn_vec = memref_reshape %unranked(%shape1) + : (memref<*xf32>, memref<1xi32>) -> memref + %dyn_mat = memref_reshape %dyn_vec(%shape2) + : (memref, memref<2xi32>) -> memref + %new_unranked = memref_reshape %dyn_mat(%shape3) + : (memref, memref) -> memref<*xf32> + return %new_unranked : memref<*xf32> +}