diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -1464,7 +1464,7 @@ VectorType getPassThruVectorType() { return pass_thru().getType().cast(); } - VectorType getResultVectorType() { + VectorType getVectorType() { return result().getType().cast(); } }]; @@ -1478,7 +1478,7 @@ Arguments<(ins Arg:$base, VectorOfRankAndType<[1], [AnyInteger]>:$indices, VectorOfRankAndType<[1], [I1]>:$mask, - VectorOfRank<[1]>:$value)> { + VectorOfRank<[1]>:$valueToStore)> { let summary = "scatters elements from a vector into memory as defined by an index vector and mask"; @@ -1520,12 +1520,13 @@ VectorType getMaskVectorType() { return mask().getType().cast(); } - VectorType getValueVectorType() { - return value().getType().cast(); + VectorType getVectorType() { + return valueToStore().getType().cast(); } }]; - let assemblyFormat = "$base `[` $indices `]` `,` $mask `,` $value attr-dict `:` " - "type($base) `,` type($indices) `,` type($mask) `,` type($value)"; + let assemblyFormat = + "$base `[` $indices `]` `,` $mask `,` $valueToStore attr-dict `:` " + "type($base) `,` type($indices) `,` type($mask) `,` type($valueToStore)"; let hasCanonicalizer = 1; } @@ -1575,7 +1576,7 @@ VectorType getPassThruVectorType() { return pass_thru().getType().cast(); } - VectorType getResultVectorType() { + VectorType getVectorType() { return result().getType().cast(); } }]; @@ -1589,7 +1590,7 @@ Arguments<(ins Arg:$base, Variadic:$indices, VectorOfRankAndType<[1], [I1]>:$mask, - VectorOfRank<[1]>:$value)> { + VectorOfRank<[1]>:$valueToStore)> { let summary = "writes elements selectively from a vector as defined by a mask"; @@ -1626,12 +1627,13 @@ VectorType getMaskVectorType() { return mask().getType().cast(); } - VectorType getValueVectorType() { - return value().getType().cast(); + VectorType getVectorType() { + return valueToStore().getType().cast(); } }]; - let assemblyFormat = "$base `[` $indices `]` `,` $mask `,` $value attr-dict `:` " - "type($base) `,` type($mask) `,` type($value)"; + let assemblyFormat = + "$base `[` $indices `]` `,` $mask `,` $valueToStore attr-dict `:` " + "type($base) `,` type($mask) `,` type($valueToStore)"; let hasCanonicalizer = 1; } diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -437,7 +437,7 @@ return failure(); // Get index ptrs. - VectorType vType = gather.getResultVectorType(); + VectorType vType = gather.getVectorType(); Type iType = gather.getIndicesVectorType().getElementType(); Value ptrs; if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(), @@ -471,7 +471,7 @@ return failure(); // Get index ptrs. - VectorType vType = scatter.getValueVectorType(); + VectorType vType = scatter.getVectorType(); Type iType = scatter.getIndicesVectorType().getElementType(); Value ptrs; if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(), @@ -480,7 +480,7 @@ // Replace with the scatter intrinsic. rewriter.replaceOpWithNewOp( - scatter, adaptor.value(), ptrs, adaptor.mask(), + scatter, adaptor.valueToStore(), ptrs, adaptor.mask(), rewriter.getI32IntegerAttr(align)); return success(); } @@ -500,7 +500,7 @@ MemRefType memRefType = expand.getMemRefType(); // Resolve address. - auto vtype = typeConverter->convertType(expand.getResultVectorType()); + auto vtype = typeConverter->convertType(expand.getVectorType()); Value ptr = this->getStridedElementPtr(loc, memRefType, adaptor.base(), adaptor.indices(), rewriter); @@ -528,7 +528,7 @@ adaptor.indices(), rewriter); rewriter.replaceOpWithNewOp( - compress, adaptor.value(), ptr, adaptor.mask()); + compress, adaptor.valueToStore(), ptr, adaptor.mask()); return success(); } }; diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -2458,7 +2458,7 @@ static LogicalResult verify(GatherOp op) { VectorType indicesVType = op.getIndicesVectorType(); VectorType maskVType = op.getMaskVectorType(); - VectorType resVType = op.getResultVectorType(); + VectorType resVType = op.getVectorType(); MemRefType memType = op.getMemRefType(); if (resVType.getElementType() != memType.getElementType()) @@ -2504,15 +2504,15 @@ static LogicalResult verify(ScatterOp op) { VectorType indicesVType = op.getIndicesVectorType(); VectorType maskVType = op.getMaskVectorType(); - VectorType valueVType = op.getValueVectorType(); + VectorType valueVType = op.getVectorType(); MemRefType memType = op.getMemRefType(); if (valueVType.getElementType() != memType.getElementType()) - return op.emitOpError("base and value element type should match"); + return op.emitOpError("base and valueToStore element type should match"); if (valueVType.getDimSize(0) != indicesVType.getDimSize(0)) - return op.emitOpError("expected value dim to match indices dim"); + return op.emitOpError("expected valueToStore dim to match indices dim"); if (valueVType.getDimSize(0) != maskVType.getDimSize(0)) - return op.emitOpError("expected value dim to match mask dim"); + return op.emitOpError("expected valueToStore dim to match mask dim"); return success(); } @@ -2548,7 +2548,7 @@ static LogicalResult verify(ExpandLoadOp op) { VectorType maskVType = op.getMaskVectorType(); VectorType passVType = op.getPassThruVectorType(); - VectorType resVType = op.getResultVectorType(); + VectorType resVType = op.getVectorType(); MemRefType memType = op.getMemRefType(); if (resVType.getElementType() != memType.getElementType()) @@ -2595,15 +2595,15 @@ static LogicalResult verify(CompressStoreOp op) { VectorType maskVType = op.getMaskVectorType(); - VectorType valueVType = op.getValueVectorType(); + VectorType valueVType = op.getVectorType(); MemRefType memType = op.getMemRefType(); if (valueVType.getElementType() != memType.getElementType()) - return op.emitOpError("base and value element type should match"); + return op.emitOpError("base and valueToStore element type should match"); if (llvm::size(op.indices()) != memType.getRank()) return op.emitOpError("requires ") << memType.getRank() << " indices"; if (valueVType.getDimSize(0) != maskVType.getDimSize(0)) - return op.emitOpError("expected value dim to match mask dim"); + return op.emitOpError("expected valueToStore dim to match mask dim"); return success(); } @@ -2616,8 +2616,8 @@ switch (get1DMaskFormat(compress.mask())) { case MaskFormat::AllTrue: rewriter.replaceOpWithNewOp( - compress, compress.value(), compress.base(), compress.indices(), - false); + compress, compress.valueToStore(), compress.base(), + compress.indices(), false); return success(); case MaskFormat::AllFalse: rewriter.eraseOp(compress); diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1300,7 +1300,7 @@ func @scatter_base_type_mismatch(%base: memref, %indices: vector<16xi32>, %mask: vector<16xi1>, %value: vector<16xf32>) { - // expected-error@+1 {{'vector.scatter' op base and value element type should match}} + // expected-error@+1 {{'vector.scatter' op base and valueToStore element type should match}} vector.scatter %base[%indices], %mask, %value : memref, vector<16xi32>, vector<16xi1>, vector<16xf32> } @@ -1318,7 +1318,7 @@ func @scatter_dim_indices_mismatch(%base: memref, %indices: vector<17xi32>, %mask: vector<16xi1>, %value: vector<16xf32>) { - // expected-error@+1 {{'vector.scatter' op expected value dim to match indices dim}} + // expected-error@+1 {{'vector.scatter' op expected valueToStore dim to match indices dim}} vector.scatter %base[%indices], %mask, %value : memref, vector<17xi32>, vector<16xi1>, vector<16xf32> } @@ -1327,7 +1327,7 @@ func @scatter_dim_mask_mismatch(%base: memref, %indices: vector<16xi32>, %mask: vector<17xi1>, %value: vector<16xf32>) { - // expected-error@+1 {{'vector.scatter' op expected value dim to match mask dim}} + // expected-error@+1 {{'vector.scatter' op expected valueToStore dim to match mask dim}} vector.scatter %base[%indices], %mask, %value : memref, vector<16xi32>, vector<17xi1>, vector<16xf32> } @@ -1368,7 +1368,7 @@ func @compress_base_type_mismatch(%base: memref, %mask: vector<16xi1>, %value: vector<16xf32>) { %c0 = constant 0 : index - // expected-error@+1 {{'vector.compressstore' op base and value element type should match}} + // expected-error@+1 {{'vector.compressstore' op base and valueToStore element type should match}} vector.compressstore %base[%c0], %mask, %value : memref, vector<16xi1>, vector<16xf32> } @@ -1376,7 +1376,7 @@ func @compress_dim_mask_mismatch(%base: memref, %mask: vector<17xi1>, %value: vector<16xf32>) { %c0 = constant 0 : index - // expected-error@+1 {{'vector.compressstore' op expected value dim to match mask dim}} + // expected-error@+1 {{'vector.compressstore' op expected valueToStore dim to match mask dim}} vector.compressstore %base[%c0], %mask, %value : memref, vector<17xi1>, vector<16xf32> }