diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -97,6 +97,11 @@ SmallVectorImpl &staticVec, int64_t sentinel) { if (auto v = ofr.dyn_cast()) { + if (auto cst = v.getDefiningOp()) { + dispatchIndexOpFoldResult(cst.getValue(), dynamicVec, staticVec, + sentinel); + return; + } dynamicVec.push_back(v); staticVec.push_back(sentinel); return; @@ -678,14 +683,12 @@ SmallVector dynamicSizes; SmallVector staticSizes; for (unsigned i = 0; i < rank; ++i) { - // staticLow and staticHigh have full information of the padding config. - // This will grow staticLow and staticHigh with 1 value. If the config is - // dynamic (ie not a constant), dynamicLow and dynamicHigh will grow with 1 - // value as well. + // This will grow staticSizes with 1 value. If sizes[i] is dynamic (ie not a + // constant), dynamicSizes will grow with 1 value as well. dispatchIndexOpFoldResult(sizes[i], dynamicSizes, staticSizes, ShapedType::kDynamicSize); } - auto resultType = RankedTensorType ::get(staticSizes, elementType); + auto resultType = inferResultType(staticSizes, elementType); build(b, result, resultType, dynamicSizes, b.getI64ArrayAttr(staticSizes)); result.addAttributes(attrs); } diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -43,6 +43,11 @@ SmallVectorImpl &staticVec, int64_t sentinel) { if (auto v = ofr.dyn_cast()) { + if (auto cst = v.getDefiningOp()) { + dispatchIndexOpFoldResult(cst.getValue(), dynamicVec, staticVec, + sentinel); + return; + } dynamicVec.push_back(v); staticVec.push_back(sentinel); return;