Index: mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td =================================================================== --- mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -766,10 +766,13 @@ concat([], [4,5,6]) -> [4,5,6] }]; - let arguments = (ins Shape_ShapeType:$lhs, Shape_ShapeType:$rhs); - let results = (outs Shape_ShapeType:$result); + let arguments = (ins Shape_ShapeOrExtentTensorType:$lhs, Shape_ShapeOrExtentTensorType:$rhs); + let results = (outs Shape_ShapeOrExtentTensorType:$result); + + let assemblyFormat = [{ + $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result) + }]; - let assemblyFormat = "$lhs `,` $rhs attr-dict"; let hasFolder = 1; } Index: mlir/test/Dialect/Shape/canonicalize.mlir =================================================================== --- mlir/test/Dialect/Shape/canonicalize.mlir +++ mlir/test/Dialect/Shape/canonicalize.mlir @@ -179,7 +179,7 @@ // CHECK: shape.const_shape [0, 1, 2, 3] : !shape.shape %lhs = shape.const_shape [0, 1] : !shape.shape %rhs = shape.const_shape [2, 3] : !shape.shape - %0 = shape.concat %lhs, %rhs + %0 = shape.concat %lhs, %rhs : !shape.shape , !shape.shape -> !shape.shape return %0 : !shape.shape } @@ -187,6 +187,18 @@ // Basic case. // CHECK-LABEL: func @f +func.func @f() -> tensor<4xindex> { + // CHECK: shape.const_shape [0, 1, 2, 3] : tensor<4xindex> + %lhs = shape.const_shape [0, 1] : tensor<2xindex> + %rhs = shape.const_shape [2, 3] : tensor<2xindex> + %0 = shape.concat %lhs, %rhs : tensor<2xindex>, tensor<2xindex> -> tensor<4xindex> + return %0 : tensor<4xindex> +} + +// ----- + +// Basic case. +// CHECK-LABEL: func @f func.func @f() -> tensor<2xindex> { // CHECK: shape.const_shape [0, 1] : tensor<2xindex> %cs = shape.const_shape [0, 1] : !shape.shape