diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -180,6 +180,34 @@ return builder.getIndexTensorAttr(resultShape); } +//===----------------------------------------------------------------------===// +// ConcatOp +//===----------------------------------------------------------------------===// + +LogicalResult +ConcatOp::inferReturnTypes(MLIRContext *context, Optional location, + ValueRange operands, DictionaryAttr attributes, + RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + auto shapeType = ShapeType::get(context); + inferredReturnTypes.push_back(shapeType); + return success(); +} + +OpFoldResult ConcatOp::fold(ArrayRef operands) { + if (!operands[0] || !operands[1]) + return nullptr; + auto lhsShape = llvm::to_vector<6>( + operands[0].cast().getValues()); + auto rhsShape = llvm::to_vector<6>( + operands[1].cast().getValues()); + SmallVector resultShape; + resultShape.append(lhsShape.begin(), lhsShape.end()); + resultShape.append(rhsShape.begin(), rhsShape.end()); + Builder builder(getContext()); + return builder.getIndexTensorAttr(resultShape); +} + //===----------------------------------------------------------------------===// // ConstShapeOp //===----------------------------------------------------------------------===// @@ -341,34 +369,6 @@ return success(); } -//===----------------------------------------------------------------------===// -// ConcatOp -//===----------------------------------------------------------------------===// - -LogicalResult -ConcatOp::inferReturnTypes(MLIRContext *context, Optional location, - ValueRange operands, DictionaryAttr attributes, - RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { - auto shapeType = ShapeType::get(context); - inferredReturnTypes.push_back(shapeType); - return success(); -} - -OpFoldResult ConcatOp::fold(ArrayRef operands) { - if (!operands[0] || !operands[1]) - return nullptr; - auto lhsShape = llvm::to_vector<6>( - operands[0].cast().getValues()); - auto rhsShape = llvm::to_vector<6>( - operands[1].cast().getValues()); - SmallVector resultShape; - resultShape.append(lhsShape.begin(), lhsShape.end()); - resultShape.append(rhsShape.begin(), rhsShape.end()); - Builder builder(getContext()); - return builder.getIndexTensorAttr(resultShape); -} - //===----------------------------------------------------------------------===// // ToExtentTensorOp //===----------------------------------------------------------------------===//