diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -124,7 +124,8 @@ }]; let builders = [ - OpBuilder<(ins "Value":$source, "int64_t":$index)> + OpBuilder<(ins "Value":$source, "int64_t":$index)>, + OpBuilder<(ins "Value":$source, "OpFoldResult":$index)>, ]; let extraClassDeclaration = [{ diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -380,6 +380,14 @@ build(builder, result, source, indexValue); } +void DimOp::build(OpBuilder &builder, OperationState &result, Value source, + OpFoldResult index) { + if (index.is()) + return build(builder, result, source, index.get()); + auto indexAttr = index.get(); + return build(builder, result, source, indexAttr.cast().getInt()); +} + std::optional DimOp::getConstantIndex() { return getConstantIntValue(getIndex()); }