diff --git a/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp --- a/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp +++ b/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp @@ -62,7 +62,7 @@ while (!opWorklist.empty()) { // Find the next operation ready for inference, that is an operation // with all operands already resolved (non-generic). - auto nextop = llvm::find_if(opWorklist, returnsDynamicShape); + auto nextop = llvm::find_if(opWorklist, allOperandsInferred); if (nextop == opWorklist.end()) break; @@ -88,6 +88,14 @@ } } + /// A utility method that returns if the given operation has all of its + /// operands inferred. + static bool allOperandsInferred(Operation *op) { + return llvm::all_of(op->getOperandTypes(), [](Type operandType) { + return operandType.isa(); + }); + } + /// A utility method that returns if the given operation has a dynamically /// shaped result. static bool returnsDynamicShape(Operation *op) { diff --git a/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp --- a/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp +++ b/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp @@ -62,7 +62,7 @@ while (!opWorklist.empty()) { // Find the next operation ready for inference, that is an operation // with all operands already resolved (non-generic). - auto nextop = llvm::find_if(opWorklist, returnsDynamicShape); + auto nextop = llvm::find_if(opWorklist, allOperandsInferred); if (nextop == opWorklist.end()) break; @@ -88,6 +88,14 @@ } } + /// A utility method that returns if the given operation has all of its + /// operands inferred. + static bool allOperandsInferred(Operation *op) { + return llvm::all_of(op->getOperandTypes(), [](Type operandType) { + return operandType.isa(); + }); + } + /// A utility method that returns if the given operation has a dynamically /// shaped result. static bool returnsDynamicShape(Operation *op) { diff --git a/mlir/examples/toy/Ch6/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch6/mlir/ShapeInferencePass.cpp --- a/mlir/examples/toy/Ch6/mlir/ShapeInferencePass.cpp +++ b/mlir/examples/toy/Ch6/mlir/ShapeInferencePass.cpp @@ -62,7 +62,7 @@ while (!opWorklist.empty()) { // Find the next operation ready for inference, that is an operation // with all operands already resolved (non-generic). - auto nextop = llvm::find_if(opWorklist, returnsDynamicShape); + auto nextop = llvm::find_if(opWorklist, allOperandsInferred); if (nextop == opWorklist.end()) break; @@ -88,6 +88,14 @@ } } + /// A utility method that returns if the given operation has all of its + /// operands inferred. + static bool allOperandsInferred(Operation *op) { + return llvm::all_of(op->getOperandTypes(), [](Type operandType) { + return operandType.isa(); + }); + } + /// A utility method that returns if the given operation has a dynamically /// shaped result. static bool returnsDynamicShape(Operation *op) { diff --git a/mlir/examples/toy/Ch7/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch7/mlir/ShapeInferencePass.cpp --- a/mlir/examples/toy/Ch7/mlir/ShapeInferencePass.cpp +++ b/mlir/examples/toy/Ch7/mlir/ShapeInferencePass.cpp @@ -62,7 +62,7 @@ while (!opWorklist.empty()) { // Find the next operation ready for inference, that is an operation // with all operands already resolved (non-generic). - auto nextop = llvm::find_if(opWorklist, returnsDynamicShape); + auto nextop = llvm::find_if(opWorklist, allOperandsInferred); if (nextop == opWorklist.end()) break; @@ -88,6 +88,14 @@ } } + /// A utility method that returns if the given operation has all of its + /// operands inferred. + static bool allOperandsInferred(Operation *op) { + return llvm::all_of(op->getOperandTypes(), [](Type operandType) { + return operandType.isa(); + }); + } + /// A utility method that returns if the given operation has a dynamically /// shaped result. static bool returnsDynamicShape(Operation *op) {