diff --git a/mlir/include/mlir/Dialect/Shape/IR/Shape.h b/mlir/include/mlir/Dialect/Shape/IR/Shape.h --- a/mlir/include/mlir/Dialect/Shape/IR/Shape.h +++ b/mlir/include/mlir/Dialect/Shape/IR/Shape.h @@ -31,6 +31,9 @@ /// Alias type for extent tensors. RankedTensorType getExtentTensorType(MLIRContext *ctx); +// Given an input shape Value, try to obtain the shape's values. +LogicalResult getShapeVec(Value input, SmallVectorImpl &shapeValues); + /// The shape descriptor type represents rank and dimension sizes. class ShapeType : public Type::TypeBase { public: diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -31,6 +31,26 @@ return RankedTensorType::get({ShapedType::kDynamicSize}, IndexType::get(ctx)); } +LogicalResult shape::getShapeVec(Value input, + SmallVectorImpl &shapeValues) { + if (auto inputOp = input.getDefiningOp()) { + auto type = inputOp.arg().getType().dyn_cast(); + if (!type.hasRank()) + return failure(); + shapeValues = llvm::to_vector<6>(type.getShape()); + return success(); + } else if (auto inputOp = input.getDefiningOp()) { + shapeValues = llvm::to_vector<6>(inputOp.shape().getValues()); + return success(); + } else if (auto inputOp = input.getDefiningOp()) { + shapeValues = llvm::to_vector<6>( + inputOp.value().cast().getValues()); + return success(); + } else { + return failure(); + } +} + static bool isErrorPropagationPossible(TypeRange operandTypes) { return llvm::any_of(operandTypes, [](Type ty) { return ty.isa(); @@ -605,24 +625,6 @@ // CstrBroadcastableOp //===----------------------------------------------------------------------===// -namespace { -// Given an input shape Value, try to obtain the shape's values. -LogicalResult getShapeVec(Value input, SmallVectorImpl &shapeValues) { - if (auto inputOp = input.getDefiningOp()) { - auto type = inputOp.arg().getType().dyn_cast(); - if (!type.hasRank()) - return failure(); - shapeValues = llvm::to_vector<6>(type.getShape()); - return success(); - } else if (auto inputOp = input.getDefiningOp()) { - shapeValues = llvm::to_vector<6>(inputOp.shape().getValues()); - return success(); - } else { - return failure(); - } -} -} // namespace - void CstrBroadcastableOp::getCanonicalizationPatterns( RewritePatternSet &patterns, MLIRContext *context) { // Canonicalization patterns have overlap with the considerations during diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -1259,3 +1259,14 @@ // CHECK: return %[[SHAPE]] return %1 : !shape.shape } + +// ---- + +// CHECK-LABEL: @cstr_broadcastable_folding +func @cstr_broadcastable_folding(%arg : tensor) { + // CHECK: const_witness true + %0 = shape.shape_of %arg : tensor -> tensor<2xindex> + %1 = constant dense<[4]> : tensor<1xindex> + %2 = shape.cstr_broadcastable %0, %1: tensor<2xindex>, tensor<1xindex> + "use"(%2) : (!shape.witness) -> () +}