diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -766,10 +766,13 @@ concat([], [4,5,6]) -> [4,5,6] }]; - let arguments = (ins Shape_ShapeType:$lhs, Shape_ShapeType:$rhs); - let results = (outs Shape_ShapeType:$result); + let arguments = (ins Shape_ShapeOrExtentTensorType:$lhs, Shape_ShapeOrExtentTensorType:$rhs); + let results = (outs Shape_ShapeOrExtentTensorType:$result); + + let assemblyFormat = [{ + $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result) + }]; - let assemblyFormat = "$lhs `,` $rhs attr-dict"; let hasFolder = 1; } diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -179,12 +179,24 @@ // CHECK: shape.const_shape [0, 1, 2, 3] : !shape.shape %lhs = shape.const_shape [0, 1] : !shape.shape %rhs = shape.const_shape [2, 3] : !shape.shape - %0 = shape.concat %lhs, %rhs + %0 = shape.concat %lhs, %rhs : !shape.shape , !shape.shape -> !shape.shape return %0 : !shape.shape } // ----- +// Basic case. +// CHECK-LABEL: func @f +func.func @f() -> tensor<4xindex> { + // CHECK: shape.const_shape [0, 1, 2, 3] : tensor<4xindex> + %lhs = shape.const_shape [0, 1] : tensor<2xindex> + %rhs = shape.const_shape [2, 3] : tensor<2xindex> + %0 = shape.concat %lhs, %rhs : tensor<2xindex>, tensor<2xindex> -> tensor<4xindex> + return %0 : tensor<4xindex> +} + +// ----- + // Basic case. // CHECK-LABEL: func @f func.func @f() -> tensor<2xindex> {