diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -2071,6 +2071,80 @@ let summary = "floating point division remainder operation"; } +//===----------------------------------------------------------------------===// +// ReshapeMemRefCastOp +//===----------------------------------------------------------------------===// + +def ReshapeMemRefCastOp: Std_Op<"reshape_memref_cast", [ + DeclareOpInterfaceMethods, + NoSideEffect]> { + let summary = "reshape memref cast operation"; + let description = [{ + The `reshape_memref_cast` operation converts a memref from one type to an + equivalent type with a provided shape. The data is never copied or moded. + The source and destination types are compatible if both have the same + element type, address space and identity layout map. The following + combinations are possible: + + a. Both are ranked memref types. + + ```mlir + // Reshape statically-shaped memref. + %dst = reshape_memref_cast %src(%shape) + : (memref<4x1xf32>, memref<1xi32>) to memref<4xf32> + %dst0 = reshape_memref_cast %src(%shape0) + : (memref<4x1xf32>, memref<2xi32>) to memref<2x2xf32> + ``` + + b. Source type is ranked, destination type is unranked. + + ```mlir + // Reshape dynamically-shaped 1D memref. + %dst = reshape_memref_cast %src(%shape) + : (memref, memref) to memref<*xf32> + ``` + + c. Source type is unranked, destination type is ranked. + + ```mlir + // Flatten unranked memref. + %dst = reshape_memref_cast %src(%shape) + : (memref<*xf32>, memref<1xi32>) to memref + ``` + + d. Both are unranked memref types. + + ```mlir + // Reshape unranked memref. + %dst = reshape_memref_cast %src(%shape) + : (memref<*xf32>, memref) to memref<*xf32> + ``` + }]; + + let arguments = (ins + AnyRankedOrUnrankedMemRef:$operand, + MemRefRankOf<[AnySignlessInteger], [1]>:$shape + ); + let results = (outs AnyRankedOrUnrankedMemRef:$result); + + let builders = [OpBuilder< + "OpBuilder &builder, OperationState &result, MemRefType resultType, " # + "Value operand, Value shape", [{ + result.addOperands(operand); + result.addOperands(shape); + result.types.push_back(resultType); + }]>]; + + let extraClassDeclaration = [{ + MemRefType getType() { return getResult().getType().cast(); } + }]; + + let assemblyFormat = [{ + $operand `(` $shape `)` attr-dict `:` `(` type($operand) `,` type($shape) + `)` `->` type($result) + }]; +} + //===----------------------------------------------------------------------===// // ReturnOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -2020,6 +2020,44 @@ return IntegerAttr(); } +//===----------------------------------------------------------------------===// +// ReshapeMemrefCastOp +//===----------------------------------------------------------------------===// + +Value ReshapeMemRefCastOp::getViewSource() { return operand(); } + +static LogicalResult verify(ReshapeMemRefCastOp op) { + Type operandType = op.operand().getType(); + Type resultType = op.result().getType(); + + Type operandElementType = operandType.cast().getElementType(); + Type resultElementType = resultType.cast().getElementType(); + if (operandElementType != resultElementType) + return op.emitOpError("element types of source and destination memref " + "types should be the same"); + + if (auto operandMemRefType = operandType.dyn_cast()) + if (!operandMemRefType.getAffineMaps().empty()) + return op.emitOpError( + "operand memref type should have identity affine map"); + + int64_t shapeSize = op.shape().getType().cast().getDimSize(0); + auto resultMemRefType = resultType.dyn_cast(); + if (resultMemRefType) { + if (shapeSize == ShapedType::kDynamicSize) + return op.emitOpError("cannot use shape operand with dynamic length to " + "cast statically-ranked memref type"); + if (shapeSize != resultMemRefType.getRank()) + return op.emitOpError( + "length of shape operand differs from the result's memref rank"); + + if (!resultMemRefType.getAffineMaps().empty()) + return op.emitOpError( + "result memref type should have identity affine map"); + } + return success(); +} + //===----------------------------------------------------------------------===// // ReturnOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Standard/invalid.mlir b/mlir/test/Dialect/Standard/invalid.mlir --- a/mlir/test/Dialect/Standard/invalid.mlir +++ b/mlir/test/Dialect/Standard/invalid.mlir @@ -15,3 +15,45 @@ %0 = index_cast %arg0 : tensor to i64 return %0 : i64 } + +// ----- + +// CHECK-LABEL: reshape_memref_cast_element_type_mismatch +func @reshape_memref_cast_element_type_mismatch( + %buf: memref<*xf32>, %shape: memref<1xi32>) { + // expected-error @+1 {{element types of source and destination memref types should be the same}} + reshape_memref_cast %buf(%shape) + : (memref<*xf32>, memref<1xi32>) -> memref +} + +// ----- + +// CHECK-LABEL: reshape_memref_cast_dst_ranked_shape_unranked +func @reshape_memref_cast_dst_ranked_shape_unranked( + %buf: memref<*xf32>, %shape: memref) { + // expected-error @+1 {{cannot use shape operand with dynamic length to cast statically-ranked memref type}} + reshape_memref_cast %buf(%shape) + : (memref<*xf32>, memref) -> memref +} + +// ----- + +// CHECK-LABEL: reshape_memref_cast_dst_shape_rank_mismatch +func @reshape_memref_cast_dst_shape_rank_mismatch( + %buf: memref<*xf32>, %shape: memref<1xi32>) { + // expected-error @+1 {{length of shape operand differs from the result's memref rank}} + reshape_memref_cast %buf(%shape) + : (memref<*xf32>, memref<1xi32>) -> memref +} + +// ----- + +// CHECK-LABEL: reshape_memref_cast_affine_map_is_not_identity +func @reshape_memref_cast_affine_map_is_not_identity( + %buf: memref<4x4xf32, offset: 0, strides: [3, 2]>, + %shape: memref<1xi32>) { + // expected-error @+1 {{operand memref type should have identity affine map}} + reshape_memref_cast %buf(%shape) + : (memref<4x4xf32, offset: 0, strides: [3, 2]>, memref<1xi32>) + -> memref<8xf32> +} diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir --- a/mlir/test/IR/core-ops.mlir +++ b/mlir/test/IR/core-ops.mlir @@ -703,6 +703,29 @@ return } +// CHECK-LABEL: func @reshape_memref_cast( +func @reshape_memref_cast(%unranked: memref<*xf32>, %shape1: memref<1xi32>, + %shape2: memref<2xi32>, %shape3: memref) { + // CHECK-SAME: [[UNRANKED:%.*]]: memref<*xf32>, [[SHAPE_1:%.*]]: memref<1xi32>, + // CHECK-SAME: [[SHAPE_2:%.*]]: memref<2xi32>, [[SHAPE_3:%.*]]: memref + + // CHECK-NEXT: [[DYN_VEC:%.*]] = reshape_memref_cast [[UNRANKED]] + // CHECK-SAME: : (memref<*xf32>, memref<1xi32>) -> memref + %dyn_vec = reshape_memref_cast %unranked(%shape1) + : (memref<*xf32>, memref<1xi32>) -> memref + + // CHECK-NEXT: [[DYN_MAT:%.*]] = reshape_memref_cast [[DYN_VEC]] + // CHECK-SAME: : (memref, memref<2xi32>) -> memref + %dyn_mat = reshape_memref_cast %dyn_vec(%shape2) + : (memref, memref<2xi32>) -> memref + + // CHECK-NEXT: {{%.*}} = reshape_memref_cast [[DYN_MAT]] + // CHECK-SAME: : (memref, memref) -> memref<*xf32> + %new_unranked = reshape_memref_cast %dyn_mat(%shape3) + : (memref, memref) -> memref<*xf32> + return +} + // CHECK-LABEL: func @memref_view(%arg0 func @memref_view(%arg0 : index, %arg1 : index, %arg2 : index) { %0 = alloc() : memref<2048xi8>