diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc --- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc +++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc @@ -122,22 +122,26 @@ // IMPL-LABEL: ArrayAttr Test5Op::indexing_maps() { // IMPL: auto cst0 = getAffineConstantExpr(strides().getValue({ 0 }), context); // IMPL: auto cst1 = getAffineConstantExpr(strides().getValue({ 1 }), context); +// IMPL: auto cst2 = getAffineConstantExpr(strides().getValue({ 0 }), context); +// IMPL: auto cst3 = getAffineConstantExpr(strides().getValue({ 1 }), context); // IMPL: auto map0 = AffineMap::get(7, 9, {d0, d1 * s7 + d4, d2 * s8 + d5, d6}, context); -// IMPL: map0 = map0.replaceDimsAndSymbols({}, { s0, s1, s2, s3, s4, s5, s6, cst0, cst1 }, 7, 0); +// IMPL: map0 = map0.replaceDimsAndSymbols({}, { s0, s1, s2, s3, s4, s5, s6, cst2, cst3 }, 7, 0); // IMPL: map0 = simplifyAffineMap(map0); // IMPL: auto map1 = AffineMap::get(7, 9, {d3, d4, d5, d6}, context); -// IMPL: map1 = map1.replaceDimsAndSymbols({}, { s0, s1, s2, s3, s4, s5, s6, cst0, cst1 }, 7, 0); +// IMPL: map1 = map1.replaceDimsAndSymbols({}, { s0, s1, s2, s3, s4, s5, s6, cst2, cst3 }, 7, 0); // IMPL: map1 = simplifyAffineMap(map1); // IMPL: auto map2 = AffineMap::get(7, 7, {d0, d1, d2, d3}, context); -// IMPL: map2 = map2.replaceDimsAndSymbols({}, { s0, s1, s2, s3, s4, s5, s6, cst0, cst1 }, 7, 0); +// IMPL: map2 = map2.replaceDimsAndSymbols({}, { s0, s1, s2, s3, s4, s5, s6, cst2, cst3 }, 7, 0); // IMPL: map2 = simplifyAffineMap(map2); // IMPL: return {{.+}}.getAffineMapArrayAttr({ map0, map1, map2 }); // ods_def: def test5(I: f32(N, H, W, C), K: f32(F, KH, KW, C)) -> (O: f32(N, H, W, F)) attr(strides: 2xi32) { - O(n, h, w, f) = std_addf(std_mulf( - I(n, h * strides[0] + kh, w * strides[1] + kw, c), K(f, kh, kw, c))); + O(n, h, w, f) = std_addf( + std_mulf(std_addf(I(n, h * strides[0] + kh, w * strides[1] + kw, c), + I(n, h * strides[0] + kh, w * strides[1] + kw, c)), + K(f, kh, kw, c))); } // Test documentation diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp @@ -1190,6 +1190,9 @@ /// Attributes are per TC def. std::map registeredAttrs; + /// A map from AttrUse to AffineExpr symbol. + llvm::StringMap registeredAttrUseToSymbol; + StringRef docString; Parser &parser; @@ -1296,12 +1299,14 @@ if (failed(parseAttrUse(result))) return llvm::None; - // We create a new symbol for each attribute usage without reuse. This is - // fine given these symbols will be replaced with constants and folded away - // for concrete op instances. - result.symbol = getAffineSymbolExpr(symbols.size(), parser.context); - // Merely for taking the index. We don't reuse anyway. - symbols.emplace_back("", result.symbol); + auto symbolIt = registeredAttrUseToSymbol.find(result.getKey()); + if (symbolIt == registeredAttrUseToSymbol.end()) { + result.symbol = getAffineSymbolExpr(symbols.size(), parser.context); + symbols.emplace_back("", result.symbol); + registeredAttrUseToSymbol[result.getKey()] = result.symbol; + } else { + result.symbol = symbolIt->second; + } attrUses.push_back(result);