diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -90,6 +90,8 @@ OptionalAttr:$error); let results = (outs Shape_ShapeType:$result); + let assemblyFormat = "$lhs `,` $rhs attr-dict"; + let hasFolder = 1; } @@ -472,6 +474,7 @@ let arguments = (ins Shape_ShapeType:$lhs, Shape_ShapeType:$rhs); let results = (outs Shape_ShapeType:$result); + let assemblyFormat = "$lhs `,` $rhs attr-dict"; let hasFolder = 1; } diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -237,12 +237,17 @@ //===----------------------------------------------------------------------===// OpFoldResult BroadcastOp::fold(ArrayRef operands) { - if (!operands[0] || !operands[1]) + if (!operands[1]) return nullptr; - auto lhsShape = llvm::to_vector<6>( - operands[0].cast().getValues()); + auto rhsShape = llvm::to_vector<6>( operands[1].cast().getValues()); + if (rhsShape.empty()) + return lhs(); + + if (!operands[0]) + return nullptr; + SmallVector resultShape; // If the shapes are not compatible, we can't fold it. // TODO: Fold to an "error". diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -54,19 +54,41 @@ // CHECK: shape.const_shape [7, 2] %0 = shape.const_shape [1, 2] %1 = shape.const_shape [7, 1] - %2 = "shape.broadcast"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape + %2 = shape.broadcast %0, %1 return %2 : !shape.shape } // ----- +// Rhs is a scalar. +// CHECK-LABEL: func @f +func @f(%arg0 : !shape.shape) -> !shape.shape { + // CHECK: return %arg0 + %0 = shape.const_shape [] + %1 = shape.broadcast %arg0, %0 + return %1 : !shape.shape +} + +// ----- + +// Lhs is a scalar. +// CHECK-LABEL: func @f +func @f(%arg0 : !shape.shape) -> !shape.shape { + // CHECK: return %arg0 + %0 = shape.const_shape [] + %1 = shape.broadcast %0, %arg0 + return %1 : !shape.shape +} + +// ----- + // Incompatible shapes. No folding. // CHECK-LABEL: func @f func @f() -> !shape.shape { // CHECK: shape.broadcast %0 = shape.const_shape [2] %1 = shape.const_shape [7] - %2 = "shape.broadcast"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape + %2 = shape.broadcast %0, %1 return %2 : !shape.shape } @@ -78,7 +100,7 @@ // CHECK: shape.const_shape [0, 1, 2, 3] %lhs = shape.const_shape [0, 1] %rhs = shape.const_shape [2, 3] - %0 = "shape.concat"(%lhs, %rhs) : (!shape.shape, !shape.shape) -> !shape.shape + %0 = shape.concat %lhs, %rhs return %0 : !shape.shape } diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir --- a/mlir/test/Dialect/Shape/ops.mlir +++ b/mlir/test/Dialect/Shape/ops.mlir @@ -29,10 +29,10 @@ return } -func @test_broadcastable_fixed() { +func @test_broadcast_fixed() { %0 = shape.const_shape [10, 1, 57, 92] %1 = shape.const_shape [4, 57, 92] - %2 = "shape.broadcastable"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape + %2 = shape.broadcast %0, %1 %3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape return }