diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -111,6 +111,13 @@ rank to obtain the memref with the smaller rank. In the case of a dimension expansion, the reassociation maps can be interpreted as inverse maps. + The result memref type of a reshape when dimensions are collapsed + (operand memref type when dimensions are expanded) can be + zero-ranked if the operand memref type (or the result memref type + when dimensions are expanded) is statically shaped with all + dimensions being unit extent. In such cases the reassociation map + is empty. + Examples: ```mlir @@ -152,6 +159,13 @@ rank to obtain the tensor with the smaller rank. In the case of a dimension expansion, the reassociation maps can be interpreted as inverse maps. + The result tensor type of a reshape when dimensions are collapsed + (operand tensor type when dimensions are expanded) can be + zero-ranked if the operand tensor type (or the result tensor type + when dimensions are expanded) is statically shaped with all + dimensions being unit extent. In such cases the reassociation map + is empty. + Examples: ```mlir diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -438,11 +438,21 @@ std::swap(expandedRank, collapsedRank); std::swap(expandedType, collapsedType); } - if (expandedRank == 0 || collapsedRank == 0) + if (expandedRank == 0) return op.emitOpError("expected non-zero memref ranks"); if (expandedRank == collapsedRank) return op.emitOpError("expected to collapse or expand dims"); + if (collapsedRank == 0) { + // If collapsed rank is 0, then expanded type must be static shaped and of + // sizes 1. + if (llvm::any_of(expandedType.getShape(), + [](int64_t dim) -> bool { return dim != 1; })) + return op.emitOpError( + "invalid to reshape tensor/memref with non-unit extent dimensions to " + "zero-rank tensor/memref"); + return success(); + } if (collapsedRank != op.reassociation().size()) return op.emitOpError("expected rank of the collapsed type(") << collapsedRank << ") to be the number of reassociation maps(" diff --git a/mlir/test/Dialect/Linalg/llvm.mlir b/mlir/test/Dialect/Linalg/llvm.mlir --- a/mlir/test/Dialect/Linalg/llvm.mlir +++ b/mlir/test/Dialect/Linalg/llvm.mlir @@ -273,3 +273,32 @@ // CHECK: llvm.insertvalue {{.*}}[4, 1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> // CHECK: llvm.mlir.constant(1 : index) : !llvm.i64 // CHECK: llvm.insertvalue {{.*}}[4, 2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> + +func @reshape_zero_dim(%arg0 : memref<1x1xf32>) { + %0 = linalg.reshape %arg0 [] : memref<1x1xf32> into memref + %1 = linalg.reshape %0 [] : memref into memref<1x1xf32> + return +} +// CHECK-LABEL: func @reshape_zero_dim +// CHECK: llvm.mlir.undef : !llvm<"{ float*, float*, i64 }"> +// CHECK: llvm.extractvalue %{{.*}}[0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, float*, i64 }"> +// CHECK: llvm.extractvalue %{{.*}}[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"{ float*, float*, i64 }"> +// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm<"{ float*, float*, i64 }"> +// CHECK: llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// CHECK: llvm.extractvalue %{{.*}}[0] : !llvm<"{ float*, float*, i64 }"> +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// CHECK: llvm.extractvalue %{{.*}}[1] : !llvm<"{ float*, float*, i64 }"> +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm<"{ float*, float*, i64 }"> +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64 +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64 +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64 +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64 +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -630,3 +630,26 @@ // CHECK: linalg.batch_matmul // CHECK: linalg.batch_matmul +// ----- + +func @tensor_reshape_zero_dim(%arg0 : tensor<1x1xf32>, %arg1 : tensor) -> (tensor, tensor<1x1xf32>) +{ + %0 = linalg.tensor_reshape %arg0 [] : tensor<1x1xf32> into tensor + %1 = linalg.tensor_reshape %0 [] : tensor into tensor<1x1xf32> + return %0, %1 : tensor, tensor<1x1xf32> +} +// CHECK-LABEL: func @tensor_reshape_zero_dim +// CHECK: linalg.tensor_reshape %{{.*}} [] : tensor<1x1xf32> into tensor +// CHECK: linalg.tensor_reshape %{{.*}} [] : tensor into tensor<1x1xf32> + +// ----- + +func @memref_reshape_zero_dim(%arg0 : memref<1x1xf32>, %arg1 : memref) -> (memref, memref<1x1xf32>) +{ + %0 = linalg.reshape %arg0 [] : memref<1x1xf32> into memref + %1 = linalg.reshape %0 [] : memref into memref<1x1xf32> + return %0, %1 : memref, memref<1x1xf32> +} +// CHECK-LABEL: func @memref_reshape_zero_dim +// CHECK: linalg.reshape %{{.*}} [] : memref<1x1xf32> into memref +// CHECK: linalg.reshape %{{.*}} [] : memref into memref<1x1xf32>