diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -566,20 +566,17 @@ Type(elementTypeCall), etype.predicate>]>, - descr # " of " # etype.summary # " values", cppClassName> { - // The type of elements in the container. - Type elementType = etype; - - // Call to retrieve. - code getElementTypeCall = elementTypeCall; -} + descr # " of " # etype.summary # " values", cppClassName>; class ShapedContainerType allowedTypes, Pred containerPred, string descr, string cppClassName = "::mlir::Type"> : - ContainerType, containerPred, - "$_self.cast<::mlir::ShapedType>().getElementType()", descr, - cppClassName>; + Type.predicate>, + "; }($_self.cast<::mlir::ShapedType>().getElementType())">]>, + descr # " of " # AnyTypeOf.summary # " values", cppClassName>; // Whether a shaped type is ranked. def HasRankPred : CPred<"$_self.cast<::mlir::ShapedType>().hasRank()">; diff --git a/mlir/test/mlir-tblgen/predicate.td b/mlir/test/mlir-tblgen/predicate.td --- a/mlir/test/mlir-tblgen/predicate.td +++ b/mlir/test/mlir-tblgen/predicate.td @@ -25,11 +25,11 @@ // CHECK-NOT: return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be 32-bit integer or floating-point type, but got " << type; // CHECK: static ::mlir::LogicalResult [[$TENSOR_CONSTRAINT:__mlir_ods_local_type_constraint.*]]( -// CHECK-NEXT: if (!(((type.isa<::mlir::TensorType>())) && ((true)))) { +// CHECK-NEXT: if (!(((type.isa<::mlir::TensorType>())) && ([](Type elementType) { return (true); }(type.cast<::mlir::ShapedType>().getElementType())))) { // CHECK-NEXT: return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be tensor of any type values, but got " << type; // CHECK: static ::mlir::LogicalResult [[$TENSOR_INTEGER_FLOAT_CONSTRAINT:__mlir_ods_local_type_constraint.*]]( -// CHECK-NEXT: if (!(((type.isa<::mlir::TensorType>())) && (((type.cast<::mlir::ShapedType>().getElementType().isF32())) || ((type.cast<::mlir::ShapedType>().getElementType().isSignlessInteger(32)))))) { +// CHECK-NEXT: if (!(((type.isa<::mlir::TensorType>())) && ([](Type elementType) { return ((elementType.isF32())) || ((elementType.isSignlessInteger(32))); }(type.cast<::mlir::ShapedType>().getElementType())))) { // CHECK-NEXT: return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be tensor of 32-bit float or 32-bit signless integer values, but got " << type; // CHECK-LABEL: OpA::verify