diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -813,6 +813,29 @@ return $_op.iterator_types(); }] >, + InterfaceMethod< + /*desc=*/[{ + Return true if the indexing map is depending on the current op instance. + This means that the indexing map is dynamically synthesized by using the + op instance's concrete attributes, instead of being static for all + instances of the same op kind. + }], + /*retTy=*/"bool", + /*methodName=*/"hasDynamicIndexingMaps", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ return false; }] + >, + InterfaceMethod< + /*desc=*/[{ + Verify all attributes used by indexing maps are valid. + }], + /*retTy=*/"LogicalResult", + /*methodName=*/"verifyIndexingMapRequiredAttributes", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ return success(); }] + >, InterfaceMethod< /*desc=*/[{ Return the indexing maps attribute within the current operation. diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -302,6 +302,12 @@ if (op->getNumResults() > linalgOp.getNumOutputTensors()) return op->emitError("unexpected #results > #outputs"); + // Before checking indexing maps, we need to make sure the attributes + // referenced by it are valid. + if (linalgOp.hasDynamicIndexingMaps()) + if (failed(linalgOp.verifyIndexingMapRequiredAttributes())) + return failure(); + // All shaped operands must be indexed. if (linalgOp.indexing_maps().size() != linalgOp.getNumShapedOperands()) return linalgOp.emitOpError("expected the number of indexing_map (") diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc --- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc +++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc @@ -92,6 +92,18 @@ // ODS: RankedI32ElementsAttr<[5, 6]>:$ivec_attr, // ODS: OptionalAttr:$optional_attr // +// ODS: bool hasDynamicIndexingMaps(); +// ODS: LogicalResult verifyIndexingMapRequiredAttributes(); +// +// IMPL: bool Test4Op::hasDynamicIndexingMaps() { return true; } +// IMPL: LogicalResult Test4Op::verifyIndexingMapRequiredAttributes() +// IMPL: op->getAttrOfType("array_attr") +// IMPL: op->getAttr("f32_attr") +// IMPL: op->getAttrOfType("fvec_attr") +// IMPL: op->getAttr("i32_attr") +// IMPL: op->getAttr("i64_attr") +// IMPL: op->getAttrOfType("ivec_attr") +// ods_def : def test4(A: f32(Batch, M, K), B: f32(K, N)) -> (C: f32(Batch, M, N)) attr( diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp @@ -1126,6 +1126,11 @@ void printReferenceIterators(llvm::raw_ostream &os, StringRef cppOpName, ComprehensionParsingState &state); + /// Print methods related to indexing map required attributes. + void printIndexingMapRequiredAttrMethods(llvm::raw_ostream &os, + StringRef cppOpName, + ComprehensionParsingState &state); + /// Print the C++ StructuredOpsInterface impl of `indexing_maps`. void printReferenceIndexingMaps(llvm::raw_ostream &os, StringRef cppOpName, ComprehensionParsingState &state); @@ -1770,6 +1775,7 @@ std::string extraMethods; llvm::raw_string_ostream ss(extraMethods); printReferenceIterators(ss, cppOpName, state); + printIndexingMapRequiredAttrMethods(ss, cppOpName, state); printReferenceIndexingMaps(ss, cppOpName, state); printRegionBuilder(ss, cppOpName, state); printCanonicalizersAndFolders(ss, cppOpName); @@ -1906,6 +1912,8 @@ std::string getLibraryCallName() {{ return generateLibraryCallName(getOperation()); } + + {7} }]; })FMT"; @@ -1971,9 +1979,18 @@ llvm::formatv(builderFmt, cppOpName, attrParamsList, attrStmtsList); } + std::string attrMethods; + if (!registeredAttrs.empty()) { + attrMethods = R"( + bool hasDynamicIndexingMaps(); + LogicalResult verifyIndexingMapRequiredAttributes(); + )"; + } + // Finally put everything together. os << llvm::formatv(header, cppOpName, linalgOpName, interfaceNameList, doc, - attrList, state.orderedTensorArgs.size(), attrBuilder); + attrList, state.orderedTensorArgs.size(), attrBuilder, + attrMethods); } /// Print the C++ StructuredOpsInterface impl of `iterator_types`. @@ -2032,6 +2049,103 @@ os << llvm::formatv(canonicalizersAndFoldersFmt, cppOpName); } +void TCParser::printIndexingMapRequiredAttrMethods( + llvm::raw_ostream &os, StringRef cppOpName, + ComprehensionParsingState &state) { + if (registeredAttrs.empty()) + return; + + SmallVector attributes; + for (const auto &attr : registeredAttrs) { + if (attr.second.isOptional) + continue; + + llvm::StringRef name = attr.first; + llvm::StringRef elementType = attr.second.elementType; + const auto &dims = attr.second.vectorDims; + + std::string elemTypeCheck = llvm::StringSwitch(elementType) + .Case("f32", "isF32()") + .Case("i32", "isInteger(32)") + .Case("i64", "isInteger(64)") + .Default(""); + if (elemTypeCheck.empty()) { + (void)parser.emitError( + "unimplemented support for attribute element type: " + elementType); + return; + } + + if (dims.empty() && !attr.second.isArray) { + // Scalar case. + const char *attrFmt = R"FMT( + if (auto attr = op->getAttr("{0}")) {{ + if (!attr.getType().{1}) return op->emitError( + "incorrect type for indexing map required attribute '{0}'"); + } else {{ + return op->emitError( + "missing indexing map required attribute '{0}'"); + } + )FMT"; + + attributes.push_back(llvm::formatv(attrFmt, name, elemTypeCheck)); + continue; + } + + if (!dims.empty()) { + // Vector case. + SmallVector dimStrs; + for (uint64_t dim : dims) + dimStrs.push_back(std::to_string(dim)); + + const char *attrFmt = R"FMT( + if (auto attr = op->getAttrOfType("{0}")) {{ + if (!attr.getType().getElementType().{1}) return op->emitError( + "incorrect element type for indexing map required attribute '{0}'"); + if (attr.getType().getShape() != ArrayRef{{ {2} }) + return op->emitError( + "incorrect shape for indexing map required attribute '{0}'"); + } else { + return op->emitError( + "missing indexing map required attribute '{0}'"); + } + )FMT"; + + attributes.push_back(llvm::formatv(attrFmt, name, elemTypeCheck, + llvm::join(dimStrs, ", "))); + continue; + } + + { + // Array case. + const char *attrFmt = R"FMT( + if (auto attr = op->getAttrOfType("{0}")) {{ + for (Attribute element : attr) {{ + if (!element.getType().{1}) return emitError( + "incorrect element type for indexing map required attribute '{0}'"); + } + } else {{ + return op->emitError( + "missing indexing map required attribute '{0}'"); + } + )FMT"; + + attributes.push_back(llvm::formatv(attrFmt, name, elemTypeCheck)); + } + } + + const char *methodFmt = R"FMT( + bool {0}::hasDynamicIndexingMaps() {{ return true; } + + LogicalResult {0}::verifyIndexingMapRequiredAttributes() {{ + Operation *op = getOperation(); + {1} + return success(); + } + )FMT"; + + os << llvm::formatv(methodFmt, cppOpName, llvm::join(attributes, "\n")); +} + /// Print the C++ StructuredOpsInterface impl of `referenceIndexingMaps`. void TCParser::printReferenceIndexingMaps(llvm::raw_ostream &os, StringRef cppOpName,