diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -152,6 +152,8 @@ let arguments = (ins Shape_ShapeType:$lhs, Shape_ShapeType:$rhs, OptionalAttr:$error); let results = (outs Shape_ShapeType:$result); + + let hasFolder = 1; } def Shape_ConstShapeOp : Shape_Op<"const_shape", @@ -225,6 +227,8 @@ let arguments = (ins Shape_ShapeType:$input); let results = (outs IndexTensor:$result); + + let hasFolder = 1; } def Shape_JoinOp : Shape_Op<"join", []> { @@ -376,6 +380,7 @@ let arguments = (ins Shape_ShapeType:$operand, I32:$index); let results = (outs Shape_ShapeType:$head, Shape_ShapeType:$tail); + let hasFolder = 1; } def Shape_ConcatOp : Shape_Op<"concat", @@ -393,6 +398,8 @@ let arguments = (ins Shape_ShapeType:$lhs, Shape_ShapeType:$rhs); let results = (outs Shape_ShapeType:$result); + + let hasFolder = 1; } #endif // SHAPE_OPS diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/Shape/IR/Shape.h" +#include "mlir/Dialect/Traits.h" #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/PatternMatch.h" @@ -87,6 +88,26 @@ } } +//===----------------------------------------------------------------------===// +// BroadcastOp +//===----------------------------------------------------------------------===// + +OpFoldResult BroadcastOp::fold(ArrayRef operands) { + if (!operands[0] || !operands[1]) + return nullptr; + auto lhsShape = llvm::to_vector<6>( + operands[0].cast().getValues()); + auto rhsShape = llvm::to_vector<6>( + operands[1].cast().getValues()); + SmallVector resultShape; + // If the shapes are not compatible, we can't fold it. + // TODO: Fold to an "error". + if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape)) + return nullptr; + Builder builder(getContext()); + return builder.getI64TensorAttr(resultShape); +} + //===----------------------------------------------------------------------===// // ConstShapeOp //===----------------------------------------------------------------------===// @@ -176,6 +197,27 @@ return success(); } +LogicalResult SplitAtOp::fold(ArrayRef operands, + SmallVectorImpl &results) { + if (!operands[0] || !operands[1]) + return failure(); + auto shapeVec = llvm::to_vector<6>( + operands[0].cast().getValues()); + auto shape = llvm::makeArrayRef(shapeVec); + auto splitPoint = operands[1].cast().getInt(); + // Verify that the split point is in the correct range. + // TODO: Constant fold to an "error". + int64_t rank = shape.size(); + if (!(-rank <= splitPoint && splitPoint <= rank)) + return failure(); + if (splitPoint < 0) + splitPoint += shape.size(); + Builder builder(operands[0].getContext()); + results.push_back(builder.getI64TensorAttr(shape.take_front(splitPoint))); + results.push_back(builder.getI64TensorAttr(shape.drop_front(splitPoint))); + return success(); +} + //===----------------------------------------------------------------------===// // ConcatOp //===----------------------------------------------------------------------===// @@ -189,6 +231,35 @@ return success(); } +OpFoldResult ConcatOp::fold(ArrayRef operands) { + if (!operands[0] || !operands[1]) + return nullptr; + auto lhsShape = llvm::to_vector<6>( + operands[0].cast().getValues()); + auto rhsShape = llvm::to_vector<6>( + operands[1].cast().getValues()); + SmallVector resultShape; + resultShape.append(lhsShape.begin(), lhsShape.end()); + resultShape.append(rhsShape.begin(), rhsShape.end()); + Builder builder(getContext()); + return builder.getI64TensorAttr(resultShape); +} + +//===----------------------------------------------------------------------===// +// ToExtentTensorOp +//===----------------------------------------------------------------------===// + +OpFoldResult ToExtentTensorOp::fold(ArrayRef operands) { + if (!operands[0]) + return nullptr; + Builder builder(getContext()); + auto shape = llvm::to_vector<6>( + operands[0].cast().getValues()); + auto type = RankedTensorType::get({static_cast(shape.size())}, + builder.getIndexType()); + return DenseIntElementsAttr::get(type, shape); +} + namespace mlir { namespace shape { diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -canonicalize <%s | FileCheck %s --dump-input=fail +// RUN: mlir-opt -split-input-file -canonicalize <%s | FileCheck %s --dump-input=fail // ----- // CHECK-LABEL: func @f @@ -7,3 +7,82 @@ %0 = "shape.shape_of"(%arg0) : (tensor<2x3x4xf32>) -> !shape.shape return %0 : !shape.shape } + +// ----- +// Basic case. +// CHECK-LABEL: func @f +func @f() -> (!shape.shape, !shape.shape) { + // CHECK: shape.const_shape [2, 3] + // CHECK: shape.const_shape [4, 5] + %c2 = constant 2 : i32 + %0 = "shape.const_shape"() {shape = dense<[2, 3, 4, 5]> : tensor<4xi64>} : () -> !shape.shape + %head, %tail = "shape.split_at"(%0, %c2) : (!shape.shape, i32) -> (!shape.shape, !shape.shape) + return %head, %tail : !shape.shape, !shape.shape + +} + +// ----- +// Negative split point. +// CHECK-LABEL: func @f +func @f() -> (!shape.shape, !shape.shape) { + // CHECK: shape.const_shape [2, 3, 4] + // CHECK: shape.const_shape [5] + %c-1 = constant -1 : i32 + %0 = "shape.const_shape"() {shape = dense<[2, 3, 4, 5]> : tensor<4xi64>} : () -> !shape.shape + %head, %tail = "shape.split_at"(%0, %c-1) : (!shape.shape, i32) -> (!shape.shape, !shape.shape) + return %head, %tail : !shape.shape, !shape.shape +} + +// ----- +// Out of range split point. No folding. +// CHECK-LABEL: func @f +func @f() -> (!shape.shape, !shape.shape) { + // CHECK: shape.split_at + %c5 = constant 5 : i32 + %0 = "shape.const_shape"() {shape = dense<[2, 3, 4, 5]> : tensor<4xi64>} : () -> !shape.shape + %head, %tail = "shape.split_at"(%0, %c5) : (!shape.shape, i32) -> (!shape.shape, !shape.shape) + return %head, %tail : !shape.shape, !shape.shape +} + +// ----- +// Basic case. +// CHECK-LABEL: func @f +func @f() -> !shape.shape { + // CHECK: shape.const_shape [7, 2] + %0 = shape.const_shape [1, 2] + %1 = shape.const_shape [7, 1] + %2 = "shape.broadcast"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape + return %2 : !shape.shape +} + +// ----- +// Incompatible shapes. No folding. +// CHECK-LABEL: func @f +func @f() -> !shape.shape { + // CHECK: shape.broadcast + %0 = shape.const_shape [2] + %1 = shape.const_shape [7] + %2 = "shape.broadcast"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape + return %2 : !shape.shape +} + +// ----- +// Basic case. +// CHECK-LABEL: func @f +func @f() -> !shape.shape { + // CHECK: shape.const_shape [0, 1, 2, 3] + %lhs = shape.const_shape [0, 1] + %rhs = shape.const_shape [2, 3] + %0 = "shape.concat"(%lhs, %rhs) : (!shape.shape, !shape.shape) -> !shape.shape + return %0 : !shape.shape +} + +// ----- +// Basic case. +// CHECK-LABEL: func @f +func @f() -> tensor<2xindex> { + // CHECK: constant dense<[0, 1]> : tensor<2xindex> + %cs = shape.const_shape [0, 1] + %0 = "shape.to_extent_tensor"(%cs) : (!shape.shape) -> tensor<2xindex> + return %0 : tensor<2xindex> +}