读代码-MinHashDriver及相关
用到:泛型类 counter 哈希实现
package org.apache.mahout.clustering.minhash;
public final class MinHashDriver extends AbstractJob
输入Sequence格式
输出根据debug模式可选向量和文本格式,文件可以Sequence和Text格式
Class<? extends Writable> outputClass = debugOutput ? VectorWritable.class : Text.class; Class<? extends OutputFormat> outputFormatClass = debugOutput ? SequenceFileOutputFormat.class : TextOutputFormat.class; job.setMapperClass(MinHashMapper.class); job.setReducerClass(MinHashReducer.class); job.setInputFormatClass(SequenceFileInputFormat.class); job.setOutputFormatClass(outputFormatClass); job.setMapOutputKeyClass(Text.class); job.setMapOutputValueClass(outputClass); job.setOutputKeyClass(Text.class); job.setOutputValueClass(outputClass);
hashFunction = HashFactory.createHashFunctions(hashType, numHashFunctions);
for (int i = 0; i < numHashFunctions; i++) { for (Vector.Element ele : featureVector) { int value = (int) ele.get(); bytesToHash[0] = (byte) (value >> 24); bytesToHash[1] = (byte) (value >> 16); bytesToHash[2] = (byte) (value >> 8); bytesToHash[3] = (byte) value; int hashIndex = hashFunction[i].hash(bytesToHash); if (minHashValues[i] > hashIndex) { minHashValues[i] = hashIndex; } } } for (int i = 0; i < numHashFunctions; i++) { StringBuilder clusterIdBuilder = new StringBuilder(); for (int j = 0; j < keyGroups; j++) { clusterIdBuilder.append(minHashValues[(i + j) % numHashFunctions]).append('-'); } String clusterId = clusterIdBuilder.toString(); clusterId = clusterId.substring(0, clusterId.lastIndexOf('-')); Text cluster = new Text(clusterId); Writable point; if (debugOutput) { point = new VectorWritable(featureVector.clone()); } else { point = new Text(item.toString()); } context.write(cluster, point); } Collection<Writable> pointList = new ArrayList<Writable>(); for (Writable point : points) { if (debugOutput) { Vector pointVector = ((VectorWritable) point).get().clone(); Writable writablePointVector = new VectorWritable(pointVector); pointList.add(writablePointVector); } else { Writable pointText = new Text(point.toString()); pointList.add(pointText); } } if (pointList.size() >= minClusterSize) { context.getCounter(Clusters.ACCEPTED).increment(1); for (Writable point : pointList) { context.write(cluster, point); } } else { context.getCounter(Clusters.DISCARDED).increment(1); } public enum HashType { LINEAR, POLYNOMIAL, MURMUR } static class LinearHash implements HashFunction { private final int seedA; private final int seedB; LinearHash(int seedA, int seedB) { this.seedA = seedA; this.seedB = seedB; } @Override public int hash(byte[] bytes) { long hashValue = 31; for (long byteVal : bytes) { hashValue *= seedA * byteVal; hashValue += seedB; } return Math.abs((int) (hashValue % RandomUtils.MAX_INT_SMALLER_TWIN_PRIME)); } } static class PolynomialHash implements HashFunction { private final int seedA; private final int seedB; private final int seedC; PolynomialHash(int seedA, int seedB, int seedC) { this.seedA = seedA; this.seedB = seedB; this.seedC = seedC; } @Override public int hash(byte[] bytes) { long hashValue = 31; for (long byteVal : bytes) { hashValue *= seedA * (byteVal >> 4); hashValue += seedB * byteVal + seedC; } return Math .abs((int) (hashValue % RandomUtils.MAX_INT_SMALLER_TWIN_PRIME)); } } static class MurmurHashWrapper implements HashFunction { private final int seed; MurmurHashWrapper(int seed) { this.seed = seed; } @Override public int hash(byte[] bytes) { long hashValue = MurmurHash.hash64A(bytes, seed); return Math.abs((int) (hashValue % RandomUtils.MAX_INT_SMALLER_TWIN_PRIME)); } }