diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -181,6 +181,34 @@ } //===----------------------------------------------------------------------===// +// ConcatOp +//===----------------------------------------------------------------------===// + +LogicalResult +ConcatOp::inferReturnTypes(MLIRContext *context, Optional<Location> location, + ValueRange operands, DictionaryAttr attributes, + RegionRange regions, + SmallVectorImpl<Type> &inferredReturnTypes) { + auto shapeType = ShapeType::get(context); + inferredReturnTypes.push_back(shapeType); + return success(); +} + +OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) { + if (!operands[0] || !operands[1]) + return nullptr; + auto lhsShape = llvm::to_vector<6>( + operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); + auto rhsShape = llvm::to_vector<6>( + operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); + SmallVector<int64_t, 6> resultShape; + resultShape.append(lhsShape.begin(), lhsShape.end()); + resultShape.append(rhsShape.begin(), rhsShape.end()); + Builder builder(getContext()); + return builder.getIndexTensorAttr(resultShape); +} + +//===----------------------------------------------------------------------===// // ConstShapeOp //===----------------------------------------------------------------------===// @@ -342,34 +370,6 @@ } //===----------------------------------------------------------------------===// -// ConcatOp -//===----------------------------------------------------------------------===// - -LogicalResult -ConcatOp::inferReturnTypes(MLIRContext *context, Optional<Location> location, - ValueRange operands, DictionaryAttr attributes, - RegionRange regions, - SmallVectorImpl<Type> &inferredReturnTypes) { - auto shapeType = ShapeType::get(context); - inferredReturnTypes.push_back(shapeType); - return success(); -} - -OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) { - if (!operands[0] || !operands[1]) - return nullptr; - auto lhsShape = llvm::to_vector<6>( - operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); - auto rhsShape = llvm::to_vector<6>( - operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); - SmallVector<int64_t, 6> resultShape; - resultShape.append(lhsShape.begin(), lhsShape.end()); - resultShape.append(rhsShape.begin(), rhsShape.end()); - Builder builder(getContext()); - return builder.getIndexTensorAttr(resultShape); -} - -//===----------------------------------------------------------------------===// // ToExtentTensorOp //===----------------------------------------------------------------------===//