diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -2219,6 +2219,10 @@ MemRefType srcType = getSrcType(); MemRefType resultType = getResultType(); + if (srcType.getRank() >= resultType.getRank()) + return emitOpError("expected rank expansion, but found source rank ") + << srcType.getRank() << " >= result rank " << resultType.getRank(); + // Verify result shape. if (failed(verifyCollapsedShape(getOperation(), srcType.getShape(), resultType.getShape(), @@ -2370,6 +2374,10 @@ MemRefType srcType = getSrcType(); MemRefType resultType = getResultType(); + if (srcType.getRank() <= resultType.getRank()) + return emitOpError("expected rank reduction, but found source rank ") + << srcType.getRank() << " <= result rank " << resultType.getRank(); + // Verify result shape. if (failed(verifyCollapsedShape(getOperation(), resultType.getShape(), srcType.getShape(), getReassociationIndices(), diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -1398,10 +1398,22 @@ } LogicalResult ExpandShapeOp::verify() { + auto srcType = getSrcType(); + auto resultType = getResultType(); + if (srcType.getRank() >= resultType.getRank()) + return emitOpError("expected rank expansion, but found source rank ") + << srcType.getRank() << " >= result rank " << resultType.getRank(); + return verifyTensorReshapeOp(*this, getResultType(), getSrcType()); } LogicalResult CollapseShapeOp::verify() { + auto srcType = getSrcType(); + auto resultType = getResultType(); + if (srcType.getRank() <= resultType.getRank()) + return emitOpError("expected rank reduction, but found source rank ") + << srcType.getRank() << " <= result rank " << resultType.getRank(); + return verifyTensorReshapeOp(*this, getSrcType(), getResultType()); } diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir --- a/mlir/test/Dialect/MemRef/invalid.mlir +++ b/mlir/test/Dialect/MemRef/invalid.mlir @@ -392,9 +392,9 @@ // ----- -func.func @expand_shape(%arg0: memref) { - // expected-error @+1 {{invalid number of reassociation groups: found 1, expected 0}} - %0 = memref.expand_shape %arg0 [[0]] : memref into memref +func.func @expand_shape(%arg0: memref) { + // expected-error @+1 {{invalid number of reassociation groups: found 1, expected 2}} + %0 = memref.expand_shape %arg0 [[0, 1]] : memref into memref return } @@ -408,16 +408,30 @@ // ----- -func.func @collapse_shape_to_higher_rank(%arg0: memref) { - // expected-error @+1 {{op reassociation index 0 is out of bounds}} - %0 = memref.collapse_shape %arg0 [[0]] : memref into memref<1xf32> +func.func @collapse_shape_out_of_bounds(%arg0: memref) { + // expected-error @+1 {{op reassociation index 2 is out of bounds}} + %0 = memref.collapse_shape %arg0 [[0, 1, 2]] : memref into memref +} + +// ----- + +func.func @expand_shape_invalid_ranks(%arg0: memref) { + // expected-error @+1 {{op expected rank expansion, but found source rank 2 >= result rank 2}} + %0 = memref.expand_shape %arg0 [[0], [1]] : memref into memref +} + +// ----- + +func.func @collapse_shape_invalid_ranks(%arg0: memref) { + // expected-error @+1 {{op expected rank reduction, but found source rank 2 <= result rank 2}} + %0 = memref.collapse_shape %arg0 [[0], [1]] : memref into memref } // ----- -func.func @expand_shape_to_smaller_rank(%arg0: memref<1xf32>) { - // expected-error @+1 {{op reassociation index 0 is out of bounds}} - %0 = memref.expand_shape %arg0 [[0]] : memref<1xf32> into memref +func.func @expand_shape_out_of_bounds(%arg0: memref) { + // expected-error @+1 {{op reassociation index 2 is out of bounds}} + %0 = memref.expand_shape %arg0 [[0, 1, 2]] : memref into memref<4x?xf32> } // ----- diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir --- a/mlir/test/Dialect/Tensor/invalid.mlir +++ b/mlir/test/Dialect/Tensor/invalid.mlir @@ -294,6 +294,20 @@ // ----- +func.func @expand_shape_invalid_ranks(%arg0: tensor) { + // expected-error @+1 {{op expected rank expansion, but found source rank 2 >= result rank 2}} + %0 = tensor.expand_shape %arg0 [[0], [1]] : tensor into tensor +} + +// ----- + +func.func @collapse_shape_invalid_ranks(%arg0: tensor) { + // expected-error @+1 {{op expected rank reduction, but found source rank 2 <= result rank 2}} + %0 = tensor.collapse_shape %arg0 [[0], [1]] : tensor into tensor +} + +// ----- + func.func @rank(%0: f32) { // expected-error@+1 {{'tensor.rank' op operand #0 must be tensor of any type values}} "tensor.rank"(%0): (f32)->index